In [2]:
import cv2
import os
from tqdm.notebook import tqdm
from pathlib import Path
from IPython.display import display, clear_output
from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor, Resize
from transformers import CLIPTextModel, CLIPTokenizer

In [5]:
import kagglehub

dataset_path = kagglehub.dataset_download("jehanbhathena/weather-dataset")

In [6]:
image_paths = []
for root, dirs, files in os.walk(dataset_path):
    for file in files:
        if file.lower().endswith(('.jpg')):
            image_paths.append(os.path.join(root, file))

In [45]:
image_paths[3470]

'/root/.cache/kagglehub/datasets/jehanbhathena/weather-dataset/versions/3/dataset/fogsmog/4357.jpg'

In [7]:
def create_bw_dataset(image_paths, output_dir):
	clear_output()
	if not os.path.exists(output_dir):
		os.makedirs(output_dir)
	if not os.path.exists(Path(output_dir) / "train"):
		os.makedirs(Path(output_dir) / "train")
	if not os.path.exists(Path(output_dir) / "train_colored"):
		os.makedirs(Path(output_dir) / "train_colored")
	progress_bar = tqdm(total=len(image_paths), desc="Creating bw images", unit="image")
	for image_id in range(len(image_paths)):
		try:
			image = cv2.imread(image_paths[image_id])
			gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
		except Exception:
			print(f"Error reading image {image_paths[image_id]}")
			progress_bar.update(1)
			continue
		for x in range(len(gray)):
			for y in range(len(gray[x])):
				if gray[x][y] < 128:
					gray[x][y] = 0
				else:
					gray[x][y] = 255
		cv2.imwrite(Path(output_dir) / "train" / f"img{image_id}.jpg", gray)
		cv2.imwrite(Path(output_dir) / "train_colored" / f"img{image_id}.jpg", image)
		progress_bar.update(1)
	progress_bar.close()

In [8]:
create_bw_dataset(image_paths, "dataset")

Creating bw images:   0%|          | 0/6862 [00:00<?, ?image/s]



Error reading image /root/.cache/kagglehub/datasets/jehanbhathena/weather-dataset/versions/3/dataset/fogsmog/4514.jpg
Error reading image /root/.cache/kagglehub/datasets/jehanbhathena/weather-dataset/versions/3/dataset/snow/1187.jpg




In [9]:
class BWColorizationDatasetCV2(Dataset):
    def __init__(self, bw_dir, color_dir, image_size=512):
        self.bw_dir = bw_dir
        self.color_dir = color_dir
        self.image_size = image_size
        self.bw_images = sorted(os.listdir(bw_dir))
        self.color_images = sorted(os.listdir(color_dir))

    def __len__(self):
        return len(self.bw_images)

    def __getitem__(self, idx):
        bw_path = os.path.join(self.bw_dir, self.bw_images[idx])
        color_path = os.path.join(self.color_dir, self.color_images[idx])

        bw_img = cv2.imread(bw_path, cv2.IMREAD_GRAYSCALE)
        color_img = cv2.imread(color_path, cv2.IMREAD_COLOR)
        color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)

        bw_img = cv2.resize(bw_img, (self.image_size, self.image_size))
        color_img = cv2.resize(color_img, (self.image_size, self.image_size))

        # Normalize to [0,1]
        bw_img = bw_img.astype("float32") / 255.0
        color_img = color_img.astype("float32") / 255.0

        # Convert to tensor and scale to [-1,1]
        bw_img = torch.from_numpy(bw_img).unsqueeze(0)  # [1,H,W]
        bw_img = bw_img * 2 - 1

        color_img = torch.from_numpy(color_img).permute(2, 0, 1)  # [3,H,W]
        color_img = color_img * 2 - 1

        return bw_img, color_img

In [None]:
dataset = BWColorizationDatasetCV2("dataset/train", "dataset/train_colored")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)

# Load models
controlnet = SD3ControlNetModel.from_pretrained(
    "stabilityai/stable-diffusion-3.5-large-controlnet-canny",
    torch_dtype=torch.float16,
).to("cuda")

pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-large",
    controlnet=controlnet,
    torch_dtype=torch.float16,
).to("cuda")

pipe.enable_model_cpu_offload()  # offload to CPU when not used, save VRAM

# Optimizer
optimizer = torch.optim.AdamW(controlnet.parameters(), lr=1e-5)

# Training loop params
num_epochs = 5
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs * len(dataloader))

controlnet.train()

clear_output()
print("Starting training...")

for epoch in range(num_epochs):
	epoch_loss = 0
	epoch_progress = tqdm(total=len(dataloader), desc="Images", unit="image")
	for bw_img, color_img in dataloader:
		bw_img = bw_img.to("cuda", dtype=torch.float16)
		color_img = color_img.to("cuda", dtype=torch.float16)

		# 1. Encode color_img to latents using pipe.vae
		latents = pipe.vae.encode(color_img).latent_dist.sample() * 0.18215  # scale latent
		latents = latents.to(torch.float16)

		# 2. Sample noise and timesteps
		noise = torch.randn_like(latents)
		timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device)

		# 3. Add noise to latents (forward diffusion)
		noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

		# 4. Get conditioning image (BW)
		control = bw_img

		# 5. Get text embeddings (you can customize prompt, here empty string)
		prompt = [""] * latents.shape[0]
		text_input = pipe.tokenizer(prompt, padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")
		text_embeddings = pipe.text_encoder(text_input.input_ids.to("cuda"))[0]

		# 6. Forward pass: UNet predicts noise residual conditioned on text and controlnet conditioning image
		noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings, controlnet_cond=control).sample

		# 7. Compute loss = MSE between predicted noise and true noise
		loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float())

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		scheduler.step()

		epoch_loss += loss.item()
        
		epoch_progress.update(1)
	
	epoch_progress.close()

	print(f"Epoch {epoch + 1}/{num_epochs} - Loss: {epoch_loss / len(dataloader):.6f}")

Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
pipe.controlnet.save_pretrained("qworks-bw-colorizer-controlnet")