In [17]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt

# Color map used for conditioning
COLOR_MAP = {
    "red": [1, 0, 0],
    "green": [0, 1, 0],
    "blue": [0, 0, 1],
    "yellow": [1, 1, 0],
    "cyan": [0, 1, 1],
    "magenta": [1, 0, 1],
    "white": [1, 1, 1],
    "black": [0, 0, 0],
    "purple": [0.5, 0, 0.5],
    "orange": [1, 0.5, 0]# Added for your dataset
}

# Custom dataset
class PolygonColorDataset(Dataset):
    def __init__(self, json_path, input_dir, output_dir, transform=None):
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.transform = transform or transforms.ToTensor()

        for i, entry in enumerate(self.data):
            missing = [k for k in ('input_polygon', 'output_image', 'colour') if k not in entry]
            if missing:
                raise KeyError(f"Missing keys in entry {i}: {missing} — entry: {entry}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        in_path = os.path.join(self.input_dir, entry['input_polygon'])
        out_path = os.path.join(self.output_dir, entry['output_image'])

        input_img = Image.open(in_path).convert('L')
        target_img = Image.open(out_path).convert('RGB')

        input_tensor = self.transform(input_img)  # (1, H, W)
        target_tensor = self.transform(target_img)  # (3, H, W)

        color = COLOR_MAP.get(entry['colour'].lower())
        if color is None:
            raise ValueError(f"Unknown color '{entry['colour']}' in entry: {entry}")
        color_tensor = torch.tensor(color, dtype=torch.float32).view(3, 1, 1)
        color_tensor = color_tensor.expand(3, *input_tensor.shape[1:])  # (3, H, W)

        model_input = torch.cat([input_tensor, color_tensor], dim=0)  # (4, H, W)
        return model_input, target_tensor

# Dataset loaders
train_dataset = PolygonColorDataset(
    json_path='D:/Anya_data/dataset/training/data.json',
    input_dir='D:/Anya_data/dataset/training/inputs',
    output_dir='D:/Anya_data/dataset/training/outputs'
)
val_dataset = PolygonColorDataset(
    json_path='D:/Anya_data/dataset/validation/data.json',
    input_dir='D:/Anya_data/dataset/validation/inputs',
    output_dir='D:/Anya_data/dataset/validation/outputs'
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

# UNet model definition
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=3):
        super().__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(256, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        self.final_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        return torch.sigmoid(self.final_conv(d1))

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training and validation loops
def train_one_epoch(model, dataloader):
    model.train()
    total_loss = 0
    for x, y in tqdm(dataloader, desc="Training"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

def validate(model, dataloader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in tqdm(dataloader, desc="Validation"):
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

# Initialize W&B
wandb.init(project="ayna-polygon-color", name="unet-baseline")

# Training loop
EPOCHS = 20
for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = validate(model, val_loader)
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
    wandb.log({"train_loss": train_loss, "val_loss": val_loss})

# Visualization
def visualize(model, dataloader):
    model.eval()
    x, y = next(iter(dataloader))
    x = x.to(device)
    with torch.no_grad():
        preds = model(x).cpu()
    x = x.cpu()

    fig, axs = plt.subplots(3, 3, figsize=(10, 8))
    for i in range(3):
        axs[i][0].imshow(x[i, 0], cmap='gray')
        axs[i][0].set_title("Polygon")
        axs[i][1].imshow(y[i].permute(1, 2, 0))
        axs[i][1].set_title("Target")
        axs[i][2].imshow(preds[i].permute(1, 2, 0))
        axs[i][2].set_title("Prediction")
    plt.tight_layout()
    plt.show()

# Optional: visualize(model, val_loader)


Training: 100%|██████████| 7/7 [00:09<00:00,  1.35s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.20it/s]


Epoch 1: Train Loss=0.3954, Val Loss=0.4584


Training: 100%|██████████| 7/7 [00:09<00:00,  1.29s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.68it/s]


Epoch 2: Train Loss=0.3314, Val Loss=0.3615


Training: 100%|██████████| 7/7 [00:08<00:00,  1.27s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.69it/s]


Epoch 3: Train Loss=0.2943, Val Loss=0.3123


Training: 100%|██████████| 7/7 [00:09<00:00,  1.30s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.52it/s]


Epoch 4: Train Loss=0.2576, Val Loss=0.1829


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.72it/s]


Epoch 5: Train Loss=0.2239, Val Loss=0.2331


Training: 100%|██████████| 7/7 [00:09<00:00,  1.29s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.75it/s]


Epoch 6: Train Loss=0.1898, Val Loss=0.1949


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.54it/s]


Epoch 7: Train Loss=0.1603, Val Loss=0.1789


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.70it/s]


Epoch 8: Train Loss=0.1339, Val Loss=0.1454


Training: 100%|██████████| 7/7 [00:09<00:00,  1.29s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.67it/s]


Epoch 9: Train Loss=0.1217, Val Loss=0.1059


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.60it/s]


Epoch 10: Train Loss=0.1001, Val Loss=0.1178


Training: 100%|██████████| 7/7 [00:09<00:00,  1.31s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.71it/s]


Epoch 11: Train Loss=0.0979, Val Loss=0.0885


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.72it/s]


Epoch 12: Train Loss=0.0862, Val Loss=0.0767


Training: 100%|██████████| 7/7 [00:09<00:00,  1.29s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.68it/s]


Epoch 13: Train Loss=0.0775, Val Loss=0.0624


Training: 100%|██████████| 7/7 [00:09<00:00,  1.29s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.65it/s]


Epoch 14: Train Loss=0.0671, Val Loss=0.0561


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.71it/s]


Epoch 15: Train Loss=0.0652, Val Loss=0.0482


Training: 100%|██████████| 7/7 [00:09<00:00,  1.29s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.72it/s]


Epoch 16: Train Loss=0.0596, Val Loss=0.0531


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.76it/s]


Epoch 17: Train Loss=0.0584, Val Loss=0.0565


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.65it/s]


Epoch 18: Train Loss=0.0532, Val Loss=0.0374


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.63it/s]


Epoch 19: Train Loss=0.0497, Val Loss=0.0361


Training: 100%|██████████| 7/7 [00:08<00:00,  1.28s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.70it/s]


Epoch 20: Train Loss=0.0533, Val Loss=0.0423


In [19]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt

# Color map
COLOR_MAP = {
    "red": [1, 0, 0],
    "green": [0, 1, 0],
    "blue": [0, 0, 1],
    "yellow": [1, 1, 0],
    "cyan": [0, 1, 1],
    "magenta": [1, 0, 1],
    "white": [1, 1, 1],
    "black": [0, 0, 0],
    "purple": [0.5, 0, 0.5],
    "orange": [1, 0.5, 0]
}

# Custom dataset
class PolygonColorDataset(Dataset):
    def __init__(self, json_path, input_dir, output_dir, transform=None):
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.transform = transform or transforms.ToTensor()

        for i, entry in enumerate(self.data):
            missing = [k for k in ('input_polygon', 'output_image', 'colour') if k not in entry]
            if missing:
                raise KeyError(f"Missing keys in entry {i}: {missing} — entry: {entry}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        in_path = os.path.join(self.input_dir, entry['input_polygon'])
        out_path = os.path.join(self.output_dir, entry['output_image'])

        input_img = Image.open(in_path).convert('L')
        target_img = Image.open(out_path).convert('RGB')

        input_tensor = self.transform(input_img)  # (1, H, W)
        target_tensor = self.transform(target_img)  # (3, H, W)

        color = COLOR_MAP.get(entry['colour'].lower())
        if color is None:
            raise ValueError(f"Unknown color '{entry['colour']}' in entry: {entry}")
        color_tensor = torch.tensor(color, dtype=torch.float32).view(3, 1, 1)
        color_tensor = color_tensor.expand(3, *input_tensor.shape[1:])

        model_input = torch.cat([input_tensor, color_tensor], dim=0)
        return model_input, target_tensor

# Data transforms with augmentation
data_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()
])

# Dataset loaders
train_dataset = PolygonColorDataset(
    json_path='D:/Anya_data/dataset/training/data.json',
    input_dir='D:/Anya_data/dataset/training/inputs',
    output_dir='D:/Anya_data/dataset/training/outputs'
)
val_dataset = PolygonColorDataset(
    json_path='D:/Anya_data/dataset/validation/data.json',
    input_dir='D:/Anya_data/dataset/validation/inputs',
    output_dir='D:/Anya_data/dataset/validation/outputs'
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

# UNet model with Dropout
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=3):
        super().__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(256, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        self.final_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        return torch.sigmoid(self.final_conv(d1))

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Training and validation loops
def train_one_epoch(model, dataloader):
    model.train()
    total_loss = 0
    for x, y in tqdm(dataloader, desc="Training"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

def validate(model, dataloader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in tqdm(dataloader, desc="Validation"):
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
    return total_loss / len(dataloader.dataset)

# Initialize W&B
wandb.init(project="ayna-polygon-color", name="unet-augmented")

# Training loop
EPOCHS = 20
for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = validate(model, val_loader)
    scheduler.step()
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
    wandb.log({"train_loss": train_loss, "val_loss": val_loss})

# Visualization
def visualize(model, dataloader):
    model.eval()
    x, y = next(iter(dataloader))
    x = x.to(device)
    with torch.no_grad():
        preds = model(x).cpu()
    x = x.cpu()

    fig, axs = plt.subplots(3, 3, figsize=(10, 8))
    for i in range(3):
        axs[i][0].imshow(x[i, 0], cmap='gray')
        axs[i][0].set_title("Polygon")
        axs[i][1].imshow(y[i].permute(1, 2, 0))
        axs[i][1].set_title("Target")
        axs[i][2].imshow(preds[i].permute(1, 2, 0))
        axs[i][2].set_title("Prediction")
    plt.tight_layout()
    plt.show()

# To use after training:
# visualize(model, val_loader)


0,1
train_loss,█▇▆▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁
val_loss,█▆▆▃▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁

0,1
train_loss,0.05325
val_loss,0.04233


Training: 100%|██████████| 7/7 [00:12<00:00,  1.83s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.40it/s]


Epoch 1: Train Loss=0.4012, Val Loss=0.4527


Training: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.87it/s]


Epoch 2: Train Loss=0.3465, Val Loss=0.5395


Training: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.86it/s]


Epoch 3: Train Loss=0.3023, Val Loss=0.2989


Training: 100%|██████████| 7/7 [00:12<00:00,  1.73s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.80it/s]


Epoch 4: Train Loss=0.2760, Val Loss=0.2804


Training: 100%|██████████| 7/7 [00:12<00:00,  1.72s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.88it/s]


Epoch 5: Train Loss=0.2403, Val Loss=0.2187


Training: 100%|██████████| 7/7 [00:11<00:00,  1.71s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.67it/s]


Epoch 6: Train Loss=0.2233, Val Loss=0.2058


Training: 100%|██████████| 7/7 [00:12<00:00,  1.72s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.78it/s]


Epoch 7: Train Loss=0.2060, Val Loss=0.2888


Training: 100%|██████████| 7/7 [00:11<00:00,  1.70s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.83it/s]


Epoch 8: Train Loss=0.1778, Val Loss=0.3128


Training: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.79it/s]


Epoch 9: Train Loss=0.1602, Val Loss=0.2247


Training: 100%|██████████| 7/7 [00:12<00:00,  1.72s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.86it/s]


Epoch 10: Train Loss=0.1444, Val Loss=0.1757


Training: 100%|██████████| 7/7 [00:11<00:00,  1.71s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.76it/s]


Epoch 11: Train Loss=0.1357, Val Loss=0.1616


Training: 100%|██████████| 7/7 [00:12<00:00,  1.73s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.79it/s]


Epoch 12: Train Loss=0.1203, Val Loss=0.1118


Training: 100%|██████████| 7/7 [00:11<00:00,  1.71s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.77it/s]


Epoch 13: Train Loss=0.1195, Val Loss=0.1707


Training: 100%|██████████| 7/7 [00:12<00:00,  1.72s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.83it/s]


Epoch 14: Train Loss=0.1095, Val Loss=0.0950


Training: 100%|██████████| 7/7 [00:12<00:00,  1.73s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.92it/s]


Epoch 15: Train Loss=0.0995, Val Loss=0.1244


Training: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.80it/s]


Epoch 16: Train Loss=0.0965, Val Loss=0.0830


Training: 100%|██████████| 7/7 [00:12<00:00,  1.72s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.75it/s]


Epoch 17: Train Loss=0.0920, Val Loss=0.0731


Training: 100%|██████████| 7/7 [00:11<00:00,  1.71s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.85it/s]


Epoch 18: Train Loss=0.0937, Val Loss=0.0906


Training: 100%|██████████| 7/7 [00:11<00:00,  1.69s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s]


Epoch 19: Train Loss=0.0830, Val Loss=0.0686


Training: 100%|██████████| 7/7 [00:11<00:00,  1.71s/it]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.89it/s]


Epoch 20: Train Loss=0.0809, Val Loss=0.0905
