# 🎨 Image Colorization Training - Complete Notebook

**Complete end-to-end image colorization training with U-Net architecture**

- ✅ **50 epochs training**
- ✅ **Checkpoint saving every epoch**
- ✅ **Error calculation every epoch**
- ✅ **TPU/GPU/CPU support**
- ✅ **Memory optimized for Kaggle**
- ✅ **COCO dataset compatible**

**Hardware Support:**
- 🚀 TPU (Google Colab/Kaggle TPU)
- 🎮 GPU (CUDA)
- 💻 CPU (fallback)


## 📦 Installation & Setup

In [1]:
# Install required packages
!pip install torch>=2.0.0 torchvision>=0.15.0
!pip install opencv-python>=4.8.0 Pillow>=10.0.0 scikit-image>=0.21.0
!pip install numpy>=1.24.0 scipy>=1.11.0 matplotlib>=3.7.0
!pip install tqdm>=4.65.0 pyyaml>=6.0 psutil>=5.9.0

# TPU support (for Kaggle TPU)
try:
    !pip install torch-xla>=2.0.0
    print("✅ TPU support installed")
except:
    print("⚠️ TPU support not available (GPU/CPU mode)")

✅ TPU support installed


In [None]:
# Import all required libraries
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import VGG16_Weights

import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
import yaml
from datetime import datetime
import time
from tqdm import tqdm
from typing import Tuple, Optional, Dict, List
import json
import psutil
from skimage import color

# TPU support
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    TPU_AVAILABLE = True
    print("✅ TPU support loaded")
except ImportError:
    TPU_AVAILABLE = False
    print("⚠️ TPU not available")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if TPU_AVAILABLE:
    print(f"TPU available: {xm.xla_device() is not None}")

## 🏗️ Model Architecture - U-Net

In [None]:
class UNetColorizer(nn.Module):
    """U-Net architecture for image colorization."""

    def __init__(self, input_channels=1, output_channels=2):
        super(UNetColorizer, self).__init__()

        # Encoder (downsampling)
        self.enc1 = self._conv_block(input_channels, 64)
        self.enc2 = self._conv_block(64, 128)
        self.enc3 = self._conv_block(128, 256)
        self.enc4 = self._conv_block(256, 512)

        # Bottleneck
        self.bottleneck = self._conv_block(512, 1024)

        # Decoder (upsampling)
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = self._conv_block(1024, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self._conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self._conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self._conv_block(128, 64)

        # Final layer
        self.final = nn.Conv2d(64, output_channels, 1)

        # Pooling
        self.pool = nn.MaxPool2d(2, 2)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool(enc4))

        # Decoder with skip connections
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.dec1(dec1)

        return torch.tanh(self.final(dec1))

print("✅ U-Net model defined")

## 🎯 Loss Functions

In [None]:
class L1Loss(nn.Module):
    """L1 loss for colorization training."""

    def __init__(self, weight: float = 1.0):
        super(L1Loss, self).__init__()
        self.weight = weight

    def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return self.weight * F.l1_loss(predicted, target)


class PerceptualLoss(nn.Module):
    """VGG-based perceptual loss."""

    def __init__(self, weight: float = 1.0):
        super(PerceptualLoss, self).__init__()
        self.weight = weight

        # Load pre-trained VGG16
        vgg = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
        self.feature_extractor = nn.Sequential(*list(vgg.features)[:16])

        # Freeze parameters
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

    def lab_to_rgb_tensor(self, lab_tensor: torch.Tensor) -> torch.Tensor:
        """Convert LAB tensor to RGB tensor."""
        # Denormalize LAB values
        lab = lab_tensor.clone()
        lab[:, 0] = lab[:, 0] * 100  # L: 0-100
        lab[:, 1] = lab[:, 1] * 127  # a: -127 to 127
        lab[:, 2] = lab[:, 2] * 127  # b: -127 to 127

        # Convert to RGB (simplified conversion)
        rgb = torch.zeros_like(lab)
        rgb[:, 0] = lab[:, 0] / 100  # Approximate conversion
        rgb[:, 1] = (lab[:, 1] + 127) / 255
        rgb[:, 2] = (lab[:, 2] + 127) / 255

        return torch.clamp(rgb, 0, 1)

    def forward(self, predicted_lab: torch.Tensor, target_lab: torch.Tensor) -> torch.Tensor:
        # Convert LAB to RGB
        predicted_rgb = self.lab_to_rgb_tensor(predicted_lab)
        target_rgb = self.lab_to_rgb_tensor(target_lab)

        # Extract features
        predicted_features = self.feature_extractor(predicted_rgb)
        target_features = self.feature_extractor(target_rgb)

        return self.weight * F.mse_loss(predicted_features, target_features)


class HybridLoss(nn.Module):
    """Hybrid loss combining L1 and perceptual losses."""

    def __init__(self, l1_weight: float = 1.0, perceptual_weight: float = 0.1):
        super(HybridLoss, self).__init__()
        self.l1_loss = L1Loss(l1_weight)
        self.perceptual_loss = PerceptualLoss(perceptual_weight) if perceptual_weight > 0 else None

    def forward(self, predicted_ab, target_ab, predicted_lab, target_lab):
        # L1 loss on ab channels
        l1_loss_value = self.l1_loss(predicted_ab, target_ab)

        # Perceptual loss on full LAB images
        if self.perceptual_loss is not None:
            perceptual_loss_value = self.perceptual_loss(predicted_lab, target_lab)
        else:
            perceptual_loss_value = torch.tensor(0.0, device=predicted_ab.device)

        total_loss = l1_loss_value + perceptual_loss_value

        return total_loss, l1_loss_value, perceptual_loss_value

print("✅ Loss functions defined")

## 📊 Dataset & Data Loading

In [None]:
class ColorizationDataset(Dataset):
    """Dataset for image colorization training."""

    def __init__(self, dataset_path: str, target_size: Tuple[int, int] = (256, 256)):
        self.dataset_path = Path(dataset_path)
        self.target_size = target_size

        # Load image paths
        self.image_paths = self._load_image_paths()
        print(f"Found {len(self.image_paths)} images")

    def _load_image_paths(self) -> List[Path]:
        """Load all image paths from dataset directory."""
        extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
        image_paths = []

        for ext in extensions:
            image_paths.extend(self.dataset_path.glob(f'*{ext}'))
            image_paths.extend(self.dataset_path.glob(f'*{ext.upper()}'))

        return sorted(image_paths)

    def _rgb_to_lab(self, rgb_image: np.ndarray) -> np.ndarray:
        """Convert RGB image to LAB color space."""
        return color.rgb2lab(rgb_image)

    def _load_and_preprocess_image(self, image_path: Path) -> Dict[str, torch.Tensor]:
        """Load and preprocess a single image."""
        # Load image
        image = Image.open(image_path).convert('RGB')
        image = image.resize(self.target_size, Image.Resampling.LANCZOS)

        # Convert to numpy array
        rgb_image = np.array(image).astype(np.float32) / 255.0

        # Convert to LAB color space
        lab_image = self._rgb_to_lab(rgb_image)

        # Normalize LAB values
        lab_image[:, :, 0] = lab_image[:, :, 0] / 100.0  # L: 0-100 -> 0-1
        lab_image[:, :, 1] = lab_image[:, :, 1] / 127.0  # a: -127-127 -> -1-1
        lab_image[:, :, 2] = lab_image[:, :, 2] / 127.0  # b: -127-127 -> -1-1

        # Extract L and ab channels
        l_channel = lab_image[:, :, 0:1]  # L channel
        ab_channels = lab_image[:, :, 1:3]  # ab channels

        # Convert to tensors and rearrange dimensions to (C, H, W)
        l_tensor = torch.from_numpy(l_channel).permute(2, 0, 1).float()
        ab_tensor = torch.from_numpy(ab_channels).permute(2, 0, 1).float()
        lab_tensor = torch.from_numpy(lab_image).permute(2, 0, 1).float()

        return {
            'l_channel': l_tensor,
            'ab_channels': ab_tensor,
            'lab_image': lab_tensor
        }

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        image_path = self.image_paths[idx]
        try:
            return self._load_and_preprocess_image(image_path)
        except Exception as e:
            print(f"Error loading {image_path}: {e}")
            # Return a dummy sample
            return {
                'l_channel': torch.zeros(1, *self.target_size),
                'ab_channels': torch.zeros(2, *self.target_size),
                'lab_image': torch.zeros(3, *self.target_size)
            }

print("✅ Dataset class defined")

## 🚀 Training Configuration

In [None]:
# Training Configuration
CONFIG = {
    # Dataset
    'dataset_path': '/kaggle/input/coco-2014-dataset-for-yolov3/coco2014/train2014',  # Update this path
    'input_size': [256, 256],

    # Training
    'epochs': 50,
    'learning_rate': 0.0001,
    'weight_decay': 0.0001,

    # Loss function
    'l1_weight': 1.0,
    'perceptual_weight': 0.0,  # Disabled to save memory

    # Device-specific batch sizes (will be set automatically)
    'batch_size_tpu': 16,
    'batch_size_gpu': 4,
    'batch_size_cpu': 2,
}

print("✅ Configuration loaded")
print(f"Dataset path: {CONFIG['dataset_path']}")
print(f"Training for {CONFIG['epochs']} epochs")

## 🎯 Complete Training Class

In [None]:
class ImageColorizationTrainer:
    """Complete image colorization trainer with multi-device support."""

    def __init__(self, config):
        self.config = config

        # Setup device (TPU > CUDA > CPU)
        self._setup_device()

        # Create directories
        os.makedirs("checkpoints", exist_ok=True)
        os.makedirs("logs", exist_ok=True)

        # Setup model
        self.model = UNetColorizer(input_channels=1, output_channels=2).to(self.device)
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")

        # Setup optimizer
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )

        # Setup loss
        self.criterion = HybridLoss(
            l1_weight=config['l1_weight'],
            perceptual_weight=config['perceptual_weight']
        ).to(self.device)

        # Setup data
        self._setup_data()

        # Training metrics
        self.train_losses = []
        self.val_losses = []

    def _setup_device(self):
        """Setup device with priority: TPU > CUDA > CPU"""
        if TPU_AVAILABLE and xm.xla_device() is not None:
            self.device = xm.xla_device()
            self.is_tpu = True
            self.is_cuda = False
            print(f"🚀 Using TPU device: {self.device}")
            print(f"TPU cores: {xm.xrt_world_size()}")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.is_tpu = False
            self.is_cuda = True
            print(f"🎮 Using CUDA device: {self.device}")
            print(f"GPU: {torch.cuda.get_device_name(0)}")
            print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        else:
            self.device = torch.device("cpu")
            self.is_tpu = False
            self.is_cuda = False
            print(f"💻 Using CPU device: {self.device}")

    def _setup_data(self):
        """Setup training and validation data loaders."""
        # Load dataset
        dataset = ColorizationDataset(
            dataset_path=self.config['dataset_path'],
            target_size=tuple(self.config['input_size'])
        )

        # Split dataset
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        print(f"📊 Dataset: {len(dataset):,} images")
        print(f"Training: {len(train_dataset):,} images")
        print(f"Validation: {len(val_dataset):,} images")

        # Device-specific batch sizes and settings
        if self.is_tpu:
            batch_size = self.config['batch_size_tpu']
            num_workers = 0
            pin_memory = False
        elif self.is_cuda:
            batch_size = self.config['batch_size_gpu']
            num_workers = 2
            pin_memory = True
        else:
            batch_size = self.config['batch_size_cpu']
            num_workers = 4
            pin_memory = False

        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            drop_last=False
        )

        # Wrap with TPU parallel loader if using TPU
        if self.is_tpu:
            self.train_loader = pl.ParallelLoader(train_loader, [self.device])
            self.val_loader = pl.ParallelLoader(val_loader, [self.device])
        else:
            self.train_loader = train_loader
            self.val_loader = val_loader

        print(f"Batch size: {batch_size}")
        print(f"Train batches: {len(train_loader):,}")
        print(f"Val batches: {len(val_loader):,}")

    def train_epoch(self, epoch):
        """Train for one epoch."""
        self.model.train()
        total_loss = 0.0

        # Get the actual loader for TPU or regular loader
        loader = self.train_loader.per_device_loader(self.device) if self.is_tpu else self.train_loader

        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{self.config['epochs']} - Training")

        for batch_idx, batch in enumerate(progress_bar):
            if self.is_tpu:
                grayscale = batch['l_channel'].to(self.device)
                color_ab = batch['ab_channels'].to(self.device)
            else:
                grayscale = batch['l_channel'].to(self.device, non_blocking=True)
                color_ab = batch['ab_channels'].to(self.device, non_blocking=True)

            # Forward pass
            self.optimizer.zero_grad()
            predicted_ab = self.model(grayscale)

            # Construct full LAB images for loss calculation
            predicted_lab = torch.cat([grayscale, predicted_ab], dim=1)
            target_lab = torch.cat([grayscale, color_ab], dim=1)

            total_loss_batch, l1_loss, perceptual_loss = self.criterion(
                predicted_ab, color_ab, predicted_lab, target_lab
            )

            # Backward pass
            total_loss_batch.backward()

            if self.is_tpu:
                xm.optimizer_step(self.optimizer)
            else:
                self.optimizer.step()

            total_loss += total_loss_batch.item()

            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{total_loss_batch.item():.4f}',
                'Avg': f'{total_loss/(batch_idx+1):.4f}'
            })

            # Clear cache periodically (only for CUDA)
            if self.is_cuda and (batch_idx + 1) % 100 == 0:
                torch.cuda.empty_cache()

        avg_loss = total_loss / len(loader)
        return avg_loss

    def validate_epoch(self, epoch):
        """Validate for one epoch."""
        self.model.eval()
        total_loss = 0.0

        # Get the actual loader for TPU or regular loader
        loader = self.val_loader.per_device_loader(self.device) if self.is_tpu else self.val_loader

        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{self.config['epochs']} - Validation")

        with torch.no_grad():
            for batch_idx, batch in enumerate(progress_bar):
                if self.is_tpu:
                    grayscale = batch['l_channel'].to(self.device)
                    color_ab = batch['ab_channels'].to(self.device)
                else:
                    grayscale = batch['l_channel'].to(self.device, non_blocking=True)
                    color_ab = batch['ab_channels'].to(self.device, non_blocking=True)

                predicted_ab = self.model(grayscale)

                # Construct full LAB images for loss calculation
                predicted_lab = torch.cat([grayscale, predicted_ab], dim=1)
                target_lab = torch.cat([grayscale, color_ab], dim=1)

                total_loss_batch, l1_loss, perceptual_loss = self.criterion(
                    predicted_ab, color_ab, predicted_lab, target_lab
                )

                total_loss += total_loss_batch.item()

                # Update progress bar
                progress_bar.set_postfix({
                    'Loss': f'{total_loss_batch.item():.4f}',
                    'Avg': f'{total_loss/(batch_idx+1):.4f}'
                })

        avg_loss = total_loss / len(loader)
        return avg_loss

    def save_checkpoint(self, epoch, train_loss, val_loss):
        """Save model checkpoint."""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': self.config
        }

        # Save epoch checkpoint
        checkpoint_path = f"checkpoints/epoch_{epoch+1:02d}.pth"
        torch.save(checkpoint, checkpoint_path)

        # Save best model
        if not hasattr(self, 'best_val_loss') or val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            torch.save(checkpoint, "checkpoints/best_model.pth")
            print(f"✅ New best model saved (val_loss: {val_loss:.4f})")

        print(f"💾 Checkpoint saved: {checkpoint_path}")

    def log_metrics(self, epoch, train_loss, val_loss):
        """Log training metrics."""
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)

        # Write to log file
        log_entry = f"{datetime.now().isoformat()},Epoch,{epoch+1},Train,{train_loss:.6f},Val,{val_loss:.6f}\n"
        with open("logs/training_log.csv", "a") as f:
            f.write(log_entry)

        # Print summary
        print(f"📊 Epoch {epoch+1:2d}: Train={train_loss:.4f}, Val={val_loss:.4f}")

        if hasattr(self, 'best_val_loss'):
            print(f"    Best Val Loss: {self.best_val_loss:.4f}")

    def train(self):
        """Main training loop."""
        print("=" * 60)
        print("🚀 Starting Image Colorization Training")
        print("=" * 60)

        # Initialize log file
        with open("logs/training_log.csv", "w") as f:
            f.write("timestamp,type,epoch,metric,train_loss,val_metric,val_loss\n")

        start_time = time.time()

        for epoch in range(self.config['epochs']):
            print(f"\n{'='*20} EPOCH {epoch+1}/{self.config['epochs']} {'='*20}")

            # Train
            train_loss = self.train_epoch(epoch)

            # Validate
            val_loss = self.validate_epoch(epoch)

            # Save checkpoint
            self.save_checkpoint(epoch, train_loss, val_loss)

            # Log metrics
            self.log_metrics(epoch, train_loss, val_loss)

            # Memory cleanup
            if self.is_cuda:
                torch.cuda.empty_cache()
            elif self.is_tpu:
                xm.mark_step()

        total_time = time.time() - start_time
        print(f"\n🎉 Training completed in {total_time/3600:.1f} hours!")
        print(f"📁 Checkpoints saved in: checkpoints/")
        print(f"📊 Training log saved in: logs/training_log.csv")

        return self.train_losses, self.val_losses

print("✅ Training class defined")

## 🎬 Start Training

In [None]:
# Update dataset path for your Kaggle dataset
# Common Kaggle COCO paths:
# '/kaggle/input/coco-2014-dataset-for-yolov3/coco2014/train2014'
# '/kaggle/input/coco2014/train2014'
# '/kaggle/input/ms-coco-2014/train2014'

# Check available datasets
import os
print("Available input datasets:")
if os.path.exists('/kaggle/input'):
    for item in os.listdir('/kaggle/input'):
        print(f"  - {item}")
        if 'coco' in item.lower():
            dataset_path = f'/kaggle/input/{item}'
            print(f"    Contents: {os.listdir(dataset_path)}")
else:
    print("Not running on Kaggle - update CONFIG['dataset_path'] manually")

In [None]:
# Update the dataset path based on your Kaggle dataset
# CONFIG['dataset_path'] = '/kaggle/input/your-coco-dataset/train2014'  # Update this!

# Create trainer and start training
trainer = ImageColorizationTrainer(CONFIG)

# Start training for 50 epochs
train_losses, val_losses = trainer.train()

## 📈 Training Results & Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss', color='blue')
plt.plot(val_losses, label='Validation Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(val_losses, label='Validation Loss', color='red', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Over Time')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

# Print final results
print(f"\n🎯 Final Results:")
print(f"Final Training Loss: {train_losses[-1]:.4f}")
print(f"Final Validation Loss: {val_losses[-1]:.4f}")
print(f"Best Validation Loss: {min(val_losses):.4f} (Epoch {val_losses.index(min(val_losses))+1})")

# Save results
results = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'best_val_loss': min(val_losses),
    'best_epoch': val_losses.index(min(val_losses)) + 1,
    'config': CONFIG
}

with open('training_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n💾 Results saved to training_results.json")

## 🎨 Test the Trained Model

In [None]:
def colorize_image(model, image_path, device):
    """Colorize a single image using the trained model."""
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image = image.resize((256, 256), Image.Resampling.LANCZOS)

    # Convert to LAB
    rgb_image = np.array(image).astype(np.float32) / 255.0
    lab_image = color.rgb2lab(rgb_image)

    # Normalize and extract L channel
    l_channel = lab_image[:, :, 0:1] / 100.0
    l_tensor = torch.from_numpy(l_channel).permute(2, 0, 1).unsqueeze(0).float().to(device)

    # Predict ab channels
    with torch.no_grad():
        predicted_ab = model(l_tensor)

    # Convert back to numpy
    predicted_ab = predicted_ab.cpu().squeeze().permute(1, 2, 0).numpy()

    # Reconstruct LAB image
    reconstructed_lab = np.zeros_like(lab_image)
    reconstructed_lab[:, :, 0] = l_channel.squeeze() * 100.0
    reconstructed_lab[:, :, 1:] = predicted_ab * 127.0

    # Convert back to RGB
    reconstructed_rgb = color.lab2rgb(reconstructed_lab)
    reconstructed_rgb = np.clip(reconstructed_rgb, 0, 1)

    return rgb_image, reconstructed_rgb

# Load the best model
best_model_path = "checkpoints/best_model.pth"
if os.path.exists(best_model_path):
    checkpoint = torch.load(best_model_path, map_location=trainer.device)
    trainer.model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✅ Loaded best model from epoch {checkpoint['epoch']+1}")

    # Test on a few images from the dataset
    test_images = list(Path(CONFIG['dataset_path']).glob('*.jpg'))[:5]

    plt.figure(figsize=(15, 10))
    for i, img_path in enumerate(test_images):
        try:
            original, colorized = colorize_image(trainer.model, img_path, trainer.device)

            plt.subplot(2, 5, i+1)
            plt.imshow(original)
            plt.title(f'Original {i+1}')
            plt.axis('off')

            plt.subplot(2, 5, i+6)
            plt.imshow(colorized)
            plt.title(f'Colorized {i+1}')
            plt.axis('off')
        except Exception as e:
            print(f"Error processing {img_path}: {e}")

    plt.tight_layout()
    plt.savefig('colorization_results.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("❌ No trained model found. Train the model first.")

## 📁 Download Results (for Kaggle)

In [None]:
# Create a zip file with all results for easy download
import zipfile

with zipfile.ZipFile('colorization_training_results.zip', 'w') as zipf:
    # Add checkpoints
    for checkpoint_file in Path('checkpoints').glob('*.pth'):
        zipf.write(checkpoint_file, f'checkpoints/{checkpoint_file.name}')

    # Add logs
    if os.path.exists('logs/training_log.csv'):
        zipf.write('logs/training_log.csv', 'logs/training_log.csv')

    # Add results
    if os.path.exists('training_results.json'):
        zipf.write('training_results.json', 'training_results.json')

    # Add plots
    if os.path.exists('training_curves.png'):
        zipf.write('training_curves.png', 'training_curves.png')

    if os.path.exists('colorization_results.png'):
        zipf.write('colorization_results.png', 'colorization_results.png')

print("📦 All results packaged in: colorization_training_results.zip")
print("\n🎉 Training Complete!")
print("\n📋 Summary:")
print(f"   - Trained for {CONFIG['epochs']} epochs")
print(f"   - Model checkpoints saved every epoch")
print(f"   - Training and validation losses calculated")
print(f"   - Best model saved automatically")
print(f"   - Results visualized and saved")