In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UNet2DConditionModel
from diffusers import DDPMScheduler, AutoencoderKL
from diffusers.optimization import get_cosine_schedule_with_warmup
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
import os
import glob

In [None]:
# Configuration
SKETCH_DIR = "../../data/google_image/input"
IMAGE_DIR = "../../data/google_image/target"
CONTROLNET_MODEL = "lllyasviel/sd-controlnet-canny"  # Pre-trained edge-based ControlNet
SD_MODEL = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "final_model_2"
BATCH_SIZE = 2                   # Adjust based on GPU memory
GRAD_ACCUM_STEPS = 2             # Accumulate gradients for larger effective batch size
NUM_EPOCHS = 3
LR = 1e-5
RESOLUTION = 512                 # SD 1.5 uses 512x512
MIXED_PRECISION = "fp16"         # Use "no" if GPU doesn't support fp16

In [None]:
class SketchDataset(Dataset):
    def __init__(self, sketch_dir, image_dir, tokenizer, resolution=512):
        self.sketch_paths = sorted(glob.glob(os.path.join(sketch_dir, "*.png")))
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")))
        self.tokenizer = tokenizer
        self.resolution = resolution
        self.prompt = "photorealistic render, realistic, photograph, outside"  # Fixed prompt
        
        # Verify dataset
        if len(self.sketch_paths) != len(self.image_paths):
            print(f"Warning: Sketch count ({len(self.sketch_paths)}) doesn't match image count ({len(self.image_paths)})")
        
    def __len__(self):
        return min(len(self.sketch_paths), len(self.image_paths))

    def __getitem__(self, idx):
        # Load sketch and ensure proper format
        sketch = Image.open(self.sketch_paths[idx])
        if sketch.mode != "RGB":
            sketch = sketch.convert("RGB")
        sketch = sketch.resize((self.resolution, self.resolution))
        
        # Load real image
        real_img = Image.open(self.image_paths[idx])
        if real_img.mode != "RGB":
            real_img = real_img.convert("RGB")
        real_img = real_img.resize((self.resolution, self.resolution))
        
        # Tokenize text
        inputs = self.tokenizer(
            self.prompt,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Convert to tensors and normalize
        sketch_tensor = (torch.tensor(np.array(sketch)).float().permute(2, 0, 1) / 255.0 * 2.0 - 1.0)
        real_tensor = (torch.tensor(np.array(real_img)).float().permute(2, 0, 1) / 127.5 - 1.0)
        
        return {
            "pixel_values": real_tensor,
            "conditioning_pixel_values": sketch_tensor,
            "input_ids": inputs.input_ids.squeeze(0)}

In [4]:
# Initialize models
controlnet = ControlNetModel.from_pretrained(CONTROLNET_MODEL)
tokenizer = CLIPTokenizer.from_pretrained(SD_MODEL, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(SD_MODEL, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(SD_MODEL, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(SD_MODEL, subfolder="unet")

In [None]:
# Freeze all models except ControlNet
for param in controlnet.parameters():
    param.requires_grad = True 
    
for param in unet.parameters():
    param.requires_grad = False
for param in text_encoder.parameters():
    param.requires_grad = False
for param in vae.parameters():
    param.requires_grad = False

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Optimizer (only train ControlNet)
optimizer = torch.optim.AdamW(controlnet.parameters(), lr=LR, weight_decay=0.05)
dataset = SketchDataset(SKETCH_DIR, IMAGE_DIR, tokenizer, resolution=RESOLUTION)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

# Scheduler
noise_scheduler = DDPMScheduler.from_pretrained(SD_MODEL, subfolder="scheduler")

# Training setup
controlnet.to(device)
text_encoder.to(device)
vae.to(device)
unet.to(device)

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_fe

In [7]:
# Mixed precision
scaler = torch.amp.GradScaler("cuda", enabled=MIXED_PRECISION == "fp16")

print(f"Starting training with {len(dataset)} examples")
print(f"Device: {device}, Batch size: {BATCH_SIZE}, Accum steps: {GRAD_ACCUM_STEPS}")
print(f"Using ControlNet: {CONTROLNET_MODEL}")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

Starting training with 36 examples
Device: cuda, Batch size: 1, Accum steps: 4
Using ControlNet: lllyasviel/sd-controlnet-canny


In [None]:
# Training loop

global_step = 0
for epoch in range(NUM_EPOCHS):
    print(f"===== Epoch: {epoch}, Step: {global_step} =====")

    controlnet.train()
    total_loss = 0
    for step, batch in enumerate(dataloader):
        # Move batch to device
        pixel_values = batch["pixel_values"].to(device)
        conditioning_values = batch["conditioning_pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)

        # Convert images to latents
        with torch.no_grad():
            latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
            
            # Encode text
            encoder_hidden_states = text_encoder(input_ids)[0]

        # Sample noise and timesteps
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # ControlNet forward with mixed precision
        with torch.amp.autocast("cuda", enabled=MIXED_PRECISION == "fp16"):
            # Forward through ControlNet
            down_block_res_samples, mid_block_res_sample = controlnet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=encoder_hidden_states,
                controlnet_cond=conditioning_values,
                return_dict=False,
            )
            
            # Predict noise residual
            model_pred = unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=encoder_hidden_states,
                down_block_additional_residuals=down_block_res_samples,
                mid_block_additional_residual=mid_block_res_sample,
            ).sample
        
            # Compute loss
            loss = torch.nn.functional.mse_loss(model_pred, noise, reduction="mean")
            loss = loss / GRAD_ACCUM_STEPS
            print("Loss:", loss.item())
        
        # Backpropagation with gradient accumulation
        scaler.scale(loss).backward()
        total_loss += loss.item()
        
        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            
            if global_step % 5 == 0:
                avg_loss = total_loss * GRAD_ACCUM_STEPS / 5
                print(f"Epoch {epoch}, Step {global_step}, Loss: {avg_loss:.4f}")
                total_loss = 0

===== Epoch: 0, Step: 0 =====


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Loss: 0.003892288077622652
Loss: 0.02731848508119583
Loss: 0.03729882836341858
Loss: 0.020688112825155258
Loss: 0.03698461502790451
Loss: 0.009142505936324596
Loss: 0.013230041600763798
Loss: 0.0006327779847197235
Loss: 0.004590954631567001
Loss: 0.004443539772182703
Loss: 0.1007956713438034
Loss: 0.017717715352773666
Loss: 0.0508512482047081
Loss: 0.0012459794525057077
Loss: 0.018883388489484787
Loss: 0.010083112865686417
Loss: 0.016805417835712433
Loss: 0.047780923545360565
Loss: 0.022824283689260483
Loss: 0.01697903871536255
Epoch 0, Step 5, Loss: 0.3698
Loss: 0.0636143758893013
Loss: 0.004050725139677525
Loss: 0.002037922851741314
Loss: 0.00931905023753643
Loss: 0.005027182400226593
Loss: 0.01643446832895279
Loss: 0.015170787461102009
Loss: 0.026057451963424683
Loss: 0.017387758940458298
Loss: 0.0304858535528183
Loss: 0.006382010877132416
Loss: 0.003982558846473694
Loss: 0.0008740065386518836
Loss: 0.04261345416307449
Loss: 0.024621307849884033
Loss: 0.08426712453365326
===== Epoch

In [None]:
# Save final model
controlnet.save_pretrained(os.path.join(OUTPUT_DIR, "final_model"))
print(f"Training complete! Final model saved to {OUTPUT_DIR}/final_model")

Training complete! Final model saved to final_model_2/final_model


: 

### Prediction

In [9]:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
from PIL import Image

# Load fine-tuned ControlNet
controlnet = ControlNetModel.from_pretrained(
    "final_model/final_model",
    torch_dtype=torch.float16
)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    safety_checker=None  # Disable if you get NSFW false positives
).to("cuda")

# Memory optimizations
pipe.enable_model_cpu_offload()

# Process sketch
sketch = Image.open("demo.png").convert("RGB").resize((512, 512))

# Generate image
image = pipe(
    "photorealistic render, realistic, photograph, outside angle, high quality, sharp details",
    image=sketch,
    num_inference_steps=30,
    guidance_scale=7.5,
    generator=torch.Generator(device="cuda").manual_seed(42)
).images[0]

image.save("output.png")

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


  0%|          | 0/30 [00:00<?, ?it/s]