# Ayna ML Assignment – Polygon Colorization with UNet
**Author:** Prithviraj Verma  
**Task:** Train a UNet model to generate an RGB image of a colored polygon given a grayscale polygon image and a color name.

---


# Install & Import Dependencies

In [None]:
!pip install -q wandb

import wandb
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt


In [None]:
wandb.login()


# Color Map + Dataset Class

In [None]:
COLOR_MAP = {
    "red": [1, 0, 0],
    "green": [0, 1, 0],
    "blue": [0, 0, 1],
    "yellow": [1, 1, 0],
    "cyan": [0, 1, 1],
    "magenta": [1, 0, 1],
    "white": [1, 1, 1],
    "black": [0, 0, 0],
    "purple": [0.5, 0, 0.5],
    "orange": [1, 0.5, 0]
}

class PolygonColorDataset(Dataset):
    def __init__(self, json_path, input_dir, output_dir, transform=None):
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.transform = transform or transforms.ToTensor()

        for i, entry in enumerate(self.data):
            missing = [k for k in ('input_polygon', 'output_image', 'colour') if k not in entry]
            if missing:
                raise KeyError(f"Missing keys in entry {i}: {missing} — entry: {entry}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        in_path = os.path.join(self.input_dir, entry['input_polygon'])
        out_path = os.path.join(self.output_dir, entry['output_image'])

        input_img = Image.open(in_path).convert('L')
        target_img = Image.open(out_path).convert('RGB')

        input_tensor = self.transform(input_img)
        target_tensor = self.transform(target_img)

        color = COLOR_MAP.get(entry['colour'].lower())
        if color is None:
            raise ValueError(f"Unknown color '{entry['colour']}' in entry: {entry}")
        color_tensor = torch.tensor(color, dtype=torch.float32).view(3, 1, 1).expand(3, *input_tensor.shape[1:])
        model_input = torch.cat([input_tensor, color_tensor], dim=0)

        return model_input, target_tensor

# DataLoader Setup

In [None]:
train_dataset = PolygonColorDataset(
    json_path='D:/Anya_data/dataset/training/data.json',
    input_dir='D:/Anya_data/dataset/training/inputs',
    output_dir='D:/Anya_data/dataset/training/outputs'
)

val_dataset = PolygonColorDataset(
    json_path='D:/Anya_data/dataset/validation/data.json',
    input_dir='D:/Anya_data/dataset/validation/inputs',
    output_dir='D:/Anya_data/dataset/validation/outputs'
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

# UNet Model with Dropout

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=3):
        super().__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(256, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        self.final_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        return torch.sigmoid(self.final_conv(d1))

# Training Setup

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)


# Training and Validation Loops

In [None]:
def train_one_epoch(model, dataloader):
    model.train()
    total_loss = 0
    for x, y in tqdm(dataloader, desc="Training"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

def validate(model, dataloader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in tqdm(dataloader, desc="Validation"):
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)


# Start wandb + Train

In [None]:
wandb.init(project="ayna-polygon-color", name="unet-final")

EPOCHS = 20
for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = validate(model, val_loader)
    scheduler.step()
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
    wandb.log({"train_loss": train_loss, "val_loss": val_loss})


# Visualize Results

In [None]:
def visualize(model, dataloader):
    model.eval()
    x, y = next(iter(dataloader))
    x = x.to(device)
    with torch.no_grad():
        preds = model(x).cpu()
    x = x.cpu()

    fig, axs = plt.subplots(3, 3, figsize=(10, 8))
    for i in range(3):
        axs[i][0].imshow(x[i, 0], cmap='gray')
        axs[i][0].set_title("Polygon")
        axs[i][1].imshow(y[i].permute(1, 2, 0))
        axs[i][1].set_title("Target")
        axs[i][2].imshow(preds[i].permute(1, 2, 0))
        axs[i][2].set_title("Prediction")
    plt.tight_layout()
    plt.show()

# Call this after training
# visualize(model, val_loader)


# Testing

In [None]:
def test_model(model, dataloader, num_samples=5):
    """
    Visualize predictions on a few validation samples.
    
    Parameters:
        model: Trained UNet model
        dataloader: Validation DataLoader
        num_samples: Number of samples to visualize
    """
    model.eval()
    device = next(model.parameters()).device

    x_batch, y_batch = next(iter(dataloader))
    x_batch = x_batch.to(device)

    with torch.no_grad():
        preds = model(x_batch).cpu()

    x_batch = x_batch.cpu()
    y_batch = y_batch.cpu()

    # Plotting predictions
    fig, axs = plt.subplots(num_samples, 3, figsize=(10, 4 * num_samples))

    for i in range(num_samples):
        axs[i][0].imshow(x_batch[i, 0], cmap='gray')
        axs[i][0].set_title("Grayscale Polygon", fontsize=12)
        axs[i][0].axis('off')

        axs[i][1].imshow(y_batch[i].permute(1, 2, 0))
        axs[i][1].set_title("Ground Truth (Colored)", fontsize=12)
        axs[i][1].axis('off')

        axs[i][2].imshow(preds[i].permute(1, 2, 0))
        axs[i][2].set_title("Model Prediction", fontsize=12)
        axs[i][2].axis('off')

    plt.tight_layout()
    plt.show()

# Run inference on validation set
test_model(model, val_loader, num_samples=5)
