In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!time cp -r /content/drive/MyDrive/Dissertation/Dataset/Multiple_style3.zip /content/
!unzip -q /content/Multiple_style3.zip -d /content/

In [None]:
!nvidia-smi

In [None]:
!pip install opencv-python
!pip install einops
!git clone https://github.com/lllyasviel/ControlNet-v1-1-nightly.git
!pip install --upgrade basicsr torchvision

In [None]:
import sys
sys.path.append("/content/ControlNet-v1-1-nightly")

# Dataset

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from einops import rearrange
from annotator.midas import MidasDetector

class SketchDataset(Dataset):
    def __init__(self, root_dir, tokenizer, image_processor, size=512):
        self.Source_dir = os.path.join(root_dir, "Source")
        self.Sketch_dir = os.path.join(root_dir, "Sketch")
        self.Caption_dir = os.path.join(root_dir, "Caption")

        self.filenames = sorted([
            fname.split('_')[0] for fname in os.listdir(self.Source_dir)
            if fname.endswith(('.jpg', '.png'))
        ])

        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.size = size
        self.MidasDetector = MidasDetector()
        self.style2lambda = {
            "pencil": 0.9,
            "architectural line drawing": 0.5,
            "anime lineart":0.3
        }

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]

        # Load image file
        Source_image = Image.open(os.path.join(self.Source_dir, fname + "_left.png")).convert("RGB").resize((self.size, self.size))
        Sketch_image = Image.open(os.path.join(self.Sketch_dir, fname + "_right.png")).convert("RGB").resize((self.size, self.size))

        # Derive Canny map
        Canny_image = np.array(Source_image)
        Canny_image = cv2.Canny(Canny_image, 100, 200)
        Canny_image = cv2.cvtColor(Canny_image, cv2.COLOR_GRAY2RGB)
        Canny_image = torch.from_numpy(Canny_image).float() / 255.0
        Canny_image = rearrange(Canny_image, "h w c ->  c h w")

        # Derive Depth map
        Depth_image = np.array(Source_image)
        Depth_image = self.MidasDetector(Depth_image)
        if Depth_image.ndim == 2:
          Depth_image = np.expand_dims(Depth_image, axis=2)
          Depth_image = np.repeat(Depth_image, 3, axis=2)

        Depth_image = torch.from_numpy(Depth_image).float() / 255.0
        Depth_image = rearrange(Depth_image, "h w c ->  c h w")


        # Open text
        with open(os.path.join(self.Caption_dir, fname + ".txt"), 'r', encoding='utf-8') as f:
            prompt = f.read().strip()

        # Derive style caption
        style = ""
        try:
          start = prompt.lower().index("a ") + 2
          end = prompt.lower().index(" sketch", start)
          style = prompt[start:end].strip()
        except ValueError:
          style = "UnKnown"

        # Get the text token
        Caption = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")

        return {
            "Canny_image": Canny_image,
            "Depth_image": Depth_image,
            "Sketch_image": self.image_processor(Sketch_image),
            "Caption_ids": Caption.input_ids.squeeze(0),
            "attention_mask": Caption.attention_mask.squeeze(0),
            "Style_lambda": self.style2lambda[style],
            "Style_name": style
        }

# Avg meter

In [None]:
class AverageMeter(object):
  def __init__(self):
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

# Loss Function

In [None]:
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F

# Huber_loss
def huber_loss(input, target, delta=1.0, reduction='mean'):
    diff = input - target
    abs_diff = diff.abs()
    quad = torch.clamp(abs_diff, max=delta)
    lin = abs_diff - quad
    loss = 0.5 * quad**2 + delta * lin
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    return loss


class SketchLoss(nn.Module):
    def __init__(
        self,
        latent_weight=1.0,        # MSE loss weight
        latent_l1_weight=0.0,      # L1 loss weight
        huber_weight=0.0,         # Huber loss weight
        huber_delta=1.0,         # Huber loss threshold
    ):
        super().__init__()
        self.latent_weight   = latent_weight
        self.latent_l1_weight= latent_l1_weight
        self.huber_weight    = huber_weight
        self.huber_delta     = huber_delta


    def forward(
        self,
        noise_pred,
        noise,
    ):
        total = 0.0

        # MSE Loss
        if self.latent_weight > 0:
            latent_mse = F.mse_loss(noise_pred, noise)
            total = total + self.latent_weight * latent_mse
        else:
            latent_mse = noise_pred.new_zeros(())

        # MAE Loss
        if self.latent_l1_weight > 0:
            latent_l1 = F.l1_loss(noise_pred, noise)
            total = total + self.latent_l1_weight * latent_l1
        else:
            latent_l1 = noise_pred.new_zeros(())

        # Huber Loss
        if self.huber_weight > 0:
            latent_huber = huber_loss(noise_pred, noise, delta=self.huber_delta)
            total = total + self.huber_weight * latent_huber
        else:
            latent_huber = noise_pred.new_zeros(())

        return total


# Set seed and Layer select

In [None]:
import random, numpy as np, torch

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  return torch.Generator(device="cuda").manual_seed(seed)

# Select mid layer
def get_mid_block_linear_modules(unet_model):
    target_modules = []
    for name, module in unet_model.named_modules():
        if "mid_block" in name and isinstance(module, torch.nn.Linear):
            if any(kw in name for kw in ["to_q", "to_k", "to_v", "to_out.0"]):
                target_modules.append(name)
    return target_modules

# Select up-sampling layer
def get_up_blocks_linear_modules(unet_model):
    target_modules = []
    for name, module in unet_model.named_modules():
        if "up_blocks" in name and isinstance(module, torch.nn.Linear):
          if any(kw in name for kw in ["to_q", "to_k", "to_v", "to_out.0"]):
            target_modules.append(name)
    return target_modules

# Select down-sampling layer
def get_down_blocks_linear_modules(unet_model):
    target_modules = []
    for name, module in unet_model.named_modules():
        if "down_blocks" in name and isinstance(module, torch.nn.Linear):
          if any(kw in name for kw in ["to_q", "to_k", "to_v", "to_out.0"]):
            target_modules.append(name)
    return target_modules

# Training pipline

In [None]:
from pickle import decode_long
from io import SEEK_SET
import torch
from torch.utils.data import DataLoader
from diffusers import StableDiffusionControlNetPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler, ControlNetModel
from transformers import CLIPTextModel, CLIPTokenizer
from peft import get_peft_model, LoraConfig, TaskType
from torchvision import transforms
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt


# Configuration
model_id = "runwayml/stable-diffusion-v1-5"                   #Pretrained unet model id
controlmodel_id_1 = "lllyasviel/control_v11p_sd15_canny"             #Pretrained controlnet model id (for Canny map)
controlmodel_id_2 = "lllyasviel/control_v11f1p_sd15_depth"            #Pretrained controlnet model id (for depth map)
dataset_path = "/Multiple_style3"             #Dataset path
output_dir = "/LoRA "                   #Location to save the LoRA Matrix
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 1    # Batch size
lr = 5e-5      # Learning rate
max_steps = 5000   # Training step
cond_scale_1 = 0.5  # Canny map control intensity
cond_scale_2 = 1.0  # Depth map control intensity
seed = 5711

# Load Pretrained Model
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
controlnet1 = ControlNetModel.from_pretrained(controlmodel_id_1).to(device)
controlnet2 = ControlNetModel.from_pretrained(controlmodel_id_2).to(device)

noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
set_seed(seed)


# LoRA Configuration
target_location = get_mid_block_linear_modules(unet)+get_up_blocks_linear_modules(unet) #Inject layer select

lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=target_location, # Modified target modules, for all layer using ["to_q","to_k","to_v","to_out.0"]
    lora_dropout=0.1,
    bias="none"
)
unet = get_peft_model(unet, lora_config)

# Data pre-processing
image_processor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load data
dataset = SketchDataset(dataset_path, tokenizer, image_processor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=lr)

# Training loop
step = 0
unet.train()
controlnet1.eval()
controlnet2.eval()
text_encoder.eval()
vae.eval()
train_loss = AverageMeter()
writer = SummaryWriter("runs/sketch_lora")
for epoch in range(50):    # training epoch
    for batch in tqdm(dataloader, desc=f"Epoch{epoch+1}",leave=False):
        if step >= max_steps:
            break

        Cond_image1 = batch["Canny_image"].to(device)
        Cond_image2 = batch["Depth_image"].to(device)
        Target_image = batch["Sketch_image"].to(device)
        Input_ids = batch["Caption_ids"].to(device)
        style_lambda = batch["Style_lambda"].to(device)

        # Encode to latent space
        latents = vae.encode(Target_image).latent_dist.sample() * 0.18215

        # Add noise
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get text embedding
        text_emb = text_encoder(Input_ids)[0]

        # Predict noise
        down_block_res_samples_1, mid_block_res_sample_1 = controlnet1(
            sample=noisy_latents,
            timestep=timesteps,
            encoder_hidden_states=text_emb,
            controlnet_cond=Cond_image1,
            return_dict=False
            )
        down_block_res_samples_2, mid_block_res_sample_2 = controlnet2(
            sample=noisy_latents,
            timestep=timesteps,
            encoder_hidden_states=text_emb,
            controlnet_cond=Cond_image2,
            return_dict=False
        )

        down_block_res_samples = [
            x1 * cond_scale_1+x2 * cond_scale_2 for x1, x2 in zip(down_block_res_samples_1,down_block_res_samples_2)
            ]

        mid_block_res_sample_1 = mid_block_res_sample_1 * cond_scale_1
        mid_block_res_sample_2 = mid_block_res_sample_2 * cond_scale_2

        noise_pred = unet(
            sample=noisy_latents,
            timestep=timesteps,
            encoder_hidden_states=text_emb,
            down_block_additional_residuals=down_block_res_samples,
            mid_block_additional_residual=mid_block_res_sample_1+mid_block_res_sample_2,
            ).sample

        # Loss Configuration
        loss_fn = SketchLoss(
            latent_weight=style_lambda,
            latent_l1_weight=1-style_lambda,
            huber_weight=0,
            )
        loss = loss_fn(
            noise_pred=noise_pred,
            noise=noise
            )

        # Backpropagation
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        train_loss.update(loss.item())

        if step % 75 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}, Avg_Loss:{train_loss.avg}")
            writer.add_scalar("Avg_Loss/train", train_loss.avg, step)
            writer.add_scalar("Loss/train", loss.item(), step)
            train_loss.reset()

        step += 1

# Save model
writer.close()
unet.save_pretrained(output_dir)
print(" Model training completed，saved in:", output_dir)

# Inference pipline

In [None]:
from io import SEEK_SET
import torch
from diffusers import StableDiffusionControlNetPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler, ControlNetModel
from transformers import CLIPTextModel, CLIPTokenizer
from peft import PeftModel,PeftConfig
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from einops import rearrange
from annotator.midas import MidasDetector


device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"                  # pretrained unet model id
controlmodel_id_1 = "lllyasviel/control_v11p_sd15_canny"            # pretrained controlnet1 id (for canny map)
controlmodel_id_2 = "lllyasviel/control_v11f1p_sd15_depth"           # pretrained controlnet2 id (for depth map)
lora_path = "/LoRA"            # LoRA path
source_image_path = "/Soure_img.png"    # Source Image path
result_path = "/output.png"      # Generation result saving path
prompt = "a pencil sketch of a room"  # Prompt
num_inference_steps = 200  # denoise step
cond_scale_1 = 0.5   # Canny map control intensity
cond_scale_2 = 1.0   # Depth map control intensity
seed = 1589

# Load pretrained model
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
controlnet1 = ControlNetModel.from_pretrained(controlmodel_id_1).to(device)
controlnet2 = ControlNetModel.from_pretrained(controlmodel_id_2).to(device)
noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")

# Inject trained LoRA Matrix
unet = PeftModel.from_pretrained(unet, lora_path).to(device)

# Input Image and Prompt
token=tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
prompt_emb=text_encoder(token.input_ids.to(device))[0].to(device)

source_image= Image.open(source_image_path).convert("RGB").resize((512,512))
source_image = np.array(source_image)


# Derive Canny map
canny_image = cv2.Canny(source_image, 100, 200)
canny_image = cv2.cvtColor(canny_image, cv2.COLOR_GRAY2RGB)
canny_image = torch.from_numpy(canny_image).float() / 255.0
canny_image = rearrange(canny_image, 'h w c -> 1 c h w').to(device)

# Derive Depth map
mida=MidasDetector()
depth_image = mida(source_image)
if depth_image.ndim == 2:
          depth_image = np.expand_dims(depth_image, axis=2)
          depth_image = np.repeat(depth_image, 3, axis=2)
depth_image = torch.from_numpy(depth_image).float() / 255.0
depth_image = rearrange(depth_image, "h w c -> 1 c h w").to(device)


# Inference
text_encoder.eval()
vae.eval()
unet.eval()
controlnet1.eval()
controlnet2.eval()
latents = torch.randn((1,4,64,64),device=device)
# gen = set_seed(seed)
# latents=torch.randn((1,4,64,64),generator=gen,device=device)
noise_scheduler.set_timesteps(num_inference_steps)


with torch.no_grad():
  for t in noise_scheduler.timesteps:
    down_block_res_samples_1, mid_block_res_sample_1 = controlnet1(
        sample=latents,
        timestep=t,
        encoder_hidden_states=prompt_emb,
        controlnet_cond=canny_image,
        return_dict=False
    )

    down_block_res_samples_2, mid_block_res_sample_2 = controlnet2(
        sample=latents,
        timestep=t,
        encoder_hidden_states=prompt_emb,
        controlnet_cond=depth_image,
        return_dict=False
    )

    down_block_res_samples = [
        x1 * cond_scale_1+x2 * cond_scale_2 for x1,x2 in zip(down_block_res_samples_1,down_block_res_samples_2)
        ]
    mid_block_res_sample_1 = mid_block_res_sample_1 * cond_scale_1
    mid_block_res_sample_2 = mid_block_res_sample_2 * cond_scale_2
    mid_block_res_sample = mid_block_res_sample_1+mid_block_res_sample_2

    noise_inference_pred = unet(
        sample=latents,
        timestep=t,
        encoder_hidden_states=prompt_emb,
        down_block_additional_residuals= down_block_res_samples,
        mid_block_additional_residual= mid_block_res_sample,
    ).sample

    latents = noise_scheduler.step(noise_inference_pred, t, latents).prev_sample # Corrected variable name

  # Post processing
  image = vae.decode(latents/0.18215).sample
  r,g,b = image[:,0:1],image[:,1:2],image[:,2:3]
  gray = 0.2989*r + 0.5870*g + 0.1140*b
  enhance = gray + 0.3
  image_en = enhance.repeat(1,3,1,1)
  save_image((image_en+1)/2,result_path) # save enhanced image

  Image.fromarray(depth_image, mode='L').save("/content/depth_preview.png")


# Loss Curve

In [None]:
#Launch the TensorBoard visualization
%reload_ext tensorboard
#Reading logs
%tensorboard --logdir runs/