# Model Training Notebook
## Image Colorization using GAN

This notebook demonstrates the complete training pipeline for the image colorization model.

**Contents:**
1. Setup and Configuration
2. Data Loading
3. Model Initialization
4. Training Loop
5. Training Visualization
6. Model Saving

In [None]:
import os
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# Import from modular structure
from src.models import UNetGenerator, PatchDiscriminator
from src.preprocessing import create_dataloaders
from src.training import Trainer

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"MPS Available: {torch.backends.mps.is_available()}")

## 1. Configuration

In [None]:
# Training Configuration
config = {
    # Data
    'data_dir': '../data/train',
    'image_size': 256,
    'batch_size': 16,
    'val_split': 0.1,
    'test_split': 0.1,
    'num_workers': 4,
    
    # Training
    'num_epochs': 50,
    'lr_g': 2e-4,
    'lr_d': 2e-4,
    'beta1': 0.5,
    'beta2': 0.999,
    'l1_lambda': 100,
    
    # Saving
    'save_dir': '../results',
    'checkpoint_every': 5,
}

# Device selection
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")
print("\nConfiguration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 2. Data Loading

In [None]:
# Create data loaders
train_loader, val_loader, test_loader = create_dataloaders(
    data_dir=config['data_dir'],
    batch_size=config['batch_size'],
    image_size=config['image_size'],
    val_split=config['val_split'],
    test_split=config['test_split'],
    num_workers=config['num_workers']
)

print(f"\nData Loaders Created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

In [None]:
# Visualize a sample batch
L_sample, AB_sample = next(iter(train_loader))

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
fig.suptitle('Sample Training Batch', fontsize=14, fontweight='bold')

for i in range(4):
    # L channel (grayscale)
    axes[0, i].imshow(L_sample[i, 0].numpy(), cmap='gray')
    axes[0, i].set_title(f'L Channel {i+1}')
    axes[0, i].axis('off')
    
    # AB channels (as heatmap)
    ab_vis = np.zeros((256, 256, 3))
    ab_vis[:, :, 0] = (AB_sample[i, 0].numpy() + 1) / 2  # A -> Red
    ab_vis[:, :, 2] = (AB_sample[i, 1].numpy() + 1) / 2  # B -> Blue
    axes[1, i].imshow(ab_vis)
    axes[1, i].set_title(f'AB Channels {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print(f"L shape: {L_sample.shape}")
print(f"AB shape: {AB_sample.shape}")
print(f"L range: [{L_sample.min():.2f}, {L_sample.max():.2f}]")
print(f"AB range: [{AB_sample.min():.2f}, {AB_sample.max():.2f}]")

## 3. Model Initialization

In [None]:
# Initialize models
generator = UNetGenerator(in_channels=1, out_channels=2, features=64).to(device)
discriminator = PatchDiscriminator(in_channels=3, features=64).to(device)

# Count parameters
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Architecture:")
print(f"  Generator parameters: {count_params(generator):,}")
print(f"  Discriminator parameters: {count_params(discriminator):,}")
print(f"  Total parameters: {count_params(generator) + count_params(discriminator):,}")

In [None]:
# Test forward pass
with torch.no_grad():
    L_test = L_sample[:1].to(device)
    AB_test = AB_sample[:1].to(device)
    
    # Generator forward
    AB_pred = generator(L_test)
    print(f"Generator input: {L_test.shape}")
    print(f"Generator output: {AB_pred.shape}")
    
    # Discriminator forward
    disc_out = discriminator(L_test, AB_pred)
    print(f"Discriminator output: {disc_out.shape}")

## 4. Training

In [None]:
# Initialize trainer
trainer = Trainer(
    generator=generator,
    discriminator=discriminator,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    config=config,
    save_dir=config['save_dir']
)

print("Trainer initialized successfully!")

In [None]:
# Start training
print("Starting training...")
print("="*60)

history = trainer.train(
    num_epochs=config['num_epochs'],
    resume_path=None  # Set to checkpoint path to resume training
)

## 5. Training Visualization

In [None]:
# Plot training history
from src.utils import plot_training_history

plot_training_history(
    history,
    save_path='../results/plots/training_history.png',
    show=True
)

In [None]:
# Display training summary
print("="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Total Epochs: {len(history['train_g_loss'])}")
print(f"")
print(f"Generator Loss:")
print(f"  Initial: {history['train_g_loss'][0]:.4f}")
print(f"  Final:   {history['train_g_loss'][-1]:.4f}")
print(f"  Best:    {min(history['val_g_loss']):.4f}")
print(f"")
print(f"Discriminator Loss:")
print(f"  Initial: {history['train_d_loss'][0]:.4f}")
print(f"  Final:   {history['train_d_loss'][-1]:.4f}")
print("="*60)

In [None]:
# Visualize sample predictions after training
generator.eval()

from src.utils import lab2rgb, denormalize_lab

fig, axes = plt.subplots(3, 4, figsize=(14, 10))
fig.suptitle('Colorization Results After Training', fontsize=14, fontweight='bold')

with torch.no_grad():
    for batch_idx, (L, AB_real) in enumerate(test_loader):
        if batch_idx >= 1:
            break
            
        L = L.to(device)
        AB_pred = generator(L)
        
        for i in range(min(4, L.size(0))):
            L_np = L[i].cpu().numpy().transpose(1, 2, 0)
            AB_pred_np = AB_pred[i].cpu().numpy().transpose(1, 2, 0)
            AB_real_np = AB_real[i].numpy().transpose(1, 2, 0)
            
            lab_pred = denormalize_lab(L_np, AB_pred_np)
            lab_real = denormalize_lab(L_np, AB_real_np)
            
            rgb_pred = lab2rgb(lab_pred)
            rgb_real = lab2rgb(lab_real)
            
            # Grayscale
            axes[0, i].imshow(L_np.squeeze(), cmap='gray')
            axes[0, i].set_title('Grayscale' if i == 0 else '')
            axes[0, i].axis('off')
            
            # Predicted
            axes[1, i].imshow(np.clip(rgb_pred, 0, 1))
            axes[1, i].set_title('Predicted' if i == 0 else '')
            axes[1, i].axis('off')
            
            # Ground Truth
            axes[2, i].imshow(np.clip(rgb_real, 0, 1))
            axes[2, i].set_title('Ground Truth' if i == 0 else '')
            axes[2, i].axis('off')

axes[0, 0].set_ylabel('Input', fontsize=12)
axes[1, 0].set_ylabel('Output', fontsize=12)
axes[2, 0].set_ylabel('Target', fontsize=12)

plt.tight_layout()
plt.savefig('../results/plots/training_results.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Save Final Model

In [None]:
# Save final models
os.makedirs('../trained_models', exist_ok=True)

torch.save(generator.state_dict(), '../trained_models/generator_final.pth')
torch.save(discriminator.state_dict(), '../trained_models/discriminator_final.pth')

print("Models saved to trained_models/")
print(f"  - generator_final.pth")
print(f"  - discriminator_final.pth")

## Training Complete

The model has been trained successfully. You can now:
1. Run the evaluation notebook (`03_Evaluation.ipynb`) for detailed metrics
2. Use the Gradio UI (`python app.py`) for interactive testing
3. Use the inference script (`python inference.py`) for batch processing