In [None]:
import json
import os
from typing import Dict, List, Tuple, Optional
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np


class PolygonColorDataset(Dataset):
    def __init__(self,
                 data_json_path: str,
                 input_dir: str,
                 output_dir: str,
                 image_size: int = 256,
                 augmentations: Optional[transforms.Compose] = None,
                 color_to_idx: Optional[Dict[str, int]] = None):

        self.input_dir = input_dir
        self.output_dir = output_dir
        self.image_size = image_size
        self.augmentations = augmentations

        # Load data from JSON
        with open(data_json_path, 'r') as f:
            self.data = json.load(f)

        # Create color mappings
        if color_to_idx:
            self.color_to_idx = color_to_idx
            self.colors = sorted(list(self.color_to_idx.keys()))
        else:
            self.colors = sorted(list(set(item['colour'] for item in self.data)))
            self.color_to_idx = {color: idx for idx, color in enumerate(self.colors)}

        self.num_colors = len(self.colors)

        print(f"Dataset initialized with {len(self.data)} samples")
        print(f"Colors: {self.colors}")
        print(f"Color mapping: {self.color_to_idx}")

        # Define transforms
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])

        # Normalize to [-1, 1] for better training stability
        self.normalize = transforms.Normalize([0.5], [0.5])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load input polygon image
        input_path = os.path.join(self.input_dir, item['input_polygon'])
        input_image = Image.open(input_path).convert('RGB')

        # Load target output image
        output_path = os.path.join(self.output_dir, item['output_image'])
        output_image = Image.open(output_path).convert('RGB')

        # Apply transforms
        input_tensor = self.transform(input_image)
        output_tensor = self.transform(output_image)

        # Normalize to [-1, 1]
        input_tensor = self.normalize(input_tensor)
        output_tensor = self.normalize(output_tensor)

        # Get color information
        color_name = item['colour']
        color_idx = self.color_to_idx[color_name]

        # Create one-hot encoding for color
        color_onehot = torch.zeros(self.num_colors)
        color_onehot[color_idx] = 1.0

        # Apply augmentations if provided
        if self.augmentations:
            # Apply same augmentation to both input and output
            seed = np.random.randint(2147483647)  # make a seed with numpy generator

            # Apply to input
            torch.manual_seed(seed)
            input_tensor = self.augmentations(input_tensor)

            # Apply to output
            torch.manual_seed(seed)
            output_tensor = self.augmentations(output_tensor)

        return {
            'input_image': input_tensor,
            'color_onehot': color_onehot,
            'color_idx': color_idx,
            'color_name': color_name,
            'output_image': output_tensor,
            'input_filename': item['input_polygon'],
            'output_filename': item['output_image']
        }

    def get_color_embedding_dim(self):
        """Return the dimension of color embeddings (number of unique colors)"""
        return self.num_colors


def get_data_loaders(train_json_path: str,
                    val_json_path: str,
                    train_input_dir: str,
                    train_output_dir: str,
                    val_input_dir: str,
                    val_output_dir: str,
                    batch_size: int = 16,
                    image_size: int = 256,
                    num_workers: int = 4,
                    use_augmentations: bool = True):

    # Define augmentations for training
    train_augmentations = None
    if use_augmentations:
        train_augmentations = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        ])

    # Create datasets
    train_dataset = PolygonColorDataset(
        data_json_path=train_json_path,
        input_dir=train_input_dir,
        output_dir=train_output_dir,
        image_size=image_size,
        augmentations=train_augmentations
    )

    # Use the color mapping from the training set for the validation set
    color_to_idx = train_dataset.color_to_idx

    val_dataset = PolygonColorDataset(
        data_json_path=val_json_path,
        input_dir=val_input_dir,
        output_dir=val_output_dir,
        image_size=image_size,
        augmentations=None,  # No augmentations for validation
        color_to_idx=color_to_idx
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False
    )

    # Return color information for model initialization
    color_info = {
        'colors': train_dataset.colors,
        'color_to_idx': train_dataset.color_to_idx,
        'num_colors': train_dataset.num_colors
    }

    return train_loader, val_loader, color_info


# Example usage and testing
if __name__ == "__main__":
    # Test dataset loading
    dataset = PolygonColorDataset(
        data_json_path="/content/drive/MyDrive/dataset/dataset/training/data.json",
        input_dir="/content/drive/MyDrive/dataset/dataset/training/inputs",
        output_dir="/content/drive/MyDrive/dataset/dataset/training/outputs",
        image_size=256
    )

    # Test a single sample
    sample = dataset[0]
    print(f"Sample keys: {sample.keys()}")
    print(f"Input image shape: {sample['input_image'].shape}")
    print(f"Output image shape: {sample['output_image'].shape}")
    print(f"Color onehot shape: {sample['color_onehot'].shape}")
    print(f"Color name: {sample['color_name']}")
    print(f"Color index: {sample['color_idx']}")

    # Test data loader
    from torch.utils.data import DataLoader
    loader = DataLoader(dataset, batch_size=4, shuffle=True)
    batch = next(iter(loader))
    print(f"\nBatch input shape: {batch['input_image'].shape}")
    print(f"Batch output shape: {batch['output_image'].shape}")
    print(f"Batch color onehot shape: {batch['color_onehot'].shape}")

Dataset initialized with 56 samples
Colors: ['blue', 'cyan', 'green', 'magenta', 'orange', 'purple', 'red', 'yellow']
Color mapping: {'blue': 0, 'cyan': 1, 'green': 2, 'magenta': 3, 'orange': 4, 'purple': 5, 'red': 6, 'yellow': 7}
Sample keys: dict_keys(['input_image', 'color_onehot', 'color_idx', 'color_name', 'output_image', 'input_filename', 'output_filename'])
Input image shape: torch.Size([3, 256, 256])
Output image shape: torch.Size([3, 256, 256])
Color onehot shape: torch.Size([8])
Color name: cyan
Color index: 1

Batch input shape: torch.Size([4, 3, 256, 256])
Batch output shape: torch.Size([4, 3, 256, 256])
Batch color onehot shape: torch.Size([4, 8])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class FiLM(nn.Module):
    """Feature-wise Linear Modulation layer"""

    def __init__(self, condition_dim: int, feature_dim: int):
        super().__init__()
        self.condition_dim = condition_dim
        self.feature_dim = feature_dim

        # Linear layers to generate scale (gamma) and shift (beta) parameters
        self.gamma_linear = nn.Linear(condition_dim, feature_dim)
        self.beta_linear = nn.Linear(condition_dim, feature_dim)

    def forward(self, x, condition):
        """
        Args:
            x: Feature tensor of shape (B, C, H, W)
            condition: Condition tensor of shape (B, condition_dim)

        Returns:
            Modulated features of shape (B, C, H, W)
        """
        # Generate scale and shift parameters
        gamma = self.gamma_linear(condition)  # (B, feature_dim)
        beta = self.beta_linear(condition)    # (B, feature_dim)

        # Reshape for broadcasting with feature maps
        gamma = gamma.view(gamma.size(0), gamma.size(1), 1, 1)  # (B, C, 1, 1)
        beta = beta.view(beta.size(0), beta.size(1), 1, 1)      # (B, C, 1, 1)

        # Apply FiLM: scale and shift
        return gamma * x + beta


class ConditionalUNet(nn.Module):
    """
    Conditional UNet for polygon coloring.
    Takes an input image and color condition to generate colored polygon.
    """

    def __init__(self,
                 n_channels: int = 3,
                 n_classes: int = 3,
                 num_colors: int = 8,
                 color_embed_dim: int = 128,
                 bilinear: bool = True):
        super(ConditionalUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.num_colors = num_colors
        self.color_embed_dim = color_embed_dim
        self.bilinear = bilinear

        # Color embedding layer
        self.color_embedding = nn.Sequential(
            nn.Linear(num_colors, color_embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(color_embed_dim, color_embed_dim),
            nn.ReLU(inplace=True)
        )

        # UNet encoder
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        # FiLM layers for conditioning at different scales
        self.film1 = FiLM(color_embed_dim, 64)
        self.film2 = FiLM(color_embed_dim, 128)
        self.film3 = FiLM(color_embed_dim, 256)
        self.film4 = FiLM(color_embed_dim, 512)
        self.film5 = FiLM(color_embed_dim, 1024 // factor)

        # UNet decoder
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

        # Output activation
        self.output_activation = nn.Tanh()  # Output in [-1, 1] range

    def forward(self, x, color_condition):
        """
        Args:
            x: Input image tensor of shape (B, 3, H, W)
            color_condition: Color condition tensor of shape (B, num_colors) - one-hot encoded

        Returns:
            Generated colored polygon of shape (B, 3, H, W)
        """
        # Embed color condition
        color_embed = self.color_embedding(color_condition)  # (B, color_embed_dim)

        # Encoder path with FiLM conditioning
        x1 = self.inc(x)
        x1 = self.film1(x1, color_embed)

        x2 = self.down1(x1)
        x2 = self.film2(x2, color_embed)

        x3 = self.down2(x2)
        x3 = self.film3(x3, color_embed)

        x4 = self.down3(x3)
        x4 = self.film4(x4, color_embed)

        x5 = self.down4(x4)
        x5 = self.film5(x5, color_embed)

        # Decoder path
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        # Output layer
        logits = self.outc(x)
        output = self.output_activation(logits)

        return output


class AlternativeConditionalUNet(nn.Module):
    def __init__(self,
                 n_channels: int = 3,
                 n_classes: int = 3,
                 num_colors: int = 8,
                 bilinear: bool = True):
        super(AlternativeConditionalUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.num_colors = num_colors
        self.bilinear = bilinear

        # Input channels = image channels + color channels
        input_channels = n_channels + num_colors

        # UNet architecture
        self.inc = DoubleConv(input_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)

        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

        self.output_activation = nn.Tanh()

    def forward(self, x, color_condition):
        """
        Args:
            x: Input image tensor of shape (B, 3, H, W)
            color_condition: Color condition tensor of shape (B, num_colors)

        Returns:
            Generated colored polygon of shape (B, 3, H, W)
        """
        batch_size, _, height, width = x.shape

        # Expand color condition to match spatial dimensions
        color_maps = color_condition.unsqueeze(-1).unsqueeze(-1)  # (B, num_colors, 1, 1)
        color_maps = color_maps.expand(-1, -1, height, width)     # (B, num_colors, H, W)

        # Concatenate image and color information
        x_conditioned = torch.cat([x, color_maps], dim=1)  # (B, 3+num_colors, H, W)

        # Standard UNet forward pass
        x1 = self.inc(x_conditioned)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        logits = self.outc(x)
        output = self.output_activation(logits)

        return output


def create_model(model_type: str = "film", **kwargs):
    if model_type == "film":
        return ConditionalUNet(**kwargs)
    elif model_type == "concat":
        return AlternativeConditionalUNet(**kwargs)
    else:
        raise ValueError(f"Unknown model type: {model_type}")


# Test the models
if __name__ == "__main__":
    # Test FiLM-based model
    model_film = ConditionalUNet(num_colors=8)

    # Test input
    batch_size = 2
    x = torch.randn(batch_size, 3, 256, 256)
    color_condition = torch.zeros(batch_size, 8)
    color_condition[0, 0] = 1  # First sample: color 0
    color_condition[1, 3] = 1  # Second sample: color 3

    # Forward pass
    output = model_film(x, color_condition)
    print(f"FiLM model output shape: {output.shape}")

    # Test concatenation-based model
    model_concat = AlternativeConditionalUNet(num_colors=8)
    output_concat = model_concat(x, color_condition)
    print(f"Concatenation model output shape: {output_concat.shape}")

    # Count parameters
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"FiLM model parameters: {count_parameters(model_film):,}")
    print(f"Concatenation model parameters: {count_parameters(model_concat):,}")

FiLM model output shape: torch.Size([2, 3, 256, 256])
Concatenation model output shape: torch.Size([2, 3, 256, 256])
FiLM model parameters: 17,660,547
Concatenation model parameters: 17,267,715


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import os
import time
from datetime import datetime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt


class PolygonColoringTrainer:
    """
    Trainer class for the polygon coloring task using conditional UNet.
    """

    def __init__(self,
                 config: dict,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 color_info: dict):

        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.color_info = color_info

        # Setup device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        # Initialize model
        self.model = create_model(
            model_type=config['model_type'],
            num_colors=color_info['num_colors'],
            n_channels=config['n_channels'],
            n_classes=config['n_classes'],
            color_embed_dim=config.get('color_embed_dim', 128),
            bilinear=config.get('bilinear', True)
        ).to(self.device)

        # Count parameters
        num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Model has {num_params:,} trainable parameters")

        # Initialize optimizer
        if config['optimizer'] == 'adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=config['learning_rate'],
                weight_decay=config.get('weight_decay', 1e-4)
            )
        elif config['optimizer'] == 'adamw':
            self.optimizer = optim.AdamW(
                self.model.parameters(),
                lr=config['learning_rate'],
                weight_decay=config.get('weight_decay', 1e-2)
            )
        else:
            raise ValueError(f"Unsupported optimizer: {config['optimizer']}")

        # Initialize scheduler
        if config.get('scheduler') == 'cosine':
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=config['epochs'],
                eta_min=config['learning_rate'] * 0.01
            )
        elif config.get('scheduler') == 'step':
            self.scheduler = optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=config.get('step_size', 50),
                gamma=config.get('gamma', 0.5)
            )
        else:
            self.scheduler = None

        # Initialize loss function
        if config['loss_function'] == 'mse':
            self.criterion = nn.MSELoss()
        elif config['loss_function'] == 'l1':
            self.criterion = nn.L1Loss()
        elif config['loss_function'] == 'huber':
            self.criterion = nn.SmoothL1Loss()
        else:
            raise ValueError(f"Unsupported loss function: {config['loss_function']}")

        # Training state
        self.current_epoch = 0
        self.best_val_loss = float('inf')
        self.train_losses = []
        self.val_losses = []

        # Create output directories
        os.makedirs(config['checkpoint_dir'], exist_ok=True)
        os.makedirs(config['sample_dir'], exist_ok=True)

    def train_epoch(self):
        self.model.train()
        epoch_loss = 0.0
        num_batches = len(self.train_loader)

        pbar = tqdm(self.train_loader, desc=f'Epoch {self.current_epoch + 1}/{self.config["epochs"]}')

        for batch_idx, batch in enumerate(pbar):
            # Move data to device
            input_images = batch['input_image'].to(self.device)
            color_conditions = batch['color_onehot'].to(self.device)
            target_images = batch['output_image'].to(self.device)

            # Forward pass
            self.optimizer.zero_grad()
            predicted_images = self.model(input_images, color_conditions)

            # Compute loss
            loss = self.criterion(predicted_images, target_images)

            # Backward pass
            loss.backward()

            # Gradient clipping (optional)
            if self.config.get('grad_clip'):
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip'])

            self.optimizer.step()

            # Update metrics
            epoch_loss += loss.item()

            # Update progress bar
            pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

            # Log to wandb
            if batch_idx % self.config.get('log_interval', 50) == 0:
                wandb.log({
                    'train_loss_step': loss.item(),
                    'learning_rate': self.optimizer.param_groups[0]['lr'],
                    'epoch': self.current_epoch,
                    'step': self.current_epoch * num_batches + batch_idx
                })

        avg_epoch_loss = epoch_loss / num_batches
        self.train_losses.append(avg_epoch_loss)

        return avg_epoch_loss

    def validate_epoch(self):
        self.model.eval()
        epoch_loss = 0.0
        num_batches = len(self.val_loader)

        # Store samples for visualization
        sample_inputs = []
        sample_predictions = []
        sample_targets = []
        sample_colors = []

        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(self.val_loader, desc='Validating')):
                # Move data to device
                input_images = batch['input_image'].to(self.device)
                color_conditions = batch['color_onehot'].to(self.device)
                target_images = batch['output_image'].to(self.device)

                # Forward pass
                predicted_images = self.model(input_images, color_conditions)

                # Compute loss
                loss = self.criterion(predicted_images, target_images)
                epoch_loss += loss.item()

                # Store first batch samples for visualization
                if batch_idx == 0:
                    sample_inputs = input_images[:4].cpu()
                    sample_predictions = predicted_images[:4].cpu()
                    sample_targets = target_images[:4].cpu()
                    sample_colors = [batch['color_name'][i] for i in range(min(4, len(batch['color_name'])))]

        avg_epoch_loss = epoch_loss / num_batches
        self.val_losses.append(avg_epoch_loss)

        # Create and log sample images
        if len(sample_inputs) > 0:
            self.log_sample_images(
                sample_inputs, sample_predictions, sample_targets, sample_colors
            )

        return avg_epoch_loss

    def log_sample_images(self, inputs, predictions, targets, colors):
        # Denormalize images from [-1, 1] to [0, 1]
        def denormalize(tensor):
            return (tensor + 1) / 2

        inputs = denormalize(inputs)
        predictions = denormalize(predictions)
        targets = denormalize(targets)

        # Create comparison grid
        fig, axes = plt.subplots(3, 4, figsize=(16, 12))

        for i in range(4):
            # Input image
            axes[0, i].imshow(inputs[i].permute(1, 2, 0))
            axes[0, i].set_title(f'Input ({colors[i]})')
            axes[0, i].axis('off')

            # Predicted image
            axes[1, i].imshow(predictions[i].permute(1, 2, 0))
            axes[1, i].set_title('Predicted')
            axes[1, i].axis('off')

            # Target image
            axes[2, i].imshow(targets[i].permute(1, 2, 0))
            axes[2, i].set_title('Target')
            axes[2, i].axis('off')

        plt.tight_layout()

        # Log to wandb
        wandb.log({
            'sample_images': wandb.Image(fig),
            'epoch': self.current_epoch
        })

        plt.close(fig)

        # Save sample images to disk
        sample_path = os.path.join(
            self.config['sample_dir'],
            f'epoch_{self.current_epoch:03d}.png'
        )

        # Create a single image with all samples
        grid_img = torch.cat([
            torch.cat([inputs[i] for i in range(4)], dim=2),
            torch.cat([predictions[i] for i in range(4)], dim=2),
            torch.cat([targets[i] for i in range(4)], dim=2)
        ], dim=1)

        transforms.ToPILImage()(grid_img).save(sample_path)

    def save_checkpoint(self, is_best=False):
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_val_loss': self.best_val_loss,
            'config': self.config,
            'color_info': self.color_info
        }

        # Save latest checkpoint
        latest_path = os.path.join(self.config['checkpoint_dir'], 'latest_checkpoint.pth')
        torch.save(checkpoint, latest_path)

        # Save best checkpoint
        if is_best:
            best_path = os.path.join(self.config['checkpoint_dir'], 'best_checkpoint.pth')
            torch.save(checkpoint, best_path)
            print(f"New best model saved with validation loss: {self.best_val_loss:.4f}")

    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if self.scheduler and checkpoint['scheduler_state_dict']:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        self.current_epoch = checkpoint['epoch']
        self.train_losses = checkpoint['train_losses']
        self.val_losses = checkpoint['val_losses']
        self.best_val_loss = checkpoint['best_val_loss']

        print(f"Loaded checkpoint from epoch {self.current_epoch}")

    def train(self):
        print("Starting training...")
        start_time = time.time()

        for epoch in range(self.current_epoch, self.config['epochs']):
            self.current_epoch = epoch

            # Train
            train_loss = self.train_epoch()

            # Validate
            val_loss = self.validate_epoch()

            # Update scheduler
            if self.scheduler:
                self.scheduler.step()

            # Log epoch metrics
            wandb.log({
                'epoch': epoch,
                'train_loss_epoch': train_loss,
                'val_loss_epoch': val_loss,
                'learning_rate': self.optimizer.param_groups[0]['lr']
            })

            # Save checkpoint
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss

            if epoch % self.config.get('save_interval', 10) == 0 or is_best:
                self.save_checkpoint(is_best=is_best)

            # Print epoch summary
            elapsed_time = time.time() - start_time
            print(f"Epoch {epoch + 1}/{self.config['epochs']} - "
                  f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                  f"Time: {elapsed_time/60:.1f}min")

        print(f"Training completed! Best validation loss: {self.best_val_loss:.4f}")


def main():
    # Training configuration
    config = {
        # Model configuration
        'model_type': 'film',  # 'film' or 'concat'
        'n_channels': 3,
        'n_classes': 3,
        'color_embed_dim': 128,
        'bilinear': True,

        # Training configuration
        'epochs': 100,
        'batch_size': 16,
        'learning_rate': 1e-3,
        'optimizer': 'adamw',
        'weight_decay': 1e-2,
        'scheduler': 'cosine',
        'loss_function': 'l1',  # 'mse', 'l1', 'huber'
        'grad_clip': 1.0,

        # Data configuration
        'image_size': 256,
        'num_workers': 4,
        'use_augmentations': True,

        # Logging and saving
        'log_interval': 20,
        'save_interval': 10,
        'checkpoint_dir': 'checkpoints',
        'sample_dir': 'samples',

        # Paths
        'train_json': '/content/drive/MyDrive/dataset/dataset/training/data.json',
        'val_json': '/content/drive/MyDrive/dataset/dataset/validation/data.json',
        'train_input_dir': '/content/drive/MyDrive/dataset/dataset/training/inputs',
        'train_output_dir': '/content/drive/MyDrive/dataset/dataset/training/outputs',
        'val_input_dir': '/content/drive/MyDrive/dataset/dataset/validation/inputs',
        'val_output_dir': '/content/drive/MyDrive/dataset/dataset/validation/outputs',
    }

    # Initialize wandb
    wandb.init(
        project="polygon-coloring",
        config=config,
        name=f"unet-{config['model_type']}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    )


    # Create data loaders
    train_loader, val_loader, color_info = get_data_loaders(
        train_json_path=config['train_json'],
        val_json_path=config['val_json'],
        train_input_dir=config['train_input_dir'],
        train_output_dir=config['train_output_dir'],
        val_input_dir=config['val_input_dir'],
        val_output_dir=config['val_output_dir'],
        batch_size=config['batch_size'],
        image_size=config['image_size'],
        num_workers=config['num_workers'],
        use_augmentations=config['use_augmentations']
    )

    # Initialize trainer
    trainer = PolygonColoringTrainer(config, train_loader, val_loader, color_info)

    # Start training
    trainer.train()

    # Finish wandb run
    wandb.finish()


if __name__ == "__main__":
    main()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmeghashyam2005[0m ([33mmeghashyam2005-mahindra-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Dataset initialized with 56 samples
Colors: ['blue', 'cyan', 'green', 'magenta', 'orange', 'purple', 'red', 'yellow']
Color mapping: {'blue': 0, 'cyan': 1, 'green': 2, 'magenta': 3, 'orange': 4, 'purple': 5, 'red': 6, 'yellow': 7}
Dataset initialized with 5 samples
Colors: ['blue', 'cyan', 'green', 'magenta', 'orange', 'purple', 'red', 'yellow']
Color mapping: {'blue': 0, 'cyan': 1, 'green': 2, 'magenta': 3, 'orange': 4, 'purple': 5, 'red': 6, 'yellow': 7}
Using device: cuda




Model has 17,660,547 trainable parameters
Starting training...


Epoch 1/100: 100%|██████████| 3/3 [00:10<00:00,  3.64s/it, Loss=0.8191]
Validating: 100%|██████████| 1/1 [00:04<00:00,  4.27s/it]


New best model saved with validation loss: 0.9958
Epoch 1/100 - Train Loss: 0.8564, Val Loss: 0.9958, Time: 0.3min


Epoch 2/100: 100%|██████████| 3/3 [00:04<00:00,  1.48s/it, Loss=0.6776]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.60it/s]


Epoch 2/100 - Train Loss: 0.7231, Val Loss: 1.0316, Time: 0.4min


Epoch 3/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.5448]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.80it/s]


Epoch 3/100 - Train Loss: 0.5816, Val Loss: 1.1396, Time: 0.5min


Epoch 4/100: 100%|██████████| 3/3 [00:03<00:00,  1.32s/it, Loss=0.5156]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.65it/s]


Epoch 4/100 - Train Loss: 0.5328, Val Loss: 1.2906, Time: 0.6min


Epoch 5/100: 100%|██████████| 3/3 [00:03<00:00,  1.32s/it, Loss=0.4707]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.57it/s]


Epoch 5/100 - Train Loss: 0.4763, Val Loss: 1.2523, Time: 0.7min


Epoch 6/100: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it, Loss=0.4154]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.52it/s]


Epoch 6/100 - Train Loss: 0.4361, Val Loss: 1.1791, Time: 0.8min


Epoch 7/100: 100%|██████████| 3/3 [00:03<00:00,  1.29s/it, Loss=0.3871]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.56it/s]


Epoch 7/100 - Train Loss: 0.3859, Val Loss: 1.1260, Time: 0.9min


Epoch 8/100: 100%|██████████| 3/3 [00:04<00:00,  1.42s/it, Loss=0.3619]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.70it/s]


New best model saved with validation loss: 0.8148
Epoch 8/100 - Train Loss: 0.3587, Val Loss: 0.8148, Time: 1.0min


Epoch 9/100: 100%|██████████| 3/3 [00:03<00:00,  1.31s/it, Loss=0.3019]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.44it/s]


New best model saved with validation loss: 0.7222
Epoch 9/100 - Train Loss: 0.3136, Val Loss: 0.7222, Time: 1.1min


Epoch 10/100: 100%|██████████| 3/3 [00:04<00:00,  1.36s/it, Loss=0.2791]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.75it/s]


New best model saved with validation loss: 0.5992
Epoch 10/100 - Train Loss: 0.2878, Val Loss: 0.5992, Time: 1.3min


Epoch 11/100: 100%|██████████| 3/3 [00:04<00:00,  1.59s/it, Loss=0.2652]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.50it/s]


New best model saved with validation loss: 0.4712
Epoch 11/100 - Train Loss: 0.2852, Val Loss: 0.4712, Time: 1.4min


Epoch 12/100: 100%|██████████| 3/3 [00:04<00:00,  1.37s/it, Loss=0.2525]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.66it/s]


New best model saved with validation loss: 0.2933
Epoch 12/100 - Train Loss: 0.2563, Val Loss: 0.2933, Time: 1.5min


Epoch 13/100: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it, Loss=0.2156]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.36it/s]


New best model saved with validation loss: 0.2418
Epoch 13/100 - Train Loss: 0.2295, Val Loss: 0.2418, Time: 1.6min


Epoch 14/100: 100%|██████████| 3/3 [00:04<00:00,  1.45s/it, Loss=0.2188]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]


Epoch 14/100 - Train Loss: 0.2197, Val Loss: 0.2521, Time: 1.7min


Epoch 15/100: 100%|██████████| 3/3 [00:03<00:00,  1.27s/it, Loss=0.1845]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.42it/s]


New best model saved with validation loss: 0.2379
Epoch 15/100 - Train Loss: 0.1909, Val Loss: 0.2379, Time: 1.8min


Epoch 16/100: 100%|██████████| 3/3 [00:03<00:00,  1.31s/it, Loss=0.1767]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.64it/s]


Epoch 16/100 - Train Loss: 0.1929, Val Loss: 0.2413, Time: 1.9min


Epoch 17/100: 100%|██████████| 3/3 [00:04<00:00,  1.52s/it, Loss=0.1657]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.05it/s]


Epoch 17/100 - Train Loss: 0.1684, Val Loss: 0.2458, Time: 2.0min


Epoch 18/100: 100%|██████████| 3/3 [00:04<00:00,  1.36s/it, Loss=0.1826]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.37it/s]


New best model saved with validation loss: 0.2221
Epoch 18/100 - Train Loss: 0.1634, Val Loss: 0.2221, Time: 2.2min


Epoch 19/100: 100%|██████████| 3/3 [00:05<00:00,  1.76s/it, Loss=0.1626]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.36it/s]


New best model saved with validation loss: 0.1914
Epoch 19/100 - Train Loss: 0.1533, Val Loss: 0.1914, Time: 2.3min


Epoch 20/100: 100%|██████████| 3/3 [00:04<00:00,  1.35s/it, Loss=0.1368]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.51it/s]


New best model saved with validation loss: 0.1726
Epoch 20/100 - Train Loss: 0.1482, Val Loss: 0.1726, Time: 2.4min


Epoch 21/100: 100%|██████████| 3/3 [00:04<00:00,  1.49s/it, Loss=0.1310]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.24it/s]


New best model saved with validation loss: 0.1602
Epoch 21/100 - Train Loss: 0.1356, Val Loss: 0.1602, Time: 2.6min


Epoch 22/100: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it, Loss=0.1501]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.66it/s]


New best model saved with validation loss: 0.1568
Epoch 22/100 - Train Loss: 0.1379, Val Loss: 0.1568, Time: 2.7min


Epoch 23/100: 100%|██████████| 3/3 [00:04<00:00,  1.60s/it, Loss=0.1415]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]


Epoch 23/100 - Train Loss: 0.1503, Val Loss: 0.1594, Time: 2.8min


Epoch 24/100: 100%|██████████| 3/3 [00:04<00:00,  1.36s/it, Loss=0.1441]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.31it/s]


Epoch 24/100 - Train Loss: 0.1427, Val Loss: 0.1622, Time: 2.9min


Epoch 25/100: 100%|██████████| 3/3 [00:05<00:00,  1.77s/it, Loss=0.1302]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


Epoch 25/100 - Train Loss: 0.1426, Val Loss: 0.1620, Time: 3.0min


Epoch 26/100: 100%|██████████| 3/3 [00:03<00:00,  1.31s/it, Loss=0.1378]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]


Epoch 26/100 - Train Loss: 0.1329, Val Loss: 0.1620, Time: 3.1min


Epoch 27/100: 100%|██████████| 3/3 [00:04<00:00,  1.42s/it, Loss=0.1339]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]


New best model saved with validation loss: 0.1559
Epoch 27/100 - Train Loss: 0.1374, Val Loss: 0.1559, Time: 3.2min


Epoch 28/100: 100%|██████████| 3/3 [00:03<00:00,  1.33s/it, Loss=0.1266]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.19it/s]


New best model saved with validation loss: 0.1488
Epoch 28/100 - Train Loss: 0.1370, Val Loss: 0.1488, Time: 3.3min


Epoch 29/100: 100%|██████████| 3/3 [00:05<00:00,  1.71s/it, Loss=0.1402]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]


New best model saved with validation loss: 0.1440
Epoch 29/100 - Train Loss: 0.1337, Val Loss: 0.1440, Time: 3.5min


Epoch 30/100: 100%|██████████| 3/3 [00:04<00:00,  1.45s/it, Loss=0.1092]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]


New best model saved with validation loss: 0.1411
Epoch 30/100 - Train Loss: 0.1170, Val Loss: 0.1411, Time: 3.6min


Epoch 31/100: 100%|██████████| 3/3 [00:04<00:00,  1.49s/it, Loss=0.1049]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.16it/s]


Epoch 31/100 - Train Loss: 0.1239, Val Loss: 0.1440, Time: 3.7min


Epoch 32/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.1100]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.14it/s]


Epoch 32/100 - Train Loss: 0.1192, Val Loss: 0.1438, Time: 3.8min


Epoch 33/100: 100%|██████████| 3/3 [00:05<00:00,  1.79s/it, Loss=0.1356]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]


New best model saved with validation loss: 0.1397
Epoch 33/100 - Train Loss: 0.1240, Val Loss: 0.1397, Time: 3.9min


Epoch 34/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.1287]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.20it/s]


Epoch 34/100 - Train Loss: 0.1256, Val Loss: 0.1423, Time: 4.0min


Epoch 35/100: 100%|██████████| 3/3 [00:04<00:00,  1.63s/it, Loss=0.1200]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.05it/s]


Epoch 35/100 - Train Loss: 0.1266, Val Loss: 0.1411, Time: 4.1min


Epoch 36/100: 100%|██████████| 3/3 [00:04<00:00,  1.35s/it, Loss=0.0991]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.09it/s]


New best model saved with validation loss: 0.1379
Epoch 36/100 - Train Loss: 0.1205, Val Loss: 0.1379, Time: 4.3min


Epoch 37/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.1270]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.47it/s]


New best model saved with validation loss: 0.1347
Epoch 37/100 - Train Loss: 0.1148, Val Loss: 0.1347, Time: 4.4min


Epoch 38/100: 100%|██████████| 3/3 [00:04<00:00,  1.44s/it, Loss=0.1167]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.05it/s]


New best model saved with validation loss: 0.1338
Epoch 38/100 - Train Loss: 0.1163, Val Loss: 0.1338, Time: 4.5min


Epoch 39/100: 100%|██████████| 3/3 [00:04<00:00,  1.54s/it, Loss=0.1158]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]


Epoch 39/100 - Train Loss: 0.1191, Val Loss: 0.1366, Time: 4.6min


Epoch 40/100: 100%|██████████| 3/3 [00:04<00:00,  1.35s/it, Loss=0.1225]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.95it/s]


Epoch 40/100 - Train Loss: 0.1198, Val Loss: 0.1391, Time: 4.7min


Epoch 41/100: 100%|██████████| 3/3 [00:04<00:00,  1.50s/it, Loss=0.1232]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.35it/s]


Epoch 41/100 - Train Loss: 0.1152, Val Loss: 0.1387, Time: 4.8min


Epoch 42/100: 100%|██████████| 3/3 [00:04<00:00,  1.51s/it, Loss=0.1033]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]


Epoch 42/100 - Train Loss: 0.1155, Val Loss: 0.1365, Time: 4.9min


Epoch 43/100: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it, Loss=0.0994]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.02it/s]


New best model saved with validation loss: 0.1330
Epoch 43/100 - Train Loss: 0.1110, Val Loss: 0.1330, Time: 5.0min


Epoch 44/100: 100%|██████████| 3/3 [00:05<00:00,  1.75s/it, Loss=0.0980]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s]


New best model saved with validation loss: 0.1312
Epoch 44/100 - Train Loss: 0.1080, Val Loss: 0.1312, Time: 5.2min


Epoch 45/100: 100%|██████████| 3/3 [00:04<00:00,  1.43s/it, Loss=0.1131]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]


Epoch 45/100 - Train Loss: 0.1095, Val Loss: 0.1317, Time: 5.3min


Epoch 46/100: 100%|██████████| 3/3 [00:04<00:00,  1.62s/it, Loss=0.1083]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.05it/s]


Epoch 46/100 - Train Loss: 0.1030, Val Loss: 0.1416, Time: 5.4min


Epoch 47/100: 100%|██████████| 3/3 [00:04<00:00,  1.40s/it, Loss=0.1019]
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.00it/s]


Epoch 47/100 - Train Loss: 0.1027, Val Loss: 0.1369, Time: 5.5min


Epoch 48/100: 100%|██████████| 3/3 [00:05<00:00,  1.82s/it, Loss=0.1101]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.89it/s]


New best model saved with validation loss: 0.1283
Epoch 48/100 - Train Loss: 0.1065, Val Loss: 0.1283, Time: 5.6min


Epoch 49/100: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it, Loss=0.1031]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.94it/s]


Epoch 49/100 - Train Loss: 0.0990, Val Loss: 0.1310, Time: 5.7min


Epoch 50/100: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it, Loss=0.0983]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.41it/s]


New best model saved with validation loss: 0.1275
Epoch 50/100 - Train Loss: 0.0945, Val Loss: 0.1275, Time: 5.8min


Epoch 51/100: 100%|██████████| 3/3 [00:03<00:00,  1.33s/it, Loss=0.0936]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]


Epoch 51/100 - Train Loss: 0.1021, Val Loss: 0.1533, Time: 6.0min


Epoch 52/100: 100%|██████████| 3/3 [00:04<00:00,  1.57s/it, Loss=0.0997]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s]


Epoch 52/100 - Train Loss: 0.0991, Val Loss: 0.1363, Time: 6.1min


Epoch 53/100: 100%|██████████| 3/3 [00:04<00:00,  1.39s/it, Loss=0.0809]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]


Epoch 53/100 - Train Loss: 0.0904, Val Loss: 0.1331, Time: 6.2min


Epoch 54/100: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it, Loss=0.0679]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s]


Epoch 54/100 - Train Loss: 0.0907, Val Loss: 0.1369, Time: 6.3min


Epoch 55/100: 100%|██████████| 3/3 [00:04<00:00,  1.63s/it, Loss=0.0767]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.91it/s]


Epoch 55/100 - Train Loss: 0.0893, Val Loss: 0.1286, Time: 6.4min


Epoch 56/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.0794]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.86it/s]


New best model saved with validation loss: 0.1245
Epoch 56/100 - Train Loss: 0.0824, Val Loss: 0.1245, Time: 6.5min


Epoch 57/100: 100%|██████████| 3/3 [00:05<00:00,  1.69s/it, Loss=0.0894]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]


New best model saved with validation loss: 0.1187
Epoch 57/100 - Train Loss: 0.0834, Val Loss: 0.1187, Time: 6.6min


Epoch 58/100: 100%|██████████| 3/3 [00:03<00:00,  1.32s/it, Loss=0.0768]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.82it/s]


Epoch 58/100 - Train Loss: 0.0789, Val Loss: 0.1337, Time: 6.7min


Epoch 59/100: 100%|██████████| 3/3 [00:05<00:00,  1.82s/it, Loss=0.0856]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.90it/s]


Epoch 59/100 - Train Loss: 0.0792, Val Loss: 0.1402, Time: 6.8min


Epoch 60/100: 100%|██████████| 3/3 [00:04<00:00,  1.39s/it, Loss=0.0751]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.88it/s]


Epoch 60/100 - Train Loss: 0.0713, Val Loss: 0.1381, Time: 6.9min


Epoch 61/100: 100%|██████████| 3/3 [00:04<00:00,  1.62s/it, Loss=0.0865]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.34it/s]


Epoch 61/100 - Train Loss: 0.0721, Val Loss: 0.1309, Time: 7.0min


Epoch 62/100: 100%|██████████| 3/3 [00:04<00:00,  1.52s/it, Loss=0.0674]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.91it/s]


Epoch 62/100 - Train Loss: 0.0720, Val Loss: 0.1466, Time: 7.1min


Epoch 63/100: 100%|██████████| 3/3 [00:04<00:00,  1.40s/it, Loss=0.0566]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.22it/s]


Epoch 63/100 - Train Loss: 0.0664, Val Loss: 0.1387, Time: 7.2min


Epoch 64/100: 100%|██████████| 3/3 [00:04<00:00,  1.47s/it, Loss=0.0763]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.77it/s]


Epoch 64/100 - Train Loss: 0.0707, Val Loss: 0.1397, Time: 7.3min


Epoch 65/100: 100%|██████████| 3/3 [00:04<00:00,  1.40s/it, Loss=0.0573]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]


Epoch 65/100 - Train Loss: 0.0634, Val Loss: 0.1468, Time: 7.5min


Epoch 66/100: 100%|██████████| 3/3 [00:05<00:00,  1.69s/it, Loss=0.0666]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.83it/s]


Epoch 66/100 - Train Loss: 0.0635, Val Loss: 0.1284, Time: 7.6min


Epoch 67/100: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it, Loss=0.0711]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.72it/s]


Epoch 67/100 - Train Loss: 0.0658, Val Loss: 0.1251, Time: 7.7min


Epoch 68/100: 100%|██████████| 3/3 [00:05<00:00,  1.78s/it, Loss=0.0558]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.80it/s]


New best model saved with validation loss: 0.1158
Epoch 68/100 - Train Loss: 0.0633, Val Loss: 0.1158, Time: 7.8min


Epoch 69/100: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it, Loss=0.0648]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.72it/s]


New best model saved with validation loss: 0.1130
Epoch 69/100 - Train Loss: 0.0621, Val Loss: 0.1130, Time: 7.9min


Epoch 70/100: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it, Loss=0.0498]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.74it/s]


Epoch 70/100 - Train Loss: 0.0547, Val Loss: 0.1152, Time: 8.0min


Epoch 71/100: 100%|██████████| 3/3 [00:04<00:00,  1.42s/it, Loss=0.0581]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]


New best model saved with validation loss: 0.1094
Epoch 71/100 - Train Loss: 0.0629, Val Loss: 0.1094, Time: 8.2min


Epoch 72/100: 100%|██████████| 3/3 [00:05<00:00,  1.76s/it, Loss=0.0525]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.74it/s]


Epoch 72/100 - Train Loss: 0.0557, Val Loss: 0.1100, Time: 8.3min


Epoch 73/100: 100%|██████████| 3/3 [00:04<00:00,  1.35s/it, Loss=0.0549]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.77it/s]


Epoch 73/100 - Train Loss: 0.0562, Val Loss: 0.1163, Time: 8.4min


Epoch 74/100: 100%|██████████| 3/3 [00:05<00:00,  1.69s/it, Loss=0.0485]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.71it/s]


Epoch 74/100 - Train Loss: 0.0531, Val Loss: 0.1233, Time: 8.5min


Epoch 75/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.0562]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.76it/s]


Epoch 75/100 - Train Loss: 0.0536, Val Loss: 0.1206, Time: 8.6min


Epoch 76/100: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it, Loss=0.0438]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.11it/s]


Epoch 76/100 - Train Loss: 0.0513, Val Loss: 0.1210, Time: 8.7min


Epoch 77/100: 100%|██████████| 3/3 [00:04<00:00,  1.43s/it, Loss=0.0466]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]


Epoch 77/100 - Train Loss: 0.0483, Val Loss: 0.1204, Time: 8.8min


Epoch 78/100: 100%|██████████| 3/3 [00:04<00:00,  1.39s/it, Loss=0.0670]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.43it/s]


Epoch 78/100 - Train Loss: 0.0561, Val Loss: 0.1231, Time: 8.9min


Epoch 79/100: 100%|██████████| 3/3 [00:05<00:00,  1.76s/it, Loss=0.0494]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.76it/s]


Epoch 79/100 - Train Loss: 0.0499, Val Loss: 0.1132, Time: 9.0min


Epoch 80/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.0510]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.65it/s]


Epoch 80/100 - Train Loss: 0.0547, Val Loss: 0.1118, Time: 9.1min


Epoch 81/100: 100%|██████████| 3/3 [00:05<00:00,  1.73s/it, Loss=0.0502]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.71it/s]


Epoch 81/100 - Train Loss: 0.0490, Val Loss: 0.1157, Time: 9.2min


Epoch 82/100: 100%|██████████| 3/3 [00:04<00:00,  1.36s/it, Loss=0.0540]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.65it/s]


Epoch 82/100 - Train Loss: 0.0502, Val Loss: 0.1181, Time: 9.3min


Epoch 83/100: 100%|██████████| 3/3 [00:04<00:00,  1.44s/it, Loss=0.0477]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]


Epoch 83/100 - Train Loss: 0.0498, Val Loss: 0.1189, Time: 9.4min


Epoch 84/100: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it, Loss=0.0580]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.68it/s]


Epoch 84/100 - Train Loss: 0.0495, Val Loss: 0.1187, Time: 9.5min


Epoch 85/100: 100%|██████████| 3/3 [00:04<00:00,  1.43s/it, Loss=0.0500]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s]


Epoch 85/100 - Train Loss: 0.0433, Val Loss: 0.1159, Time: 9.6min


Epoch 86/100: 100%|██████████| 3/3 [00:04<00:00,  1.62s/it, Loss=0.0461]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.65it/s]


Epoch 86/100 - Train Loss: 0.0477, Val Loss: 0.1138, Time: 9.7min


Epoch 87/100: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it, Loss=0.0459]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.71it/s]


Epoch 87/100 - Train Loss: 0.0454, Val Loss: 0.1152, Time: 9.8min


Epoch 88/100: 100%|██████████| 3/3 [00:05<00:00,  1.83s/it, Loss=0.0401]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]


Epoch 88/100 - Train Loss: 0.0401, Val Loss: 0.1170, Time: 9.9min


Epoch 89/100: 100%|██████████| 3/3 [00:04<00:00,  1.47s/it, Loss=0.0386]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.59it/s]


Epoch 89/100 - Train Loss: 0.0437, Val Loss: 0.1177, Time: 10.0min


Epoch 90/100: 100%|██████████| 3/3 [00:05<00:00,  1.86s/it, Loss=0.0431]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]


Epoch 90/100 - Train Loss: 0.0467, Val Loss: 0.1181, Time: 10.2min


Epoch 91/100: 100%|██████████| 3/3 [00:04<00:00,  1.40s/it, Loss=0.0470]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]


Epoch 91/100 - Train Loss: 0.0437, Val Loss: 0.1171, Time: 10.3min


Epoch 92/100: 100%|██████████| 3/3 [00:05<00:00,  1.72s/it, Loss=0.0468]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]


Epoch 92/100 - Train Loss: 0.0505, Val Loss: 0.1156, Time: 10.4min


Epoch 93/100: 100%|██████████| 3/3 [00:04<00:00,  1.42s/it, Loss=0.0546]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]


Epoch 93/100 - Train Loss: 0.0428, Val Loss: 0.1154, Time: 10.5min


Epoch 94/100: 100%|██████████| 3/3 [00:04<00:00,  1.42s/it, Loss=0.0517]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]


Epoch 94/100 - Train Loss: 0.0425, Val Loss: 0.1155, Time: 10.6min


Epoch 95/100: 100%|██████████| 3/3 [00:04<00:00,  1.40s/it, Loss=0.0522]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]


Epoch 95/100 - Train Loss: 0.0391, Val Loss: 0.1158, Time: 10.7min


Epoch 96/100: 100%|██████████| 3/3 [00:04<00:00,  1.47s/it, Loss=0.0449]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.05it/s]


Epoch 96/100 - Train Loss: 0.0446, Val Loss: 0.1157, Time: 10.8min


Epoch 97/100: 100%|██████████| 3/3 [00:05<00:00,  1.71s/it, Loss=0.0325]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]


Epoch 97/100 - Train Loss: 0.0420, Val Loss: 0.1155, Time: 10.9min


Epoch 98/100: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it, Loss=0.0474]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]


Epoch 98/100 - Train Loss: 0.0447, Val Loss: 0.1160, Time: 11.0min


Epoch 99/100: 100%|██████████| 3/3 [00:05<00:00,  1.82s/it, Loss=0.0434]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.46it/s]


Epoch 99/100 - Train Loss: 0.0434, Val Loss: 0.1164, Time: 11.1min


Epoch 100/100: 100%|██████████| 3/3 [00:04<00:00,  1.47s/it, Loss=0.0370]
Validating: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]


Epoch 100/100 - Train Loss: 0.0445, Val Loss: 0.1159, Time: 11.2min
Training completed! Best validation loss: 0.1094


0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇████
learning_rate,█████▇▇▇▇▇▇▇▇▇▇▆▆▆▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
train_loss_epoch,█▆▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▅▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss_epoch,▆▇█▇▅▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
learning_rate,1e-05
step,297.0
train_loss_epoch,0.04452
train_loss_step,0.04889
val_loss_epoch,0.11595


In [None]:
import os
from google.colab import files

checkpoint_dir = 'checkpoints'
checkpoint_path = os.path.join(checkpoint_dir, 'best_checkpoint.pth')

if os.path.exists(checkpoint_path):
    print(f"Downloading {checkpoint_path}...")
    files.download(checkpoint_path)
else:
    print(f"Best checkpoint not found at {checkpoint_path}. Please ensure training completed successfully and the file exists.")

Downloading checkpoints/best_checkpoint.pth...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>