<a href="https://colab.research.google.com/github/Nekoiii/ML_Practices_colab/blob/main/Train_a_ControlNet_to_Control_SD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%script false --no-raise-error
!unzip '/content/drive/MyDrive/datasets/imgs/fill50k.zip' -d '/content/drive/MyDrive/datasets/imgs/'

In [5]:
import json
import cv2
import numpy as np

from torch.utils.data import Dataset

dataset_folder='/content/drive/MyDrive/datasets/imgs/fill50k/'
prompt_path=dataset_folder+'prompt.json'

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open(prompt_path, 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']

        source = cv2.imread(dataset_folder+ source_filename)
        target = cv2.imread(dataset_folder+ target_filename)

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)


In [6]:
dataset = MyDataset()
print(len(dataset))

item = dataset[1234]
jpg = item['jpg']
txt = item['txt']
hint = item['hint']
print(txt)
print(jpg.shape)
print(hint.shape)


50000
burly wood circle with orange background
(512, 512, 3)
(512, 512, 3)


In [8]:
%%writefile config.py
save_memory = False

Writing config.py


In [None]:
%%writefile cldm/hack.py

In [7]:
%%writefile share.py

import config
from cldm.hack import disable_verbosity, enable_sliced_attention


disable_verbosity()

if config.save_memory:
    enable_sliced_attention()


Writing share.py


In [None]:
import sys
import os
import torch
from share import *
from cldm.model import create_model

config_path='/content/drive/MyDrive/StableDifussion/models/cldm_v21.yaml'
input_path='/content/drive/MyDrive/StableDifussion/models/v2-1_512-nonema-pruned.ckpt '
output_path='/content/drive/MyDrive/my_models/controlnet_sd/control_sd21_ini.ckpt'

def get_node_name(name, parent_name):
    if len(name) <= len(parent_name):
        return False, ''
    p = name[:len(parent_name)]
    if p != parent_name:
        return False, ''
    return True, name[len(parent_name):]


def tool_add_control_sd21(input_path,output_path):

  assert os.path.exists(input_path), 'Input model does not exist.'
  assert not os.path.exists(output_path), 'Output filename already exists.'
  assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'


  model = create_model(config_path = config_path)

  pretrained_weights = torch.load(input_path)
  if 'state_dict' in pretrained_weights:
      pretrained_weights = pretrained_weights['state_dict']

  scratch_dict = model.state_dict()

  target_dict = {}
  for k in scratch_dict.keys():
      is_control, name = get_node_name(k, 'control_')
      if is_control:
          copy_k = 'model.diffusion_' + name
      else:
          copy_k = k
      if copy_k in pretrained_weights:
          target_dict[k] = pretrained_weights[copy_k].clone()
      else:
          target_dict[k] = scratch_dict[k].clone()
          print(f'These weights are newly added: {k}')

  model.load_state_dict(target_dict, strict=True)
  torch.save(model.state_dict(), output_path)
  print('Done.')

In [None]:
from share import *

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict


# Configs
resume_path = './models/control_sd21_ini.ckpt'
batch_size = 4
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False


# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./models/cldm_v21.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control


# Misc
dataset = MyDataset()
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])


# Train!
trainer.fit(model, dataloader)