In [1]:
!pip install torch torchvision wandb

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
from google.colab import files
uploaded=files.upload()

Saving dataset.zip to dataset.zip


In [3]:
import zipfile

with zipfile.ZipFile("/content/dataset.zip", 'r') as zip_ref:
    zip_ref.extractall("/content/")

In [4]:
import os, json
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class PolygonColorDataset(Dataset):
    def __init__(self, input_dir, output_dir, json_path, transform=None):
        with open(json_path) as f:
            self.data = json.load(f)
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.transform = transform
        self.color_names = sorted(list(set(d['colour'] for d in self.data)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        input_img = Image.open(os.path.join(self.input_dir, entry['input_polygon'])).convert("L")
        target_img = Image.open(os.path.join(self.output_dir, entry['output_image'])).convert("RGB")

        color_vec = torch.zeros(len(self.color_names))
        color_vec[self.color_names.index(entry['colour'])] = 1

        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)

        return input_img, color_vec, target_img

In [5]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

train_data = PolygonColorDataset(
    "/content/dataset/training/inputs",
    "/content/dataset/training/outputs",
    "/content/dataset/training/data.json",
    transform
)

val_data = PolygonColorDataset(
    "/content/dataset/validation/inputs",
    "/content/dataset/validation/outputs",
    "/content/dataset/validation/data.json",
    transform
)

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
val_loader = DataLoader(val_data, batch_size=8)

In [6]:
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU()
            )

        self.enc1 = conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = conv_block(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x, color_vec):
        b, _, h, w = x.shape
        color_img = color_vec.unsqueeze(2).unsqueeze(3).expand(b, color_vec.shape[1], h, w)
        x = torch.cat([x, color_img], dim=1)

        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))

        b = self.bottleneck(self.pool3(e3))

        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.final(d1)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1 + len(train_data.color_names), out_channels=3).to(device)
import pickle
with open("color_names.pkl", "wb") as f:
    pickle.dump(train_data.color_names, f)

import torch.nn.functional as F
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [14]:
from google.colab import files
files.download("color_names.pkl")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [10]:
import wandb
wandb.login()
wandb.init(project="ayna-unet-colorizer")

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for input_img, color_vec, target_img in train_loader:
        input_img = input_img.to(device)
        color_vec = color_vec.to(device)
        target_img = target_img.to(device)

        output = model(input_img, color_vec)
        loss = criterion(output, target_img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    wandb.log({"Epoch": epoch+1, "Loss": avg_loss})
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")



Epoch [1/20], Loss: 0.1156
Epoch [2/20], Loss: 0.1022
Epoch [3/20], Loss: 0.0972
Epoch [4/20], Loss: 0.0945
Epoch [5/20], Loss: 0.0894
Epoch [6/20], Loss: 0.0862
Epoch [7/20], Loss: 0.0871
Epoch [8/20], Loss: 0.0855
Epoch [9/20], Loss: 0.0820
Epoch [10/20], Loss: 0.0836
Epoch [11/20], Loss: 0.0927
Epoch [12/20], Loss: 0.0890
Epoch [13/20], Loss: 0.0832
Epoch [14/20], Loss: 0.0815
Epoch [15/20], Loss: 0.0799
Epoch [16/20], Loss: 0.0787
Epoch [17/20], Loss: 0.0720
Epoch [18/20], Loss: 0.0636
Epoch [19/20], Loss: 0.0633
Epoch [20/20], Loss: 0.0494


In [11]:
torch.save(model.state_dict(), "unet_colorizer.pth")

In [12]:
from google.colab import files
files.download("unet_colorizer.pth")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

FileNotFoundError: Cannot find file: training.ipynb