In [None]:
import json
import cv2
import numpy as np

from torch.utils.data import Dataset

In [None]:
class MyDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r') as f:
            list_data = f.readlines()
            self.data = [json.loads(a) for a in list_data]
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']

        source = cv2.imread(source_filename)
        target = cv2.imread(target_filename)
        dim = (512,512)
        print()
        source = cv2.resize(source, dim, interpolation = cv2.INTER_CUBIC)
        target = cv2.resize(target, dim, interpolation = cv2.INTER_CUBIC)

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

In [None]:
dataset = MyDataset("/home/jupyter/gcs/train.txt")

In [None]:
dataset[0]




{'jpg': array([[[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         ...,
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ]],
 
        [[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         ...,
         [-0.8117647 , -0.8117647 , -0.8117647 ],
         [-0.81960785, -0.81960785, -0.81960785],
         [-0.81960785, -0.81960785, -0.81960785]],
 
        [[-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         [-1.        , -1.        , -1.        ],
         ...,
         [-0.8039216 , -0.8039216 , -0.8039216 ],
         [-0.8117647 , -0.8117647 , -0.8117647 ],
         [-0.79607844, -0.79607844, -0.79607844]],
 
        ...,
 
        [[-1.        , -1. 

In [None]:
item = dataset[0]
jpg = item['jpg']
txt = item['txt']
hint = item['hint']
print(txt)
print(jpg.shape)
print(hint.shape)


a chest xray with No Finding
(512, 512, 3)
(512, 512, 3)


In [None]:
import sys
sys.path.append("ControlNet/")
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from ControlNet.cldm.logger import ImageLogger
from ControlNet.cldm.model import create_model, load_state_dict

In [None]:
# Configs
resume_path = '/home/jupyter/gcs/checkpoints/control_sd21_ini.ckpt'
batch_size = 10
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False


# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('ControlNet/models/cldm_v21.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control


# Misc
dataset = MyDataset("/home/jupyter/gcs/train.txt")
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])


# Train!
trainer.fit(model, dataloader)

ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loaded model config from [ControlNet/models/cldm_v21.yaml]
Loaded state_dict from [/home/jupyter/gcs/checkpoints/control_sd21_ini.ckpt]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
  rank_zero_deprecation(
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                   | Params
-------------------------------------------------------------
0 | model             | DiffusionWrapper       | 865 M 
1 | first_stage_model | AutoencoderKL          | 83.7 M
2 | cond_stage_model  | FrozenOpenCLIPEmbedder | 354 M 
3 | control_model     | ControlNet             | 364 M 
-------------------------------------------------------------
1.2 B     Trainable params
437 M     Non-trainable params
1.7 B     Total params
6,671.302 Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/20881 [00:00<?, ?it/s] 











Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps



DDIM Sampler:   0%|          | 0/50 [00:00<?, ?it/s][A
DDIM Sampler:   2%|▏         | 1/50 [00:00<00:22,  2.20it/s][A
DDIM Sampler:   4%|▍         | 2/50 [00:00<00:21,  2.20it/s][A
DDIM Sampler:   6%|▌         | 3/50 [00:01<00:21,  2.20it/s][A
DDIM Sampler:   8%|▊         | 4/50 [00:01<00:20,  2.20it/s][A
DDIM Sampler:  10%|█         | 5/50 [00:02<00:20,  2.20it/s][A
DDIM Sampler:  12%|█▏        | 6/50 [00:02<00:19,  2.20it/s][A
DDIM Sampler:  14%|█▍        | 7/50 [00:03<00:19,  2.21it/s][A
DDIM Sampler:  16%|█▌        | 8/50 [00:03<00:19,  2.21it/s][A
DDIM Sampler:  18%|█▊        | 9/50 [00:04<00:18,  2.21it/s][A
DDIM Sampler:  20%|██        | 10/50 [00:04<00:18,  2.21it/s][A
DDIM Sampler:  22%|██▏       | 11/50 [00:04<00:17,  2.21it/s][A
DDIM Sampler:  24%|██▍       | 12/50 [00:05<00:17,  2.21it/s][A
DDIM Sampler:  26%|██▌       | 13/50 [00:05<00:16,  2.21it/s][A
DDIM Sampler:  28%|██▊       | 14/50 [00:06<00:16,  2.21it/s][A
DDIM Sampler:  30%|███       | 15/50 [00:0

Epoch 0:   0%|          | 1/20881 [00:49<288:28:41, 49.74s/it, loss=0.161, v_num=0, train/loss_simple_step=0.161, train/loss_vlb_step=0.00183, train/loss_step=0.161, global_step=0.000]



Epoch 0:   0%|          | 2/20881 [00:54<157:34:04, 27.17s/it, loss=0.148, v_num=0, train/loss_simple_step=0.135, train/loss_vlb_step=0.0017, train/loss_step=0.135, global_step=1.000] 



Epoch 0:   0%|          | 3/20881 [00:59<115:45:30, 19.96s/it, loss=0.134, v_num=0, train/loss_simple_step=0.107, train/loss_vlb_step=0.000606, train/loss_step=0.107, global_step=2.000]



Epoch 0:   0%|          | 4/20881 [01:05<94:17:37, 16.26s/it, loss=0.116, v_num=0, train/loss_simple_step=0.0624, train/loss_vlb_step=0.00025, train/loss_step=0.0624, global_step=3.000]



Epoch 0:   0%|          | 5/20881 [01:10<81:58:12, 14.14s/it, loss=0.0967, v_num=0, train/loss_simple_step=0.0185, train/loss_vlb_step=7.39e-5, train/loss_step=0.0185, global_step=4.000]



Epoch 0:   0%|          | 6/20881 [01:16<74:01:32, 12.77


DDIM Sampler:   0%|          | 0/50 [00:00<?, ?it/s][A
DDIM Sampler:   2%|▏         | 1/50 [00:00<00:22,  2.21it/s][A
DDIM Sampler:   4%|▍         | 2/50 [00:00<00:21,  2.21it/s][A
DDIM Sampler:   6%|▌         | 3/50 [00:01<00:21,  2.21it/s][A
DDIM Sampler:   8%|▊         | 4/50 [00:01<00:20,  2.21it/s][A
DDIM Sampler:  10%|█         | 5/50 [00:02<00:20,  2.21it/s][A
DDIM Sampler:  12%|█▏        | 6/50 [00:02<00:19,  2.21it/s][A
DDIM Sampler:  14%|█▍        | 7/50 [00:03<00:19,  2.21it/s][A
DDIM Sampler:  16%|█▌        | 8/50 [00:03<00:19,  2.21it/s][A
DDIM Sampler:  18%|█▊        | 9/50 [00:04<00:18,  2.21it/s][A
DDIM Sampler:  20%|██        | 10/50 [00:04<00:18,  2.21it/s][A
DDIM Sampler:  22%|██▏       | 11/50 [00:04<00:17,  2.21it/s][A
DDIM Sampler:  24%|██▍       | 12/50 [00:05<00:17,  2.21it/s][A
DDIM Sampler:  26%|██▌       | 13/50 [00:05<00:16,  2.21it/s][A
DDIM Sampler:  28%|██▊       | 14/50 [00:06<00:16,  2.21it/s][A
DDIM Sampler:  30%|███       | 15/50 [00:0

Epoch 0:   1%|▏         | 301/20881 [27:45<31:37:30,  5.53s/it, loss=0.113, v_num=0, train/loss_simple_step=0.0832, train/loss_vlb_step=0.000466, train/loss_step=0.0832, global_step=300.0]



Epoch 0:   1%|▏         | 302/20881 [27:48<31:35:22,  5.53s/it, loss=0.111, v_num=0, train/loss_simple_step=0.0803, train/loss_vlb_step=0.000326, train/loss_step=0.0803, global_step=301.0]



Epoch 0:   1%|▏         | 303/20881 [27:53<31:34:29,  5.52s/it, loss=0.114, v_num=0, train/loss_simple_step=0.0735, train/loss_vlb_step=0.000259, train/loss_step=0.0735, global_step=302.0]



Epoch 0:   1%|▏         | 304/20881 [27:59<31:34:29,  5.52s/it, loss=0.111, v_num=0, train/loss_simple_step=0.0923, train/loss_vlb_step=0.000327, train/loss_step=0.0923, global_step=303.0]



Epoch 0:   1%|▏         | 305/20881 [28:03<31:33:02,  5.52s/it, loss=0.109, v_num=0, train/loss_simple_step=0.0648, train/loss_vlb_step=0.000221, train/loss_step=0.0648, global_step=304.0]



Epoch 0:   1%|▏         | 306/20881 [28:


DDIM Sampler:   0%|          | 0/50 [00:00<?, ?it/s][A
DDIM Sampler:   2%|▏         | 1/50 [00:00<00:22,  2.21it/s][A
DDIM Sampler:   4%|▍         | 2/50 [00:00<00:21,  2.21it/s][A
DDIM Sampler:   6%|▌         | 3/50 [00:01<00:21,  2.21it/s][A
DDIM Sampler:   8%|▊         | 4/50 [00:01<00:20,  2.21it/s][A
DDIM Sampler:  10%|█         | 5/50 [00:02<00:20,  2.21it/s][A
DDIM Sampler:  12%|█▏        | 6/50 [00:02<00:19,  2.21it/s][A
DDIM Sampler:  14%|█▍        | 7/50 [00:03<00:19,  2.21it/s][A
DDIM Sampler:  16%|█▌        | 8/50 [00:03<00:19,  2.21it/s][A
DDIM Sampler:  18%|█▊        | 9/50 [00:04<00:18,  2.21it/s][A
DDIM Sampler:  20%|██        | 10/50 [00:04<00:18,  2.21it/s][A
DDIM Sampler:  22%|██▏       | 11/50 [00:04<00:17,  2.21it/s][A
DDIM Sampler:  24%|██▍       | 12/50 [00:05<00:17,  2.21it/s][A
DDIM Sampler:  26%|██▌       | 13/50 [00:05<00:16,  2.21it/s][A
DDIM Sampler:  28%|██▊       | 14/50 [00:06<00:16,  2.21it/s][A
DDIM Sampler:  30%|███       | 15/50 [00:0

Epoch 0:   3%|▎         | 601/20881 [54:52<30:51:25,  5.48s/it, loss=0.125, v_num=0, train/loss_simple_step=0.283, train/loss_vlb_step=0.010, train/loss_step=0.283, global_step=600.0]     



Epoch 0:   3%|▎         | 602/20881 [54:55<30:50:23,  5.47s/it, loss=0.129, v_num=0, train/loss_simple_step=0.220, train/loss_vlb_step=0.00146, train/loss_step=0.220, global_step=601.0]



Epoch 0:   3%|▎         | 603/20881 [55:01<30:50:09,  5.47s/it, loss=0.133, v_num=0, train/loss_simple_step=0.189, train/loss_vlb_step=0.00216, train/loss_step=0.189, global_step=602.0]



Epoch 0:   3%|▎         | 604/20881 [55:06<30:49:54,  5.47s/it, loss=0.135, v_num=0, train/loss_simple_step=0.105, train/loss_vlb_step=0.000408, train/loss_step=0.105, global_step=603.0]



Epoch 0:   3%|▎         | 605/20881 [55:11<30:49:42,  5.47s/it, loss=0.136, v_num=0, train/loss_simple_step=0.0873, train/loss_vlb_step=0.000426, train/loss_step=0.0873, global_step=604.0]



Epoch 0:   3%|▎         | 606/20881 [55:16<30:49