<a href="https://colab.research.google.com/github/SEBIN6/Ayna_Assignment/blob/main/Unet_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## UNZIPPING DATA

In [1]:
import zipfile
import os

!unzip -q dataset.zip -d dataset

In [2]:
!pip install wandb --quiet

import os
import json
import random
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import wandb


## DATASET CLASS

In [3]:
import os
import json
import torch
from torch.utils.data import Dataset
from PIL import Image

class PolygonColorDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.input_dir = os.path.join(root_dir, "inputs")
        self.output_dir = os.path.join(root_dir, "outputs")
        self.transform = transform

        with open(os.path.join(root_dir, "data.json")) as f:
            self.data = json.load(f)


        self.color_map = {
            "red": [1, 0, 0],
            "green": [0, 1, 0],
            "blue": [0, 0, 1],
            "yellow": [1, 1, 0],
            "cyan": [0, 1, 1],
            "magenta": [1, 0, 1],
            "black": [0, 0, 0],
            "white": [1, 1, 1],
            "orange": [1, 0.5, 0],       #
            "purple": [0.5, 0, 0.5],
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        input_name = item.get("input_polygon") or item.get("input") or item.get("input_file")
        output_name = item.get("output_image") or item.get("output") or item.get("output_file")
        color_name = item.get("colour") or item.get("color")

        if input_name is None or output_name is None or color_name is None:
            raise KeyError(f"Missing expected keys in data item: {item}")

        input_path = os.path.join(self.input_dir, input_name)
        output_path = os.path.join(self.output_dir, output_name)

        if not os.path.exists(input_path):
            raise FileNotFoundError(f"Input image not found: {input_path}")
        if not os.path.exists(output_path):
            raise FileNotFoundError(f"Output image not found: {output_path}")

        input_img = Image.open(input_path).convert("L")  # grayscale
        output_img = Image.open(output_path).convert("RGB")  # target colored

        if self.transform:
            input_img = self.transform(input_img)
            output_img = self.transform(output_img)
        else:

            import torchvision.transforms as T
            to_tensor = T.ToTensor()
            input_img = to_tensor(input_img)
            output_img = to_tensor(output_img)

        # color vector (lowercase)
        color_key = color_name.lower()
        if color_key not in self.color_map:
            # fallback: try approximate or default to black
            print(f"Warning: color '{color_name}' not in map, defaulting to black")
            color_vec = torch.tensor(self.color_map["black"], dtype=torch.float32)
        else:
            color_vec = torch.tensor(self.color_map[color_key], dtype=torch.float32)

        return input_img, color_vec, output_img


## UNET

In [4]:
class UNet(nn.Module):
    def __init__(self, input_channels=4, output_channels=3):
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.enc1 = conv_block(4, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(128, 256)

        self.up1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec1 = conv_block(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = conv_block(128, 64)

        self.final = nn.Conv2d(64, output_channels, 1)

    def forward(self, x, color):
        B, _, H, W = x.shape
        color_map = color.view(B, 3, 1, 1).expand(B, 3, H, W)
        x = torch.cat([x, color_map], dim=1)

        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))

        d1 = self.up1(e3)
        d1 = torch.cat([d1, e2], dim=1)
        d1 = self.dec1(d1)

        d2 = self.up2(d1)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)

        return self.final(d2)


## TRAINING SCRIPT

In [5]:
wandb.init(project="ayna-polygon-coloring", name="unet-training")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = T.Compose([
    T.Resize((128, 128)),
    T.ToTensor()
])

train_ds = PolygonColorDataset("dataset/dataset/training", transform)
val_ds = PolygonColorDataset("dataset/dataset/validation", transform)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16)

model = UNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(10):
    model.train()
    train_loss = 0
    for x, c, y in tqdm(train_loader):
        x, c, y = x.to(device), c.to(device), y.to(device)

        optimizer.zero_grad()
        y_pred = model(x, c)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    wandb.log({"epoch": epoch, "train_loss": train_loss / len(train_loader)})


    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, c, y in val_loader:
            x, c, y = x.to(device), c.to(device), y.to(device)
            y_pred = model(x, c)
            loss = criterion(y_pred, y)
            val_loss += loss.item()

        wandb.log({"epoch": epoch, "val_loss": val_loss / len(val_loader)})


torch.save(model.state_dict(), "unet_model.pth")


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msebin2308[0m ([33msebin2308-mbccet[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 4/4 [00:53<00:00, 13.45s/it]
100%|██████████| 4/4 [00:51<00:00, 12.85s/it]
100%|██████████| 4/4 [00:50<00:00, 12.56s/it]
100%|██████████| 4/4 [00:50<00:00, 12.69s/it]
100%|██████████| 4/4 [00:51<00:00, 12.88s/it]
100%|██████████| 4/4 [00:54<00:00, 13.74s/it]
100%|██████████| 4/4 [00:55<00:00, 13.93s/it]
100%|██████████| 4/4 [00:53<00:00, 13.30s/it]
100%|██████████| 4/4 [00:52<00:00, 13.15s/it]
100%|██████████| 4/4 [00:51<00:00, 12.79s/it]
