UNET MODEL IMPLEMENTATION


In [None]:
import torch ##using Pytorch
import torch.nn as nn
import torch.nn.functional as F

# Feature-wise Linear Modulation (FiLM) Layer
class FiLM(nn.Module):
    def __init__(self, cond_dim, num_features):
        super().__init__()
        self.film = nn.Linear(cond_dim, num_features * 2)

    def forward(self, x, cond_vec):
        gamma_beta = self.film(cond_vec)
        gamma, beta = gamma_beta.chunk(2, dim=1)
        gamma = gamma.view(-1, x.size(1), 1, 1)
        beta = beta.view(-1, x.size(1), 1, 1)
        return gamma, beta

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, cond_dim=None, use_film=True):
        super().__init__()
        self.use_film = use_film and (cond_dim is not None)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.film1 = FiLM(cond_dim, out_ch) if self.use_film else None
        self.film2 = FiLM(cond_dim, out_ch) if self.use_film else None
        self.act = nn.SiLU()

    def forward(self, x, cond_vec=None):
        x = self.conv1(x)
        x = self.norm1(x)
        if self.use_film and cond_vec is not None:
            g, b = self.film1(x, cond_vec)
            x = g * x + b
        x = self.act(x)

        x = self.conv2(x)
        x = self.norm2(x)
        if self.use_film and cond_vec is not None:
            g, b = self.film2(x, cond_vec)
            x = g * x + b
        return self.act(x)

# Full UNet with various conditions
class UNet(nn.Module):
    def __init__(self, in_ch=4, out_ch=4, cond_dim=128):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, 64, cond_dim)
        self.enc2 = ConvBlock(64, 128, cond_dim)
        self.enc3 = ConvBlock(128, 256, cond_dim)
        self.enc4 = ConvBlock(256, 512, cond_dim)

        self.pool = nn.MaxPool2d(2)

        self.mid = ConvBlock(512, 1024, cond_dim)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = ConvBlock(1024, 512, cond_dim)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = ConvBlock(512, 256, cond_dim)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = ConvBlock(256, 128, cond_dim)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = ConvBlock(128, 64, cond_dim)

        self.out_conv = nn.Conv2d(64, out_ch, kernel_size=1)

        # To embed color information ((RGB)
        self.cond_emb = nn.Linear(9, cond_dim)  # Assuming 9-dim input condition

    def forward(self, x, color_idx=None, color_rgb=None):
        # Using RGB color
        if color_idx is not None:
            cond_input = color_idx
        else:
            cond_input = color_rgb
        cond_vec = self.cond_emb(cond_input)

        # Encoding path
        x1 = self.enc1(x, cond_vec)
        x2 = self.enc2(self.pool(x1), cond_vec)
        x3 = self.enc3(self.pool(x2), cond_vec)
        x4 = self.enc4(self.pool(x3), cond_vec)

        x_mid = self.mid(self.pool(x4), cond_vec)

        # Decoding path
        x = self.up4(x_mid)
        x = self.dec4(torch.cat([x4, x], dim=1), cond_vec)

        x = self.up3(x)
        x = self.dec3(torch.cat([x3, x], dim=1), cond_vec)

        x = self.up2(x)
        x = self.dec2(torch.cat([x2, x], dim=1), cond_vec)

        x = self.up1(x)
        x = self.dec1(torch.cat([x1, x], dim=1), cond_vec)

        return self.out_conv(x)


TRAINING SCRIPT

In [None]:
best_val = float('inf')
save_path = f"/content/cond_unet_{COND_METHOD}.pt"

for epoch in range(1, EPOCHS + 1):
    model.train()
    running = {"loss": 0.0, "mse": 0.0, "l1": 0.0}

    for batch in train_loader:
        img = batch["inp"].to(device)
        tgt = batch["target"].to(device)
        color_idx = batch["color_idx"].to(device)
        color_rgb = batch["color_rgb"].to(device)

        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            pred = model(img, color_idx, color_rgb)
            loss = F.l1_loss(pred, tgt) * 0.7 + F.mse_loss(pred, tgt) * 0.3
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        m = compute_metrics(pred.detach(), tgt)
        running["loss"] += loss.item()
        running["mse"]  += m["mse"]
        running["l1"]   += m["l1"]

    # Averaging metrics per epoch
    n_batches = len(train_loader)
    train_log = {k: v / n_batches for k, v in running.items()}
