In [3]:
!pip install torch torchvision torchaudio
!pip install wandb
!pip install opencv-python pillow numpy matplotlib
!pip install tqdm

Collecting opencv-python
  Using cached opencv_python-4.12.0.88-cp37-abi3-win_amd64.whl.metadata (19 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.5-cp310-cp310-win_amd64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.2-cp310-cp310-win_amd64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.59.0-cp310-cp310-win_amd64.whl.metadata (110 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.8-cp310-cp310-win_amd64.whl.metadata (6.3 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Downloading pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)
Using cached opencv_python-4.12.0.88-cp37-abi3-win_amd64.whl (39.0 MB)
Downloading matplotlib-3.10.5-cp310-cp310-win_amd64.whl (8.1 MB)
   ---------------------------------------- 0.0/8.1 MB ? eta -:--:--
   --------- 

In [4]:
import os
import json
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class PolygonDataset(Dataset):
    def __init__(self, root_dir="C:/Users/rohan/OneDrive/Desktop/Ayna ML/dataset", split="training"):
        self.root_dir = os.path.join(root_dir, split)
        self.json_path = os.path.join(self.root_dir, "data.json")
        with open(self.json_path, 'r') as f:
            self.data = json.load(f)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128)),
            transforms.RandomRotation(30),
            transforms.RandomAffine(degrees=0, scale=(0.8, 1.2))
        ])
        self.color_map = {
            "cyan": 0, "purple": 1, "blue": 2, "red": 3,
            "green": 4, "magenta": 5, "yellow": 6, "orange": 7
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        input_img_path = os.path.join(self.root_dir, "inputs", entry["input_polygon"])
        output_img_path = os.path.join(self.root_dir, "outputs", entry["output_image"])
        input_img = Image.open(input_img_path).convert("L")
        output_img = Image.open(output_img_path).convert("RGB")
        color = self.color_map[entry["colour"]]
        color_tensor = torch.tensor(color, dtype=torch.long)
        input_img = self.transform(input_img)
        output_img = self.transform(output_img)
        return input_img, color_tensor, output_img

# Test the dataset
dataset = PolygonDataset("C:/Users/rohan/OneDrive/Desktop/Ayna ML/dataset", split="training")
input_img, color, output_img = dataset[0]
print("Input shape:", input_img.shape)
print("Color index:", color)
print("Output shape:", output_img.shape)

Input shape: torch.Size([1, 128, 128])
Color index: tensor(0)
Output shape: torch.Size([3, 128, 128])


In [8]:
import json

json_path = 'C:/Users/rohan/OneDrive/Desktop/Ayna ML/dataset/validation/data.json'
with open(json_path, 'r') as f:
    data = json.load(f)
print("First three entries:", data[:3])
colors = set(entry['colour'] for entry in data)
print("Colors in dataset:", colors)

First three entries: [{'input_polygon': 'star.png', 'colour': 'yellow', 'output_image': 'yellow_star.png'}, {'input_polygon': 'triangle.png', 'colour': 'green', 'output_image': 'green_triangle.png'}, {'input_polygon': 'octagon.png', 'colour': 'blue', 'output_image': 'blue_octagon.png'}]
Colors in dataset: {'green', 'yellow', 'blue', 'cyan'}


In [2]:
# checking the local dataset structure
import os
for root, dirs, files in os.walk("C:/Users/rohan/OneDrive/Desktop/Ayna ML"):
    print(root, dirs, files)

C:/Users/rohan/OneDrive/Desktop/Ayna ML ['dataset', '__MACOSX'] ['checkpoint_1.pth', 'checkpoint_10.pth', 'checkpoint_11.pth', 'checkpoint_12.pth', 'checkpoint_13.pth', 'checkpoint_14.pth', 'checkpoint_15.pth', 'checkpoint_16.pth', 'checkpoint_17.pth', 'checkpoint_18.pth', 'checkpoint_19.pth', 'checkpoint_2.pth', 'checkpoint_20.pth', 'checkpoint_21.pth', 'checkpoint_22.pth', 'checkpoint_23.pth', 'checkpoint_24.pth', 'checkpoint_25.pth', 'checkpoint_26.pth', 'checkpoint_27.pth', 'checkpoint_28.pth', 'checkpoint_29.pth', 'checkpoint_3.pth', 'checkpoint_30.pth', 'checkpoint_31.pth', 'checkpoint_32.pth', 'checkpoint_33.pth', 'checkpoint_34.pth', 'checkpoint_35.pth', 'checkpoint_36.pth', 'checkpoint_37.pth', 'checkpoint_38.pth', 'checkpoint_39.pth', 'checkpoint_4.pth', 'checkpoint_40.pth', 'checkpoint_41.pth', 'checkpoint_42.pth', 'checkpoint_43.pth', 'checkpoint_44.pth', 'checkpoint_45.pth', 'checkpoint_46.pth', 'checkpoint_47.pth', 'checkpoint_48.pth', 'checkpoint_49.pth', 'checkpoint_5.p

In [3]:
import os
import json
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class PolygonDataset(Dataset):
    def __init__(self, root_dir="C:/Users/rohan/OneDrive/Desktop/Ayna ML/dataset", split="training"):
        self.root_dir = os.path.join(root_dir, split)
        self.json_path = os.path.join(self.root_dir, "data.json")
        with open(self.json_path, 'r') as f:
            self.data = json.load(f)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128)),
            transforms.RandomRotation(30),
            transforms.RandomAffine(degrees=0, scale=(0.8, 1.2))
        ])
        self.color_map = {
            "cyan": 0, "purple": 1, "blue": 2, "red": 3,
            "green": 4, "magenta": 5, "yellow": 6, "orange": 7
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        input_img_path = os.path.join(self.root_dir, "inputs", entry["input_polygon"])
        output_img_path = os.path.join(self.root_dir, "outputs", entry["output_image"])
        input_img = Image.open(input_img_path).convert("L")
        output_img = Image.open(output_img_path).convert("RGB")
        color = self.color_map[entry["colour"]]
        color_tensor = torch.tensor(color, dtype=torch.long)
        input_img = self.transform(input_img)
        output_img = self.transform(output_img)
        return input_img, color_tensor, output_img

# Test the dataset
dataset = PolygonDataset("C:/Users/rohan/OneDrive/Desktop/Ayna ML/dataset", split="training")
input_img, color, output_img = dataset[0]
print("Input shape:", input_img.shape)
print("Color index:", color)
print("Output shape:", output_img.shape)

Input shape: torch.Size([1, 128, 128])
Color index: tensor(0)
Output shape: torch.Size([3, 128, 128])


In [20]:
!pip uninstall wandb -y
!pip install wandb

Found existing installation: wandb 0.21.0
Uninstalling wandb-0.21.0:
  Successfully uninstalled wandb-0.21.0
Collecting wandb
  Using cached wandb-0.21.0-py3-none-win_amd64.whl.metadata (10 kB)
Using cached wandb-0.21.0-py3-none-win_amd64.whl (21.5 MB)
Installing collected packages: wandb
Successfully installed wandb-0.21.0


In [1]:
import wandb
print(wandb.__version__)
wandb.login(key="62ec231cf8dc001f9896900c9e1b60ba5571b11f")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\rohan\_netrc


0.21.0


[34m[1mwandb[0m: Currently logged in as: [33mrohankanthale0[0m ([33mrohankanthale0-bharati-vidyapeeth[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [17]:
import os
print(os.listdir("C:/Users/rohan/OneDrive/Desktop/Ayna ML"))

['checkpoint_1.pth', 'checkpoint_10.pth', 'checkpoint_11.pth', 'checkpoint_12.pth', 'checkpoint_13.pth', 'checkpoint_14.pth', 'checkpoint_15.pth', 'checkpoint_16.pth', 'checkpoint_17.pth', 'checkpoint_18.pth', 'checkpoint_19.pth', 'checkpoint_2.pth', 'checkpoint_20.pth', 'checkpoint_21.pth', 'checkpoint_22.pth', 'checkpoint_23.pth', 'checkpoint_24.pth', 'checkpoint_25.pth', 'checkpoint_26.pth', 'checkpoint_27.pth', 'checkpoint_28.pth', 'checkpoint_29.pth', 'checkpoint_3.pth', 'checkpoint_30.pth', 'checkpoint_31.pth', 'checkpoint_32.pth', 'checkpoint_33.pth', 'checkpoint_34.pth', 'checkpoint_35.pth', 'checkpoint_36.pth', 'checkpoint_37.pth', 'checkpoint_38.pth', 'checkpoint_39.pth', 'checkpoint_4.pth', 'checkpoint_40.pth', 'checkpoint_41.pth', 'checkpoint_42.pth', 'checkpoint_43.pth', 'checkpoint_44.pth', 'checkpoint_45.pth', 'checkpoint_46.pth', 'checkpoint_47.pth', 'checkpoint_48.pth', 'checkpoint_49.pth', 'checkpoint_5.pth', 'checkpoint_50.pth', 'checkpoint_6.pth', 'checkpoint_7.pth'

In [18]:
# Import wandb first and log in
import wandb
import os
import json
import argparse
from pathlib import Path
from typing import Dict, Tuple, Optional

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset  # Added Dataset import
from torch.nn import MSELoss
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms as transforms
from torchvision.utils import make_grid

# Other imports
from PIL import Image
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Configuration class for better parameter management
class Config:
    def __init__(self):
        # Training parameters
        self.learning_rate = 5e-4  # Reduced learning rate
        self.epochs = 100  # More epochs for small dataset
        self.batch_size = 8   # Smaller batch size for small dataset
        self.num_workers = 0  # Set to 0 for Windows compatibility
        self.pin_memory = False  # Disable for stability on some systems
        
        # Model parameters
        self.in_channels = 1
        self.out_channels = 3
        self.num_colors = 8
        
        # Paths
        self.dataset_root = "C:/Users/rohan/OneDrive/Desktop/Ayna ML/dataset"
        self.checkpoint_dir = "C:/Users/rohan/OneDrive/Desktop/Ayna ML/checkpoints"
        self.wandb_project = "ayna_unet_improved"
        
        # Training settings
        self.early_stopping_patience = 15  # More patience for small dataset
        self.save_every_n_epochs = 10
        self.validate_every_n_epochs = 1
        
        # Scheduler settings
        self.scheduler_patience = 8  # More patience
        self.scheduler_factor = 0.7  # Less aggressive reduction
        self.min_lr = 1e-8

# Enhanced Dataset with better error handling and augmentations
class PolygonDataset(Dataset):
    def __init__(self, root_dir: str, split: str = "training", augment: bool = True):
        self.root_dir = Path(root_dir) / split
        self.json_path = self.root_dir / "data.json"
        self.augment = augment
        
        # Load and validate data
        try:
            with open(self.json_path, 'r') as f:
                self.data = json.load(f)
        except FileNotFoundError:
            raise FileNotFoundError(f"Data file not found: {self.json_path}")
        except json.JSONDecodeError:
            raise ValueError(f"Invalid JSON file: {self.json_path}")
        
        # Color mapping
        self.color_map = {
            "cyan": 0, "purple": 1, "blue": 2, "red": 3,
            "green": 4, "magenta": 5, "yellow": 6, "orange": 7
        }
        
        # Define transforms
        self.base_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128), antialias=True)
        ])
        
        self.augment_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128), antialias=True),
            transforms.RandomRotation(45),  # More aggressive rotation
            transforms.RandomAffine(degrees=0, scale=(0.7, 1.3), translate=(0.1, 0.1)),  # More augmentation
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.3),  # Add vertical flip
            transforms.ColorJitter(brightness=0.3, contrast=0.3) if split == "training" else transforms.Lambda(lambda x: x)
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        
        try:
            # Load images
            input_img_path = self.root_dir / "inputs" / entry["input_polygon"]
            output_img_path = self.root_dir / "outputs" / entry["output_image"]
            
            input_img = Image.open(input_img_path).convert("L")
            output_img = Image.open(output_img_path).convert("RGB")
            
            # Get color encoding
            color = self.color_map.get(entry["colour"])
            if color is None:
                raise ValueError(f"Unknown color: {entry['colour']}")
            
            color_tensor = torch.tensor(color, dtype=torch.long)
            
            
            transform = self.augment_transform if self.augment else self.base_transform
            input_img = transform(input_img)
            output_img = transform(output_img)
            
            return input_img, color_tensor, output_img
            
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Return a dummy sample in case of error
            return torch.zeros(1, 128, 128), torch.tensor(0, dtype=torch.long), torch.zeros(3, 128, 128)

# Improved UNet with better architecture
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.1):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 8, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 8, in_channels, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        attention_weights = self.attention(x)
        return x * attention_weights

class ImprovedUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, num_colors=8, dropout_rate=0.1):
        super(ImprovedUNet, self).__init__()
        self.num_colors = num_colors
        
        # Color embedding with better dimension
        self.color_embedding = nn.Embedding(num_colors, 32)
        
        # Encoder
        self.enc1 = DoubleConv(in_channels + 32, 64, dropout_rate)
        self.enc2 = DoubleConv(64, 128, dropout_rate)
        self.enc3 = DoubleConv(128, 256, dropout_rate)
        self.enc4 = DoubleConv(256, 512, dropout_rate)
        
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck with attention
        self.bottleneck = DoubleConv(512, 1024, dropout_rate)
        self.attention = AttentionBlock(1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = DoubleConv(1024, 512, dropout_rate)
        
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256, dropout_rate)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128, dropout_rate)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64, dropout_rate)
        
        # Output layer
        self.out = nn.Sequential(
            nn.Conv2d(64, out_channels, 1),
            nn.Sigmoid()  # Ensure output is in [0, 1] range
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x, color):
        # Embed color information
        batch_size, _, height, width = x.shape
        color_embed = self.color_embedding(color)
        color_embed = color_embed.view(batch_size, -1, 1, 1)
        color_embed = color_embed.expand(-1, -1, height, width)
        
        # Concatenate input with color embedding
        x = torch.cat([x, color_embed], dim=1)
        
        # Encoder path
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck with attention
        b = self.bottleneck(self.pool(e4))
        b = self.attention(b)
        
        # Decoder path
        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))
        
        return self.out(d1)

# Enhanced loss functions
class CombinedLoss(nn.Module):
    def __init__(self, mse_weight=1.0, l1_weight=0.1, perceptual_weight=0.1):
        super(CombinedLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight
        
    def forward(self, pred, target):
        mse = self.mse_loss(pred, target)
        l1 = self.l1_loss(pred, target)
        return self.mse_weight * mse + self.l1_weight * l1

# Training utilities
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        
    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

def save_checkpoint(model, optimizer, scheduler, epoch, loss, filepath):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filepath)

def load_checkpoint(filepath, model, optimizer=None, scheduler=None):
    """Load model checkpoint"""
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    return checkpoint['epoch'], checkpoint['loss']

def visualize_predictions(model, dataloader, device, num_samples=4):
    """Visualize model predictions"""
    model.eval()
    with torch.no_grad():
        for i, (input_img, color, target) in enumerate(dataloader):
            if i >= num_samples:
                break
                
            input_img = input_img.to(device)
            color = color.to(device)
            target = target.to(device)
            
            pred = model(input_img, color)
            
            # Create visualization grid
            grid_input = make_grid(input_img[:4], nrow=4, normalize=True)
            grid_target = make_grid(target[:4], nrow=4, normalize=True)
            grid_pred = make_grid(pred[:4], nrow=4, normalize=True)
            
            # Log to wandb
            wandb.log({
                f"input_batch_{i}": wandb.Image(grid_input),
                f"target_batch_{i}": wandb.Image(grid_target),
                f"prediction_batch_{i}": wandb.Image(grid_pred)
            })

# Main training function
def train_model(config: Config, wandb_key: Optional[str] = None):
    # Login to wandb if key provided
    if wandb_key:
        wandb.login(key=wandb_key)
    
    # Initialize wandb
    wandb.init(
        project=config.wandb_project,
        config=vars(config),
        name=f"unet_lr{config.learning_rate}_bs{config.batch_size}"
    )
    
    # Create checkpoint directory
    Path(config.checkpoint_dir).mkdir(parents=True, exist_ok=True)
    
    # Setup datasets
    try:
        train_dataset = PolygonDataset(config.dataset_root, split="training", augment=True)
        val_dataset = PolygonDataset(config.dataset_root, split="validation", augment=False)
    except Exception as e:
        print(f"Error loading datasets: {e}")
        return
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Setup data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True,
        num_workers=0,  # Force to 0 for Windows compatibility
        pin_memory=False,  # Disable for stability
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size,
        num_workers=0,  # Force to 0 for Windows compatibility
        pin_memory=False  # Disable for stability
    )
    
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Initialize model
    model = ImprovedUNet(
        in_channels=config.in_channels,
        out_channels=config.out_channels,
        num_colors=config.num_colors
    ).to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Setup loss, optimizer, and scheduler
    criterion = CombinedLoss()
    optimizer = Adam(model.parameters(), lr=config.learning_rate, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(
        optimizer, 
        patience=config.scheduler_patience, 
        factor=config.scheduler_factor,
        min_lr=config.min_lr,
        verbose=True
    )
    
    # Early stopping
    early_stopping = EarlyStopping(patience=config.early_stopping_patience)
    
    # Training loop
    best_val_loss = float('inf')
    
    for epoch in range(config.epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs} [Train]")
        
        for batch_idx, (input_img, color, output_img) in enumerate(train_pbar):
            input_img = input_img.to(device, non_blocking=True)
            color = color.to(device, non_blocking=True)
            output_img = output_img.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            pred = model(input_img, color)
            loss = criterion(pred, output_img)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            train_loss += loss.item()
            
            # Update progress bar
            train_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{train_loss/(batch_idx+1):.4f}'
            })
        
        train_loss /= len(train_loader)
        
        # Validation phase
        if (epoch + 1) % config.validate_every_n_epochs == 0:
            model.eval()
            val_loss = 0
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.epochs} [Val]")
            
            with torch.no_grad():
                for input_img, color, output_img in val_pbar:
                    input_img = input_img.to(device, non_blocking=True)
                    color = color.to(device, non_blocking=True)
                    output_img = output_img.to(device, non_blocking=True)
                    
                    pred = model(input_img, color)
                    loss = criterion(pred, output_img)
                    val_loss += loss.item()
                    
                    val_pbar.set_postfix({'val_loss': f'{loss.item():.4f}'})
            
            val_loss /= len(val_loader)
            
            # Update scheduler
            scheduler.step(val_loss)
            
            # Log metrics
            current_lr = optimizer.param_groups[0]['lr']
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "learning_rate": current_lr
            })
            
            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.2e}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                save_checkpoint(
                    model, optimizer, scheduler, epoch + 1, val_loss,
                    Path(config.checkpoint_dir) / "best_model.pth"
                )
                print(f"New best model saved with validation loss: {val_loss:.4f}")
            
            # Visualize predictions periodically
            if (epoch + 1) % 10 == 0:
                visualize_predictions(model, val_loader, device)
            
            # Early stopping check
            if early_stopping(val_loss):
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
        
        # Save checkpoint periodically
        if (epoch + 1) % config.save_every_n_epochs == 0:
            save_checkpoint(
                model, optimizer, scheduler, epoch + 1, train_loss,
                Path(config.checkpoint_dir) / f"checkpoint_epoch_{epoch+1}.pth"
            )
    
    # Save final model
    save_checkpoint(
        model, optimizer, scheduler, epoch + 1, val_loss if 'val_loss' in locals() else train_loss,
        Path(config.checkpoint_dir) / "final_model.pth"
    )
    
    print("Training completed!")
    wandb.finish()

# Main execution
if __name__ == "__main__":
    import sys
    
    # Check if running in Jupyter notebook
    if any('ipykernel' in arg for arg in sys.argv):
        # Running in Jupyter - use default configuration
        print("Running in Jupyter notebook - using default configuration")
        config = Config()
        
        # Train the model
        try:
            train_model(config, wandb_key="62ec231cf8dc001f9896900c9e1b60ba5571b11f")
        except KeyboardInterrupt:
            print("Training interrupted by user")
        except Exception as e:
            print(f"Training failed with error: {e}")
            raise
    else:
        # Running from command line - parse arguments
        parser = argparse.ArgumentParser(description='Train UNet for polygon coloring')
        parser.add_argument('--config', type=str, help='Path to config file')
        parser.add_argument('--wandb_key', type=str, help='Weights & Biases API key')
        parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
        
        args = parser.parse_args()
        
        # Initialize configuration
        config = Config()
        
        # Override config with command line arguments if provided
        wandb_key = args.wandb_key if args.wandb_key else "62ec231cf8dc001f9896900c9e1b60ba5571b11f"
        
        # Train the model
        try:
            train_model(config, wandb_key=wandb_key)
        except KeyboardInterrupt:
            print("Training interrupted by user")
        except Exception as e:
            print(f"Training failed with error: {e}")
            raise
















[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\rohan\_netrc


Running in Jupyter notebook - using default configuration


Training samples: 56
Validation samples: 5
Using device: cuda
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
GPU Memory: 6.4 GB
Total parameters: 31,318,595
Trainable parameters: 31,318,595


Epoch 1/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  5.52it/s, loss=0.3340, avg_loss=0.3553]
Epoch 1/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 34.24it/s, val_loss=0.2436]


Epoch 1: Train Loss: 0.3553, Val Loss: 0.2436, LR: 5.00e-04
New best model saved with validation loss: 0.2436


Epoch 2/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.66it/s, loss=0.2728, avg_loss=0.3057]
Epoch 2/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 32.46it/s, val_loss=0.4165]


Epoch 2: Train Loss: 0.3057, Val Loss: 0.4165, LR: 5.00e-04


Epoch 3/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.60it/s, loss=0.2936, avg_loss=0.2984]
Epoch 3/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 31.27it/s, val_loss=0.2044]


Epoch 3: Train Loss: 0.2984, Val Loss: 0.2044, LR: 5.00e-04
New best model saved with validation loss: 0.2044


Epoch 4/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.58it/s, loss=0.2922, avg_loss=0.3014]
Epoch 4/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.88it/s, val_loss=0.1603]


Epoch 4: Train Loss: 0.3014, Val Loss: 0.1603, LR: 5.00e-04
New best model saved with validation loss: 0.1603


Epoch 5/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.55it/s, loss=0.3166, avg_loss=0.2880]
Epoch 5/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.87it/s, val_loss=0.1941]


Epoch 5: Train Loss: 0.2880, Val Loss: 0.1941, LR: 5.00e-04


Epoch 6/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.61it/s, loss=0.2632, avg_loss=0.2864]
Epoch 6/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 32.88it/s, val_loss=0.1700]


Epoch 6: Train Loss: 0.2864, Val Loss: 0.1700, LR: 5.00e-04


Epoch 7/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.39it/s, loss=0.2714, avg_loss=0.2684]
Epoch 7/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 32.60it/s, val_loss=0.1728]


Epoch 7: Train Loss: 0.2684, Val Loss: 0.1728, LR: 5.00e-04


Epoch 8/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.49it/s, loss=0.2499, avg_loss=0.2223]
Epoch 8/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.10it/s, val_loss=0.2821]


Epoch 8: Train Loss: 0.2223, Val Loss: 0.2821, LR: 5.00e-04


Epoch 9/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.42it/s, loss=0.2395, avg_loss=0.2554]
Epoch 9/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.77it/s, val_loss=0.2390]


Epoch 9: Train Loss: 0.2554, Val Loss: 0.2390, LR: 5.00e-04


Epoch 10/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.33it/s, loss=0.2567, avg_loss=0.2203]
Epoch 10/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.68it/s, val_loss=0.1780]


Epoch 10: Train Loss: 0.2203, Val Loss: 0.1780, LR: 5.00e-04


Epoch 11/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.59it/s, loss=0.2168, avg_loss=0.2322]
Epoch 11/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 35.30it/s, val_loss=0.1627]


Epoch 11: Train Loss: 0.2322, Val Loss: 0.1627, LR: 5.00e-04


Epoch 12/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.60it/s, loss=0.1915, avg_loss=0.2218]
Epoch 12/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.49it/s, val_loss=0.1477]


Epoch 12: Train Loss: 0.2218, Val Loss: 0.1477, LR: 5.00e-04
New best model saved with validation loss: 0.1477


Epoch 13/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.56it/s, loss=0.2006, avg_loss=0.2204]
Epoch 13/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 34.73it/s, val_loss=0.1600]


Epoch 13: Train Loss: 0.2204, Val Loss: 0.1600, LR: 5.00e-04


Epoch 14/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.57it/s, loss=0.2624, avg_loss=0.2175]
Epoch 14/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 32.68it/s, val_loss=0.1645]


Epoch 14: Train Loss: 0.2175, Val Loss: 0.1645, LR: 5.00e-04


Epoch 15/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.56it/s, loss=0.2383, avg_loss=0.2231]
Epoch 15/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 34.43it/s, val_loss=0.1469]


Epoch 15: Train Loss: 0.2231, Val Loss: 0.1469, LR: 5.00e-04
New best model saved with validation loss: 0.1469


Epoch 16/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.37it/s, loss=0.2358, avg_loss=0.2136]
Epoch 16/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 34.02it/s, val_loss=0.1687]


Epoch 16: Train Loss: 0.2136, Val Loss: 0.1687, LR: 5.00e-04


Epoch 17/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.56it/s, loss=0.2576, avg_loss=0.2039]
Epoch 17/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 32.78it/s, val_loss=0.1701]


Epoch 17: Train Loss: 0.2039, Val Loss: 0.1701, LR: 5.00e-04


Epoch 18/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.54it/s, loss=0.2275, avg_loss=0.2147]
Epoch 18/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.37it/s, val_loss=0.1673]


Epoch 18: Train Loss: 0.2147, Val Loss: 0.1673, LR: 5.00e-04


Epoch 19/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.55it/s, loss=0.2172, avg_loss=0.2300]
Epoch 19/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 34.54it/s, val_loss=0.1766]


Epoch 19: Train Loss: 0.2300, Val Loss: 0.1766, LR: 5.00e-04


Epoch 20/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.57it/s, loss=0.2084, avg_loss=0.2041]
Epoch 20/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.77it/s, val_loss=0.1786]


Epoch 20: Train Loss: 0.2041, Val Loss: 0.1786, LR: 5.00e-04


Epoch 21/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.58it/s, loss=0.1975, avg_loss=0.2260]
Epoch 21/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.67it/s, val_loss=0.1572]


Epoch 21: Train Loss: 0.2260, Val Loss: 0.1572, LR: 5.00e-04


Epoch 22/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.50it/s, loss=0.2339, avg_loss=0.2164]
Epoch 22/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.27it/s, val_loss=0.1715]


Epoch 22: Train Loss: 0.2164, Val Loss: 0.1715, LR: 5.00e-04


Epoch 23/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.53it/s, loss=0.2009, avg_loss=0.2071]
Epoch 23/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 32.46it/s, val_loss=0.1740]


Epoch 23: Train Loss: 0.2071, Val Loss: 0.1740, LR: 5.00e-04


Epoch 24/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.62it/s, loss=0.1756, avg_loss=0.2124]
Epoch 24/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 34.98it/s, val_loss=0.1790]


Epoch 24: Train Loss: 0.2124, Val Loss: 0.1790, LR: 3.50e-04


Epoch 25/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.55it/s, loss=0.2545, avg_loss=0.2210]
Epoch 25/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 33.78it/s, val_loss=0.1833]


Epoch 25: Train Loss: 0.2210, Val Loss: 0.1833, LR: 3.50e-04


Epoch 26/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.49it/s, loss=0.1784, avg_loss=0.2052]
Epoch 26/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 35.15it/s, val_loss=0.1911]


Epoch 26: Train Loss: 0.2052, Val Loss: 0.1911, LR: 3.50e-04


Epoch 27/100 [Train]: 100%|██████████| 7/7 [00:01<00:00,  6.67it/s, loss=0.1853, avg_loss=0.1971]
Epoch 27/100 [Val]: 100%|██████████| 1/1 [00:00<00:00, 32.83it/s, val_loss=0.1850]


Epoch 27: Train Loss: 0.1971, Val Loss: 0.1850, LR: 3.50e-04
Early stopping triggered at epoch 27
Training completed!


0,1
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇██
learning_rate,███████████████████████▁▁▁▁
train_loss,█▆▅▆▅▅▄▂▄▂▃▂▂▂▂▂▁▂▂▁▂▂▁▂▂▁▁
val_loss,▄█▂▁▂▂▂▅▃▂▁▁▁▁▁▂▂▂▂▂▁▂▂▂▂▂▂

0,1
epoch,27.0
learning_rate,0.00035
train_loss,0.19709
val_loss,0.18503


In [3]:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118
Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import MSELoss
import os

# GPU Setup Check
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("CUDA device name:", torch.cuda.get_device_name(0))

# Force GPU usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Your training setup with GPU
model = nn.Linear(10, 1).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = MSELoss()

# Ensure model is on GPU
print("Model device:", next(model.parameters()).device)

PyTorch version: 2.5.1
CUDA available: True
CUDA device count: 1
CUDA device name: NVIDIA GeForce RTX 3060 Laptop GPU
Using device: cuda
Model device: cuda:0


In [2]:
!pip install nbformat torch

Collecting nbformat
  Using cached nbformat-5.10.4-py3-none-any.whl.metadata (3.6 kB)
Collecting fastjsonschema>=2.15 (from nbformat)
  Using cached fastjsonschema-2.21.1-py3-none-any.whl.metadata (2.2 kB)
Collecting jsonschema>=2.6 (from nbformat)
  Using cached jsonschema-4.25.0-py3-none-any.whl.metadata (7.7 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2025.7.0-py3-none-any.whl.metadata (12 kB)
Collecting sympy==1.13.1 (from torch)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting attrs>=22.2.0 (from jsonschema>=2.6->nbformat)
  Using cached attrs-25.3.0-py3-none-any.whl.metadata (10 kB)
Collecting jsonschema-specifications>=2023.03.6 (from jsonschema>=2.6->nbformat)
  Using cached jsonschema_specifications-2025.4.1-py3-none-any.whl.metadata (2.9 kB)
Collecting referencing>=0.28.4 (from jsonschema>=2.6->nbformat)
  Using cached referencing-0.36.2-py3-none-any.whl.metadata (2.8 kB)
Collecting rpds-py>=0.7.1 (from jsonschema>=2.6->nbformat)
  Download