In [None]:
!pip install datasets

In [1]:
from google.colab import files
uploaded = files.upload()

In [None]:
import csv
import gc
import os
from matplotlib import pyplot as plt
import pandas as pd
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from transformers import CLIPModel
from tqdm import tqdm
from data import dataloader
from diffusion_utils import load_latest_checkpoint, save_checkpoint
from torch import nn
import torch.nn.functional as F
from torch.amp import GradScaler

In [44]:
device = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)

In [45]:
text_encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16).to(device)
model_id = "lllyasviel/control_v11p_sd15_seg"
controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
controlnet.to(device)

ControlNetModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (controlnet_cond_embedding): ControlNetConditioningEmbedding(
    (conv_in): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (blocks): ModuleList(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): Conv2d(32, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): Conv2d(96, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (conv_out): Conv2d(256, 320, ker

In [46]:
# print('model', controlnet)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.to(device)
vae = pipe.vae

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [47]:
gc.collect()

0

In [48]:
save_dir = "weights/controlnet/"
os.makedirs(save_dir, exist_ok=True)

In [49]:
loss_file_path = os.path.join(save_dir, "loss_val.csv")
if not os.path.exists(loss_file_path):
    with open(loss_file_path, mode="w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "epoch_loss"])

In [50]:
# Define optimizer and loss
optim = torch.optim.AdamW(
    controlnet.parameters(), lr=1e-4, weight_decay=1e-2, betas=(0.9, 0.999)
)

In [51]:
# criterion = torch.nn.MSELoss()  # For pixel-wise tasks
criterion = nn.SmoothL1Loss()       # For better stability - showing Nan loss for MSE -- HuberLoss (SmoothL1: l1 + MSE loss)

In [81]:
nn_model, optim, start_epoch, loss = load_latest_checkpoint(
    controlnet, optim, save_dir, device=device
)


No checkpoints found, starting from scratch.


In [53]:
# Training loop
controlnet.train()
num_epochs = 32
timesteps = 500

In [54]:
scaler = GradScaler("cuda")
torch.cuda.empty_cache()
gc.collect()

29

In [77]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

In [82]:
def train_model(nn_model, data_loader, start_epoch, n_epoch):
    upsample_block = nn.Sequential(
        # Upsample from [4, 1280, 4, 4] to [4, 1280, 256, 256] using F.interpolate
        nn.Conv2d(
            1280, 640, kernel_size=3, padding=1, stride=1
        ),  # Reduce channels from 1280 to 640
        nn.BatchNorm2d(640),
        nn.ReLU(), # Depthwise separable convolution: reduces channels and memory usage
        nn.Conv2d(
            640, 320, kernel_size=3, padding=1, stride=1
        ), # Reduce channels to 320
        nn.BatchNorm2d(320),
        nn.ReLU(),
        nn.Conv2d(320, 3, kernel_size=3, padding=1, stride=1)  # Convert to 3 channels (RGB)
    ).to(device).to(dtype=torch.float16)
    initialize_weights(upsample_block)

    for name, param in controlnet.named_parameters():
        if "mid_block" in name:
            # print('params in--inside mid block')
            param.requires_grad = True
        else:
            param.requires_grad = False

    for name, module in nn_model.named_modules():
        if 'mid_block' in name:
            # print('module--inside mid block')
            # print(f"Layer name: {name}, Layer type: {type(module)}")
            module.requires_grad = True  # Fine-tune higher layers
        else:
            module.requires_grad = False  # Freeze other layers

    for ep in range(start_epoch, num_epochs):
        epoch_loss = 0.0
        accumulation_steps = 4

        pbar = tqdm(data_loader, mininterval=2)
        for i, batch in enumerate(pbar):
            optim.zero_grad(set_to_none=True)
            images, masks, text_emb = batch
            print("masks shape---before", masks.shape)

            masks = masks.repeat(1, 3, 1, 1)    # 1 channel to 3 channel conversion
            print('images shaepe', images.shape)
            print('masks shape---after', masks.shape)
            print('text emb shapessss-------', text_emb.shape)
            images, masks, text_emb = (
                images.to(device).to(dtype=torch.float16),
                masks.to(device).to(dtype=torch.float16),
                text_emb.to(device).to(dtype=torch.float16),
            )
            with torch.autocast(device_type='cuda'):
                text_emb_resized = nn.Linear(512, 768).to(device)(
                    text_emb
                )  # Resize to match 768 features
                t = torch.randint(1, timesteps + 1, (images.shape[0],)).to(device)

                latents = vae.encode(images).latent_dist.sample().to(device)

                # Forward pass
                out_model = nn_model(
                    sample=latents,
                    timestep=t,
                    encoder_hidden_states=text_emb_resized,
                    controlnet_cond=masks,  # Segmentation masks as conditioning
                )
                # print("training after", dir(out_model), type(out_model))
                # print(out_model.down_block_res_samples[0].shape)  # Check if this contains the image
                print(out_model.mid_block_res_sample.shape)
                generated_image = out_model.mid_block_res_sample

                generated_image = upsample_block(generated_image)
                print("gen image------", generated_image.shape)
                generated_image_resized = F.interpolate(
                    generated_image,
                    size=(320, 320),
                    mode="bilinear",
                    align_corners=False,
                )
                print("gen image------", generated_image_resized.shape)
                print('image orig', images.shape)
                loss = criterion(generated_image_resized, images)   # F.mse_loss
                print(f"Epoch {ep+1}/{num_epochs}, Loss: {loss.item()}")

            epoch_loss += loss.item()
            # model_engine.backward(loss)
            # model_engine.step()
            # scaler.scale(loss).backward()
            loss.backward()
            for name, param in nn_model.named_parameters():
              if param.requires_grad and param.grad is None:
                print(f"NaN gradient detected in parameter {name}")

            if (i + 1) % accumulation_steps == 0 or i == len(pbar):
                # scaler.unscale_(optim)
                torch.nn.utils.clip_grad_norm_(upsample_block.parameters(), max_norm=1.0)
                optim.step()
                # scaler.step(optim)
                # scaler.update()
                optim.zero_grad(set_to_none=True)

            del (
                images,
                masks,
                text_emb,
                loss,
                t,
                generated_image,
                generated_image_resized,
            )
            torch.cuda.empty_cache()

        # Calculate and log average loss for the epoch
        avg_loss = epoch_loss / len(dataloader)

        with open(loss_file_path, mode="a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([ep + 1, avg_loss])

        if ep % 4 == 0 or ep == int(n_epoch - 1):
            save_checkpoint(nn_model, optim, ep, epoch_loss, save_dir)
            print("saved model at " + save_dir + f"model_{ep}.pth")

    # Plot losses
    data = pd.read_csv(loss_file_path)
    plt.figure(figsize=(8, 6))
    plt.plot(
        data["epoch"],
        data["epoch_loss"],
        marker="o",
        linestyle="-",
        color="b",
        label="Training Loss",
    )
    plt.title("Training Loss Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [83]:
def main():
    gc.collect()
    torch.cuda.empty_cache()
    train_model(nn_model=nn_model, data_loader=dataloader, start_epoch=start_epoch, n_epoch=num_epochs)

In [84]:
if __name__ == "__main__":
    main()

  0%|          | 0/1499 [00:00<?, ?it/s]

-----label name b'mourning_warbler'
-----label name b'yellow_breasted_chat'
-----label name b'house_sparrow'
-----label name b'american_three_toed_woodpecker'
-----label name b'western_gull'
-----label name b'loggerhead_shrike'
-----label name b'caspian_tern'
-----label name b'caspian_tern'
-----label name b'ring_billed_gull'-----label name
 b'horned_lark'
-----label name masks shape---before torch.Size([4, 1, 320, 320])
images shaepe torch.Size([4, 3, 320, 320])
masks shape---after torch.Size([4, 3, 320, 320])
text emb shapessss------- torch.Size([4, 1, 512])
b'ovenbird'-----label name
 b'cactus_wren'
-----label name -----label name b'cape_glossy_starling'b'clay_colored_sparrow'

torch.Size([4, 1280, 5, 5])
gen image------ torch.Size([4, 3, 5, 5])
gen image------ torch.Size([4, 3, 320, 320])
image orig torch.Size([4, 3, 320, 320])
Epoch 1/32, Loss: nan
-----label name -----label nameb'horned_puffin' b'red_breasted_merganser'

-----label name b'western_meadowlark'
-----label name b'eur

  0%|          | 2/1499 [00:02<27:37,  1.11s/it]


-----label name b'caspian_tern'
-----label name b'summer_tanager'
masks shape---before torch.Size([4, 1, 320, 320])
images shaepe torch.Size([4, 3, 320, 320])
masks shape---after torch.Size([4, 3, 320, 320])
text emb shapessss------- torch.Size([4, 1, 512])
torch.Size([4, 1280, 5, 5])
gen image------ torch.Size([4, 3, 5, 5])
gen image------ torch.Size([4, 3, 320, 320])
image orig torch.Size([4, 3, 320, 320])
-----label name b'palm_warbler'
-----label name b'mourning_warbler'
Epoch 1/32, Loss: nan
-----label name b'scott_oriole'
-----label name b'grasshopper_sparrow'
masks shape---before torch.Size([4, 1, 320, 320])
images shaepe torch.Size([4, 3, 320, 320])
masks shape---after torch.Size([4, 3, 320, 320])
text emb shapessss------- torch.Size([4, 1, 512])
torch.Size([4, 1280, 5, 5])
gen image------ torch.Size([4, 3, 5, 5])
gen image------ torch.Size([4, 3, 320, 320])
image orig torch.Size([4, 3, 320, 320])
-----label name b'green_jay'
-----label name b'common_tern'
Epoch 1/32, Loss: na

  0%|          | 5/1499 [00:04<24:45,  1.01it/s]


KeyboardInterrupt: 