# ControlNet Trainer
ControlNet (https://github.com/lllyasviel/ControlNet)<br>
Trainer by Dion Timmer<br>
https://github.com/diontimmer/ControlNet-Trainer

**Note:** On line 122 of `share.py`, change `stabilityai` to `Manojb`


In [None]:
!git clone https://github.com/diontimmer/ControlNet-Trainer

Cloning into 'ControlNet-Trainer'...
remote: Enumerating objects: 104, done.[K
remote: Counting objects: 100% (104/104), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 104 (delta 19), reused 88 (delta 8), pack-reused 0 (from 0)[K
Receiving objects: 100% (104/104), 922.41 KiB | 40.10 MiB/s, done.
Resolving deltas: 100% (19/19), done.


In [None]:
# @title Fix 2.1 Download Link and weights only load
%%writefile /content/ControlNet-Trainer/share.py

import config
from cldm.model import create_model, load_state_dict
from pytorch_lightning.callbacks import ModelCheckpoint
from safetensors.torch import save_file
from types import SimpleNamespace
import torch
import sys
import json
import os
import urllib.request
from tqdm import tqdm


# config is the defaults. read the first sys argument to read the config json and update the dict
def make_config():
    if len(sys.argv) > 1 and os.path.splitext(sys.argv[1])[1] == ".json":
        print("Loading config from json:", sys.argv[1])
        json_path = sys.argv[1]
        # load json and cast to python dict with python types
        with open(json_path, "rt", encoding="utf-8") as f:
            new_config = json.load(f)
        # update config with the new dict
        for k, v in new_config.items():
            config.config[k] = v

        config.config = SimpleNamespace(**config.config)


make_config()
config = config.config


class CustomModelCheckpoint(ModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _save_checkpoint(self, trainer, filepath):
        super()._save_checkpoint(trainer, filepath)
        state_dict = torch.load(filepath, map_location="cpu")
        try:
            state_dict = state_dict["state_dict"]["state_dict"]
        except:
            try:
                state_dict = state_dict["state_dict"]
            except:
                pass

        if any([k.startswith("control_model.") for k, v in state_dict.items()]):
            state_dict = {
                k.replace("control_model.", ""): v
                for k, v in state_dict.items()
                if k.startswith("control_model.")
            }

        save_file(state_dict, os.path.splitext(filepath)[0] + ".safetensors")
        if config.wipe_older_ckpts:
            for f in os.listdir(os.path.dirname(filepath)):
                if f.endswith(".ckpt") and f != os.path.basename(filepath):
                    os.remove(os.path.join(os.path.dirname(filepath), f))


def prepare_model_for_training():
    # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.

    base_model_path, base_model_config = create_controlnet_model(
        sd_version=config.sd_version
    )

    model = create_model(base_model_config).cpu()
    model.load_state_dict(load_state_dict(base_model_path, location="cpu"))
    model.learning_rate = config.learning_rate
    model.sd_locked = config.sd_locked
    model.only_mid_control = config.only_mid_control
    return model


def get_latest_ckpt():
    ckpt_list = os.listdir(config.output_dir)
    ckpt_list = [x for x in ckpt_list if x.endswith(".ckpt")]
    if len(ckpt_list) > 0:
        ckpt_list = sorted(
            ckpt_list,
            key=lambda x: os.path.getmtime(os.path.join(config.output_dir, x)),
            reverse=True,
        )
        found_ckpt = os.path.join(config.output_dir, ckpt_list[0])
    else:
        found_ckpt = None

    return found_ckpt


def get_node_name(name, parent_name):
    if len(name) <= len(parent_name):
        return False, ""
    p = name[: len(parent_name)]
    if p != parent_name:
        return False, ""
    return True, name[len(parent_name) :]


def create_controlnet_model(sd_version="2.1"):
    script_dir_path = os.path.dirname(os.path.realpath(__file__))
    models_folder_path = os.path.join(script_dir_path, "models")
    output_path = (
        os.path.join(models_folder_path, "control_sd21_ini.ckpt")
        if sd_version == "2.1"
        else os.path.join(models_folder_path, "control_v15_ini.ckpt")
    )
    config_file = (
        "./models/cldm_v21.yaml" if sd_version == "2.1" else "./models/cldm_v15.yaml"
    )
    if not os.path.exists(output_path):
        model = create_model(config_path=config_file)
        sd_path = (
            "./models/v2-1_512-ema-pruned.ckpt"
            if sd_version == "2.1"
            else "./models/v1-5-pruned.ckpt"
        )
        if not os.path.exists(sd_path):
            '''url = (
                "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt"
                if sd_version == "2.1"
                else "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt"
            )'''
            url = (
                "https://huggingface.co/Manojb/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt"
                if sd_version == "2.1"
                else "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt"
            )
            print("Downloading pretrained model...")
            with tqdm(
                unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
            ) as t:
                urllib.request.urlretrieve(
                    url,
                    filename=sd_path,
                    reporthook=lambda b, bsize, tsize: t.update(bsize),
                )
        else:
            print("Pretrained model already exists, skipping download...")

        pretrained_weights = torch.load(sd_path, weights_only=False)
        if "state_dict" in pretrained_weights:
            pretrained_weights = pretrained_weights["state_dict"]

        scratch_dict = model.state_dict()

        target_dict = {}
        for k in scratch_dict.keys():
            is_control, name = get_node_name(k, "control_")
            if is_control:
                copy_k = "model.diffusion_" + name
            else:
                copy_k = k
            if copy_k in pretrained_weights:
                target_dict[k] = pretrained_weights[copy_k].clone()
            else:
                target_dict[k] = scratch_dict[k].clone()
        #        print(f'These weights are newly added: {k}')

        model.load_state_dict(target_dict, strict=True)
        torch.save(model.state_dict(), output_path)
        os.remove(sd_path)

    return output_path, config_file


Overwriting /content/ControlNet-Trainer/share.py


In [None]:
# @title Fix modules.py to account for new transformer format default
%%writefile /content/ControlNet-Trainer/ldm/modules/encoders/modules.py
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel

import open_clip
from ldm.util import default, count_params


class AbstractEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def encode(self, *args, **kwargs):
        raise NotImplementedError


class IdentityEncoder(AbstractEncoder):

    def encode(self, x):
        return x


class ClassEmbedder(nn.Module):
    def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
        super().__init__()
        self.key = key
        self.embedding = nn.Embedding(n_classes, embed_dim)
        self.n_classes = n_classes
        self.ucg_rate = ucg_rate

    def forward(self, batch, key=None, disable_dropout=False):
        if key is None:
            key = self.key
        # this is for use in crossattn
        c = batch[key][:, None]
        if self.ucg_rate > 0. and not disable_dropout:
            mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
            c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
            c = c.long()
        c = self.embedding(c)
        return c

    def get_unconditional_conditioning(self, bs, device="cuda"):
        uc_class = self.n_classes - 1  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
        uc = torch.ones((bs,), device=device) * uc_class
        uc = {self.key: uc}
        return uc


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class FrozenT5Embedder(AbstractEncoder):
    """Uses the T5 transformer encoder for text"""
    def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
        super().__init__()
        self.tokenizer = T5Tokenizer.from_pretrained(version)
        self.transformer = T5EncoderModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length   # TODO: typical value?
        if freeze:
            self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        #self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)


class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        #self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)


class FrozenOpenCLIPEmbedder(AbstractEncoder):
    """
    Uses the OpenCLIP transformer encoder for text
    """
    LAYERS = [
        #"pooled",
        "last",
        "penultimate"
    ]
    def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
                 freeze=True, layer="last"):
        super().__init__()
        assert layer in self.LAYERS
        model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
        del model.visual
        self.model = model

        self.device = device
        self.max_length = max_length
        if freeze:
            self.freeze()
        self.layer = layer
        if self.layer == "last":
            self.layer_idx = 0
        elif self.layer == "penultimate":
            self.layer_idx = 1
        else:
            raise NotImplementedError()

    def freeze(self):
        self.model = self.model.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        tokens = open_clip.tokenize(text)
        z = self.encode_with_transformer(tokens.to(self.device))
        return z

    def encode_with_transformer(self, tokens):
        x = self.model.token_embedding(tokens)  # [batch_size, n_ctx, d_model]
        x = x + self.model.positional_embedding
        #x = x.permute(1, 0, 2)  # No longer the correct default format
        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
        #x = x.permute(1, 0, 2)
        x = self.model.ln_final(x)
        return x

    def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
        for i, r in enumerate(self.model.transformer.resblocks):
            if i == len(self.model.transformer.resblocks) - self.layer_idx:
                break
            if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(r, x, attn_mask)
            else:
                x = r(x, attn_mask=attn_mask)
        return x

    def encode(self, text):
        return self(text)


class FrozenCLIPT5Encoder(AbstractEncoder):
    def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
                 clip_max_length=77, t5_max_length=77):
        super().__init__()
        self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
        print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
              f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")

    def encode(self, text):
        return self(text)

    def forward(self, text):
        clip_z = self.clip_encoder.encode(text)
        t5_z = self.t5_encoder.encode(text)
        return [clip_z, t5_z]




In [None]:
%%capture
%cd /content/
#@title Setup
!mkdir dataset
!mkdir dataset/image
!mkdir dataset/conditioning
!mkdir dataset/prompts
!git clone https://github.com/diontimmer/ControlNet-Trainer.git
%cd ControlNet-Trainer
!mkdir output
!mkdir logs
!pip install -r requirements.txt

## Download Test Dataset

In [None]:
# Download a test dataset
from huggingface_hub import login, hf_hub_download
hf_hub_download(repo_id='lllyasviel/ControlNet',filename='training/fill50k.zip',local_dir='/content/')

training/fill50k.zip:   0%|          | 0.00/229M [00:00<?, ?B/s]

'/content/training/fill50k.zip'

In [None]:
!unzip -q /content/training/fill50k.zip

## Continue

In [2]:
%cd ControlNet-Trainer
import json, os

/content/ControlNet-Trainer


In [3]:
#@title Dataset
dataset_conditioning_folder = "/content/ControlNet-Trainer/fill50k" #@param {type:"string"}
dataset_target_folder = "/content/ControlNet-Trainer/fill50k" #@param {type:"string"}
dataset_captions_json = "/content/ControlNet-Trainer/fill50k/prompt.json" #@param {type:"string"}
print('Folders set!')
print(f'Conditionings: {dataset_conditioning_folder}')
print(f'Targets: {dataset_target_folder}')
print(f'Prompts: {dataset_captions_json}')

Folders set!
Conditionings: /content/ControlNet-Trainer/fill50k
Targets: /content/ControlNet-Trainer/fill50k
Prompts: /content/ControlNet-Trainer/fill50k/prompt.json


In [4]:
#@title Options
project_name = "default" #@param {type:"string"}
run_name = "" #@param {type:"string"}
sd_version = "2.1" #@param ["2.1", "1.5"]
output_dir = "/content/ControlNet-Trainer/output" #@param {type:"string"}
logging_dir = "/content/ControlNet-Trainer/logs" #@param {type:"string"}
resume_ckpt = "latest" #@param {type:"string"}

wandb_key = "" #@param {type:"string"}
resolution = 512 #@param {type:"integer"}
batch_size = 3 #@param {type:"integer"}
image_logger_freq = 250 #@param {type:"integer"}
learning_rate = 1e-5 #@param
max_steps = 9000 #@param {type:"integer"}
max_epochs = 10 #@param {type:"integer"}
wipe_older_ckpts = False #@param {type:"boolean"}

save_memory = False #@param {type:"boolean"}
image_logger_disabled = False #@param {type:"boolean"}
save_ckpt_every_n_steps = 250 #@param {type:"integer"}
save_top_k = -1 #@param {type:"integer"}
save_weights_only = False #@param {type:"boolean"}
save_last = False #@param {type:"boolean"}
sd_locked = True #@param {type:"boolean"}
only_mid_control = False #@param {type:"boolean"}
gradient_accumulation_steps = 1

multi_gpu = False #@param {type:"boolean"}

config = {
    "project_name": project_name,
    "run_name": run_name,
    "sd_version": sd_version,
    "output_dir": output_dir,
    "logging_dir": logging_dir,
    "resume_ckpt": resume_ckpt,
    "dataset_conditioning_folder": dataset_conditioning_folder,
    "dataset_target_folder": dataset_target_folder,
    "dataset_captions_json": dataset_captions_json,
    "wandb_key": wandb_key,
    "resolution": resolution,
    "save_memory": save_memory,
    "batch_size": batch_size,
    "image_logger_disabled": image_logger_disabled,
    "image_logger_freq": image_logger_freq,
    "learning_rate": learning_rate,
    "max_steps": max_steps,
    "max_epochs": max_epochs,
    "save_ckpt_every_n_steps": save_ckpt_every_n_steps,
    "save_top_k": save_top_k,
    "save_weights_only": save_weights_only,
    "save_last": save_last,
    "sd_locked": sd_locked,
    "only_mid_control": only_mid_control,
    "gradient_accumulation_steps": gradient_accumulation_steps,
    "wipe_older_ckpts": wipe_older_ckpts,
    "multi_gpu": multi_gpu
}

In [5]:
#@title Save Config
!mkdir -p configs

config_name = "default_config" #@param {type:"string"}

# Save the dictionary to a json file
with open(f'configs/{config_name}.json', 'w') as f:
    json.dump(config, f, indent=4)

print(f'Config saved as configs/{config_name}.json!')

Config saved as configs/default_config.json!


In [10]:
#@title Start Training
training_config = "default_config" #@param {type:"string"}
!python train.py configs/{training_config}.json

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 0/9  [90m━━━━━━━━━━━━━━━━[0m 137/16667 [2m0:01:09 •      [0m [2;4m2.25it/s[0m [3mv_num: 8.000     [0m
                                      [2m2:02:32        [0m          [3mtrain/loss_simpl…[0m
                                                               [3m0.015            [0m
                                                               [3mtrain/loss_vlb_s…[0m
                                                               [3m0.000            [0m
                                                               [3mtrain/loss_step: [0m
                                                               [3m0.015            [0m
                                                               [3mglobal_step:     [0m
[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K[1A[2K++++++++++++++++++++
Epoch 0/9  [90m━━━━━━━━━━━━━━━━[0m 137/16667 [2m0:01:09 •      [0m [2;4m2.25it/s[0m [3mv_num

In [7]:
# base: RuntimeError: The shape of the 2D attn_mask is torch.Size([77, 77]), but should be (3, 3).
# 'txt': ['sienna circle with coral background', 'sky blue circle with medium slate blue background', 'crimson circle with antique white background']

#