In [2]:
import sys
if './' not in sys.path:
	sys.path.append('./')
	
from omegaconf import OmegaConf
import argparse

from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

from ldm.util import instantiate_from_config
from models.util import load_state_dict
from models.logger import ImageLogger
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config_path = './configs/local_v15.yaml'
learning_rate = 1e-5
batch_size = 4
training_steps = 1e5
resume_path = './ckpt/init_local.ckpt'
default_logdir = './log_local/'
logger_freq = 500
sd_locked = True
num_workers = 4
gpus = -1

In [4]:
config = OmegaConf.load(config_path)
dataset = instantiate_from_config(config['data'])
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, pin_memory=True, shuffle=True)

In [5]:
dataset[1]['jpg'].shape
dataset[1]['local_conditions'].shape

(512, 512, 21)

In [6]:
model = instantiate_from_config(config['model'])
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmp73ckbabe
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmp73ckbabe/_remote_module_non_sriptable.py


No module 'xformers'. Proceeding without it.
UniControlNet: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.22.mlp.fc2.bias', 'vision_model.encoder.layers.14.self_attn.q_proj.bias', 'vision_model.encoder.layers.0.mlp.fc2.weight', 'vision_model.encoder.layers.12.self_attn.v_proj.bias', 'vision_model.encoder.layers.18.layer_norm2.bias', 'vision_model.encoder.layers.4.layer_norm1.weight', 'vision_model.encoder.layers.10.self_attn.k_proj.bias', 'vision_model.encoder.layers.7.layer_norm1.bias', 'vision_model.encoder.layers.20.self_attn.v_proj.weight', 'vision_model.encoder.layers.15.layer_norm2.bias', 'vision_model.encoder.layers.2.self_attn.out_proj.bias', 'vision_model.encoder.layers.5.mlp.fc1.weight', 'vision_model.encoder.layers.23.mlp.fc2.weight', 'vision_model.encoder.layers.9.mlp.fc2.weight', 'vision_model.encoder.layers.16.layer_norm2.weight', 'vision_model.encoder.layers.17.layer_norm2.weight', 'vision_model.encoder.layers.4.layer_norm2.bias',

Loaded state_dict from [./ckpt/init_local.ckpt]


In [7]:
for batch in dataloader:
    print(batch['local_conditions'][0].shape)

torch.Size([512, 512, 21])
torch.Size([512, 512, 21])
torch.Size([512, 512, 21])
torch.Size([512, 512, 21])


In [8]:
logger = ImageLogger(batch_frequency=logger_freq)
checkpoint_callback = ModelCheckpoint(every_n_train_steps=logger_freq,)

In [9]:
trainer = pl.Trainer(
        gpus=gpus,
        callbacks=[logger, checkpoint_callback], 
        default_root_dir=default_logdir,
        max_steps=training_steps,
        strategy='dp'
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(model,dataloader)

  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
  rank_zero_deprecation(
  rank_zero_deprecation(
Missing logger folder: ./log_local/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name              | Type               | Params
---------------------------------------------------------
0 | model             | DiffusionWrapper   | 859 M 
1 | first_stage_model | AutoencoderKL      | 83.7 M
2 | cond_stage_model  | FrozenCLIPEmbedder | 123 M 
3 | local_adapter     | LocalAdapter       | 411 M 
---------------------------------------------------------
1.3 B     Trainable params
206 M     Non-trainable params
1.5 B     Total params
5,912.555 Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|                                                                        | 0/4 [00:00<?, ?it/s]

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/overrides/data_parallel.py", line 64, in forward
    output = super().forward(*inputs, **kwargs)
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 82, in forward
    output = self.module.training_step(*inputs, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/ldm/models/diffusion/ddpm.py", line 442, in training_step
    loss, loss_dict = self.shared_step(batch)
  File "/data/maryam.sana/Uni-ControlNet/ldm/models/diffusion/ddpm.py", line 836, in shared_step
    loss = self(x, c)
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/ldm/models/diffusion/ddpm.py", line 848, in forward
    return self.p_losses(x, c, t, *args, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/ldm/models/diffusion/ddpm.py", line 888, in p_losses
    model_output = self.apply_model(x_noisy, t, cond)
  File "/data/maryam.sana/Uni-ControlNet/models/uni_controlnet.py", line 59, in apply_model
    local_control = self.local_adapter(x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control)
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/models/local_adapter.py", line 407, in forward
    h = module(h, emb, context, local_features[self.inject_layers.index(layer_idx)])
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/models/local_adapter.py", line 23, in forward
    x = layer(x, context)
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/ldm/modules/attention.py", line 334, in forward
    x = block(x, context=context[i])
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/ldm/modules/attention.py", line 269, in forward
    return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  File "/data/maryam.sana/Uni-ControlNet/ldm/modules/diffusionmodules/util.py", line 114, in checkpoint
    return CheckpointFunction.apply(func, len(inputs), *args)
  File "/data/maryam.sana/Uni-ControlNet/ldm/modules/diffusionmodules/util.py", line 129, in forward
    output_tensors = ctx.run_function(*ctx.input_tensors)
  File "/data/maryam.sana/Uni-ControlNet/ldm/modules/attention.py", line 273, in _forward
    x = self.attn2(self.norm2(x), context=context) + x
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/maryam.sana/Uni-ControlNet/ldm/modules/attention.py", line 177, in forward
    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  File "/data/maryam.sana/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/functional.py", line 330, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [16, 4096, 40]->[16, 4096, 1, 40] [32, 77, 40]->[32, 1, 77, 40]
