<a href="https://colab.research.google.com/github/Yang-star-source/ControlNet-In-Latent-Diffusion/blob/main/ControlNet_In_Latent_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Trained Model Definition (VAE , DiffusionUnet)

In [None]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,dropout_prob = 0.0):
        super(ResNetBlock, self).__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class AttentionBlock(nn.Module):
    def __init__(self,in_channels):
        super(AttentionBlock, self).__init__()
        self.norm = nn.GroupNorm(32,in_channels)

        self.to_q = nn.Linear(in_channels,in_channels)
        self.to_k = nn.Linear(in_channels,in_channels)
        self.to_v = nn.Linear(in_channels,in_channels)

        self.to_out = nn.Linear(in_channels,in_channels)

    def forward(self,x):
        residual = x
        B,C,H,W = x.shape
        x = self.norm(x)
        x = x.view(B,C,-1).permute(0,2,1)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        attn = torch.bmm(q,k.permute(0,2,1)) # batch matrix multiplication
        attn = attn * (C**(-0.5))  # sqrt(dk)
        attn = attn.softmax(dim=-1)
        attn = torch.bmm(attn,v)

        out = self.to_out(attn)
        out = out.permute(0,2,1).view(B,C,H,W)

        return out + residual

class MidBlock(nn.Module):
    def __init__(self,in_channels):
        super(MidBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,in_channels)
        self.attn1 = AttentionBlock(in_channels)
        self.res2 = ResNetBlock(in_channels,in_channels)

    def forward(self,x):
        x = self.res1(x)
        x = self.attn1(x)
        x = self.res2(x)
        return x

class DownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super(DownBlock, self).__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.down(x)
        return x

class Encoder(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(Encoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)
        self.down_block = nn.Sequential(
            DownBlock(64,128),
            DownBlock(128,256),
            DownBlock(256,512,has_attn = True)
        )
        self.bottle = MidBlock(512)
        self.out = nn.Sequential(
            nn.GroupNorm(32,512),
            nn.SiLU(),
            nn.Conv2d(512,out_channels*2,kernel_size=1)
        )

    def reparameterize(self,x):
        mean , log_var = torch.chunk(x,2,dim=1)
        log_var = torch.clamp(log_var, -30.0, 20.0)
        D_kl = 0.5 *(torch.exp(log_var) + mean**2 - log_var - 1)
        D_kl = torch.sum(D_kl,dim=[1,2,3]).mean() # Mean is for batch dimention
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        z = mean + eps*std
        return z,D_kl

    def forward(self,x):
        x = self.inp(x)
        x = self.down_block(x)
        x = self.bottle(x)
        x = self.out(x)
        z,D_kl = self.reparameterize(x)
        return z,D_kl

class UpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,has_attn=False):
        super().__init__()
        self.res1 = ResNetBlock(in_channels,out_channels)
        self.res2 = ResNetBlock(out_channels,out_channels)
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2), # nearest mode by default
            nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        )
        if has_attn:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.attn(x)
        x = self.up(x)
        return x

class Decoder(nn.Module):
    def __init__(self,in_channels=4,out_channels=3):
        super(Decoder, self).__init__()
        self.inp = nn.Conv2d(in_channels,512,kernel_size=3,padding=1)
        self.bottle = MidBlock(512)
        self.up_block = nn.Sequential(
            UpBlock(512,256,has_attn=True),
            UpBlock(256,128),
            UpBlock(128,64)
        )
        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)

        )

    def forward(self,x):
        x = self.inp(x)
        x = self.bottle(x)
        x = self.up_block(x)
        x = self.out(x)
        return x

class PatchGan(nn.Module):
    def __init__(self,in_channels=3):
        super(PatchGan, self).__init__()
        self.model=nn.Sequential(
            nn.Conv2d(in_channels,64,kernel_size=3,stride=2,padding=1),
            nn.SiLU(),

            nn.Conv2d(64,128,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,128),
            nn.SiLU(),

            nn.Conv2d(128,256,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,256),
            nn.SiLU(),

            nn.Conv2d(256,512,kernel_size=3,stride=2,padding=1),
            nn.GroupNorm(32,512),
            nn.SiLU(),

            nn.Conv2d(512,1,kernel_size=3,stride=1,padding=1)
        )

    def forward(self,x):
        x = self.model(x)
        return x

class VAE(nn.Module):
    def __init__(self,in_channels=3,out_channels=4):
        super(VAE, self).__init__()
        self.encoder = Encoder(in_channels,out_channels)
        self.decoder = Decoder(out_channels,in_channels)

    def forward(self,x):
        z,D_kl = self.encoder(x)
        x = self.decoder(z)
        return x,D_kl




In [None]:
def raw_time_embedding(time , dim):
    if not torch.is_tensor(time):
        time = torch.tensor(time)

    device=time.device
    if time.ndim == 0:
        time = time.unsqueeze(0).unsqueeze(1)
    else: # This will be execute in training since t shape is (B)
        time = time.unsqueeze(1)
        # (B) -> (B,1)

    # important to specify device
    i=torch.arange(dim//2,device=device).float()
    obj = (time)/(10000**(2*i/dim))
    return torch.cat([torch.sin(obj),torch.cos(obj)],dim=1)

class time_embedding(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,dim),
            nn.SiLU(),
            nn.Linear(dim,dim)
        )

    def forward(self,X):
        return self.net(X)

class DiffusionResNetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,dropout_prob = 0.0 ):
        super().__init__()
        self.norm1 = nn.GroupNorm(32,in_channels)
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)
        self.norm2 = nn.GroupNorm(32,out_channels)
        self.drop = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        self.silu = nn.SiLU()

        # purpose of this projection is to match channel dim , before adding to x
        self.time_proj = nn.Linear(time_emb_dim,out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels,out_channels,kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self,x,time_emb):
        x1 = x

        x = self.norm1(x)
        x = self.silu(x)
        x = self.conv1(x)

        # (B,C) -> (B,C,1,1)
        # why silu ?
        emb = self.time_proj(self.silu(time_emb))
        x = x + emb[:, :, None, None]

        x = self.norm2(x)
        x = self.silu(x)
        x = self.drop(x)
        x = self.conv2(x)
        return x + self.shortcut(x1)

class DiffusionAttentionBlock(nn.Module):
    def __init__(self,in_channels):
        super().__init__()
        self.norm = nn.GroupNorm(32,in_channels)

        self.to_q = nn.Linear(in_channels,in_channels)
        self.to_k = nn.Linear(in_channels,in_channels)
        self.to_v = nn.Linear(in_channels,in_channels)

        self.to_out = nn.Linear(in_channels,in_channels)

    def forward(self,x):
        residual = x
        B,C,H,W = x.shape
        x = self.norm(x)
        x = x.view(B,C,-1).permute(0,2,1)

        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        attn = torch.bmm(q,k.permute(0,2,1)) # batch matrix multiplication
        attn = attn * (C**(-0.5))  # sqrt(dk)
        attn = attn.softmax(dim=-1)
        attn = torch.bmm(attn,v)

        out = self.to_out(attn)
        out = out.permute(0,2,1).view(B,C,H,W)

        return out + residual

class DiffusionDownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):
        super().__init__()
        self.res1 = DiffusionResNetBlock(in_channels,out_channels,time_emb_dim)
        self.res2 = DiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        self.down = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=2,padding=1)
        if has_attn:
            self.attn = DiffusionAttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,time_emb):
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)
        x = self.attn(x)

        skip_connection = x

        x = self.down(x)
        return x , skip_connection

class DiffusionMidBlock(nn.Module):
    def __init__(self,in_channels,time_emb_dim):
        super().__init__()

        self.res1 = DiffusionResNetBlock(in_channels,in_channels,time_emb_dim)
        self.attn = DiffusionAttentionBlock(in_channels)
        self.res2 = DiffusionResNetBlock(in_channels,in_channels,time_emb_dim)

    def forward(self,x,time_emb):
        x = self.res1(x,time_emb)
        x = self.attn(x)
        x = self.res2(x,time_emb)
        return x

class DiffusionUpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_emb_dim,has_attn=False):
        super().__init__()

        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels,in_channels,kernel_size=3,padding=1)
        )
        # Why input *2 ? Because we have to concatenate the channels
        self.res1 = DiffusionResNetBlock(in_channels*2,out_channels,time_emb_dim)
        self.res2 = DiffusionResNetBlock(out_channels,out_channels,time_emb_dim)
        if has_attn:
            self.attn = DiffusionAttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    def forward(self,x,skip_connection,time_emb):
        x = self.up(x)

        # cancatenate at channels dimension
        x = torch.cat([x,skip_connection],dim=1)
        x = self.res1(x,time_emb)
        x = self.res2(x,time_emb)
        x = self.attn(x)
        return x

class DiffusionUnet(nn.Module):
    def __init__(self,in_channels=4,out_channels=4,time_dim=256):
        super().__init__()

        self.time_dim = time_dim
        self.time_embedding = time_embedding(time_dim)

        self.init_conv = nn.Conv2d(in_channels,64,kernel_size=3,padding=1)

        self.down1 = DiffusionDownBlock(64,64,time_dim)
        self.down2 = DiffusionDownBlock(64,128,time_dim)
        self.down3 = DiffusionDownBlock(128,128,time_dim,has_attn=True)
        self.down4 = DiffusionDownBlock(128,256,time_dim,has_attn=True)

        self.mid = DiffusionMidBlock(256,time_dim)

        self.up1 = DiffusionUpBlock(256,128,time_dim,has_attn=True)
        self.up2 = DiffusionUpBlock(128,128,time_dim,has_attn=True)
        self.up3 = DiffusionUpBlock(128,64,time_dim)
        self.up4 = DiffusionUpBlock(64,64,time_dim)

        self.out = nn.Sequential(
            nn.GroupNorm(32,64),
            nn.SiLU(),
            nn.Conv2d(64,out_channels,kernel_size=3,padding=1)
        )

    def forward(self,x,t,residuals):

        resi1,resi2,resi3,resi4,resi5 = residuals

        t = raw_time_embedding(t,self.time_dim)
        emb = self.time_embedding(t)

        x = self.init_conv(x)

        x1 , skip1 = self.down1(x,emb)
        x2 , skip2 = self.down2(x1,emb)
        x3 , skip3 = self.down3(x2,emb)
        x4 , skip4 = self.down4(x3,emb)

        skip1 = skip1 + resi1
        skip2 = skip2 + resi2
        skip3 = skip3 + resi3
        skip4 = skip4 + resi4

        x = self.mid(x4,emb)
        x = x + resi5

        x = self.up1(x,skip4,emb)
        x = self.up2(x,skip3,emb)
        x = self.up3(x,skip2,emb)
        x = self.up4(x,skip1,emb)

        x = self.out(x)
        return x




# ControlNet

In [None]:
import copy

def zero_conv_layer(in_channels,out_channels):
    zero_conv = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0)
    nn.init.zeros_(zero_conv.weight)
    nn.init.zeros_(zero_conv.bias)
    return zero_conv

class ConditionEncoder(nn.Module):
    # Out Channels map the init_conv in diffusion
    # 256 -> 128 -> 64 -> 32
    def __init__(self,in_channels=3,out_channels=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels,16,kernel_size=3,stride=1,padding=1),
            nn.SiLU(),

            nn.Conv2d(16,16,kernel_size=3,stride=2,padding=1),
            nn.SiLU(),

            nn.Conv2d(16,32,kernel_size=3,stride=2,padding=1),
            nn.SiLU(),

            nn.Conv2d(32,64,kernel_size=3,stride=2,padding=1),
            nn.SiLU(),

            nn.Conv2d(64,out_channels,kernel_size=3,stride=1,padding=1)
        )

    def forward(self,x):
        x = self.net(x)
        return x

class ControlNet(nn.Module):
    def __init__(self,DiffusionUnet_copy,in_channels=3,out_channels=4):
        super().__init__()
        self.encoder = ConditionEncoder()

        self.time_emb = copy.deepcopy(DiffusionUnet_copy.time_embedding)
        self.init_conv = copy.deepcopy(DiffusionUnet_copy.init_conv)

        self.down1 = copy.deepcopy(DiffusionUnet_copy.down1)
        self.zero_conv1 = zero_conv_layer(64,64)

        self.down2 = copy.deepcopy(DiffusionUnet_copy.down2)
        self.zero_conv2 = zero_conv_layer(128,128)

        self.down3 = copy.deepcopy(DiffusionUnet_copy.down3)
        self.zero_conv3 = zero_conv_layer(128,128)

        self.down4 = copy.deepcopy(DiffusionUnet_copy.down4)
        self.zero_conv4 = zero_conv_layer(256,256)

        self.mid = copy.deepcopy(DiffusionUnet_copy.mid)
        self.zero_conv5 = zero_conv_layer(256,256)

    def forward(self,x,t,condition):
        t = raw_time_embedding(t,self.time_emb.time_dim)
        emb = self.time_emb(t)

        c = self.encoder(condition)

        x = self.init_conv(x)

        x = x + c

        x,skip1 = self.down1(x,emb)
        out1 = self.zero_conv1(skip1)

        x,skip2 = self.down2(x,emb)
        out2 = self.zero_conv2(skip2)

        x,skip3 = self.down3(x,emb)
        out3 = self.zero_conv3(skip3)

        x,skip4 = self.down4(x,emb)
        out4 = self.zero_conv4(skip4)

        x_mid = self.mid(x,emb)
        out5 = self.zero_conv5(x_mid)

        return (out1,out2,out3,out4,out5)




# Upload and unzip zipfiles

In [None]:
import os

DRIVE_LATENTS_ZIP = "/content/drive/MyDrive/VAE_Training/LATENTS.zip"
DRIVE_CANNY_ZIP = "/content/drive/MyDrive/VAE_Training/Canny.zip"

os.makedirs("/content/train_data/LATENTS",exist_ok=True)
os.makedirs("/content/train_data/CANNY",exist_ok=True)

# Put -q for quite mode else it will print out a lot of stuff
!unzip -q "$DRIVE_LATENTS_ZIP" -d "/content/train_data/LATENTS"
!unzip -q "$DRIVE_CANNY_ZIP" -d "/content/train_data/CANNY"



In [None]:
from huggingface_hub import hf_hub_download
from torch.optim import AdamW
from torch.utils.data import Dataset , DataLoader
import numpy as np
import cv2
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm

SAVE_DIR = "/content/drive/MyDrive/VAE_Training/ControlNet_Checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

class ControlNetDataset(Dataset):
    def __init__(self,latent_dir,canny_dir):
        self.latent_dir = latent_dir
        self.canny_dir = canny_dir
        self.latent_files = sorted(os.listdir(latent_dir))
        self.canny_files = sorted(os.listdir(canny_dir))

        assert len(self.latent_files) == len(self.canny_files) , "Mismatch!"

    def __len__(self):
        return len(self.latent_files)

    def __getitem__(self, index):
        latent_path = os.path.join(self.latent_dir,self.latent_files[index])
        latent = torch.load(latent_path)

        canny_path = os.path.join(self.canny_dir,self.canny_files[index])
        canny = cv2.imread(canny_path)
        canny = cv2.cvtColor(canny,cv2.COLOR_BGR2RGB)

        canny = cv2.resize(canny, (256, 256), interpolation=cv2.INTER_AREA)

        # (H,W,C) -> (C,H,W) | [0,255] -> [0.0,1.0]
        canny = torch.from_numpy(canny).permute(2,0,1).float()/255.0
        return latent,canny

dataset = ControlNetDataset("/content/train_data/LATENTS/LATENTS","/content/train_data/CANNY/Canny")
dataloader = DataLoader(dataset,
                        batch_size=32,
                        shuffle=True,
                        num_workers=2,
                        pin_memory=True)

sample_latents, sample_canny = next(iter(dataloader))
print(f"Latents Batch Shape: {sample_latents.shape}")
print(f"Canny Batch Shape:   {sample_canny.shape}")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

repo_id = "ziyang06315/latent_diffusion_from_scratch"
UNET_PATH = hf_hub_download(repo_id, filename="unet_epoch_500.pth")
'''
VAE_CHECKPOINT = hf_hub_download(repo_id, filename="checkpoint.pth")


vae=VAE().to(DEVICE)
checkpoint = torch.load(VAE_CHECKPOINT,map_location=DEVICE)
vae.load_state_dict(checkpoint['vae_state_dict'])
vae.eval()
'''

unet = DiffusionUnet().to(DEVICE)
unet.load_state_dict(torch.load(UNET_PATH))
unet.eval()
for param in unet.parameters():
    param.requires_grad = False

controlnet = ControlNet(unet).to(DEVICE)
controlnet.train()

optimizer = AdamW(controlnet.parameters(),lr=1e-5)

beta_start = 0.0001
beta_end = 0.02
num_timesteps = 1000

betas = torch.linspace(beta_start, beta_end, num_timesteps,device=DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

EPOCH = 100
for epoch in range(EPOCH):
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader,desc=f"Epoch {epoch+1}/{EPOCH}")
    for latent,canny in progress_bar:
        latent = latent.to(DEVICE)
        canny = canny.to(DEVICE)

        noise = torch.randn_like(latent)
        t = torch.randint(0,num_timesteps,(latent.shape[0],),device=DEVICE)

        sqrt_alpha_cumprod = sqrt_alphas_cumprod[t][:,None,None,None]
        sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alphas_cumprod[t][:,None,None,None]

        noisy_latents = sqrt_alpha_cumprod * latent + sqrt_one_minus_alpha_cumprod * noise

        optimizer.zero_grad()
        residuals = controlnet(noisy_latents,t,canny)
        noise_pred = unet(noisy_latents,t,residuals)

        loss = F.mse_loss(noise_pred,noise)
        epoch_loss+=loss.item()
        progress_bar.set_postfix(loss=loss.item())

        loss.backward()
        optimizer.step()

    average_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1} Average Loss: {average_loss:.6f}")

    if (epoch+1) % 10 == 0:
        SAVE_PATH = os.path.join(SAVE_DIR,f"controlnet_epoch_{epoch+1}.pth")
        torch.save(controlnet.state_dict(),SAVE_PATH)
        print(f"Saved at {SAVE_PATH}")

print("Training Finished")


