## 1. Import Libraries

In [None]:
# Standard libraries
import os
import sys
import json
import time
import shutil
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from datetime import datetime
from collections import defaultdict

# Data handling
import numpy as np
import pandas as pd
from PIL import Image

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# Set visualization style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 10

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print("‚úÖ Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Setup Paths and Configuration

In [None]:
# Project paths
project_root = Path(r'd:\Projects\AI-Virtual-TryOn')
data_dir = project_root / 'data' / 'raw' / 'viton-hd'
output_dir = project_root / 'outputs' / 'training'
checkpoint_dir = output_dir / 'checkpoints'
logs_dir = output_dir / 'logs'

# Create directories
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(parents=True, exist_ok=True)

# Load configurations
with open(project_root / 'outputs' / 'model_architecture' / 'model_architecture_config.json', 'r') as f:
    model_config = json.load(f)

with open(project_root / 'outputs' / 'loss_functions' / 'loss_config.json', 'r') as f:
    loss_config = json.load(f)

print(f"üìÅ Project Root: {project_root}")
print(f"üìÅ Data Directory: {data_dir}")
print(f"üìÅ Output Directory: {output_dir}")
print(f"üìÅ Checkpoint Directory: {checkpoint_dir}")
print(f"üìÅ Logs Directory: {logs_dir}")
print(f"\n‚úÖ Loaded model and loss configurations")

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nüñ•Ô∏è Using device: {device}")

## 3. Training Configuration

In [None]:
# Training hyperparameters
training_config = {
    # Training settings
    'num_epochs': 100,
    'batch_size': 4,
    'num_workers': 0,  # Windows compatibility
    
    # Optimizer settings
    'lr_g': 0.0002,  # Generator learning rate
    'lr_d': 0.0002,  # Discriminator learning rate
    'beta1': 0.5,    # Adam beta1
    'beta2': 0.999,  # Adam beta2
    
    # Scheduler settings
    'decay_after_epoch': 50,  # Start LR decay after this epoch
    'decay_type': 'linear',   # 'linear' or 'step'
    
    # Training stability
    'gradient_clip': 5.0,      # Gradient clipping value
    'discriminator_steps': 1,  # Discriminator updates per generator update
    
    # Checkpointing
    'save_every': 5,           # Save checkpoint every N epochs
    'keep_best': True,         # Keep best model based on validation
    'max_checkpoints': 5,      # Maximum checkpoints to keep
    
    # Validation
    'validate_every': 1,       # Validate every N epochs
    'save_images_every': 5,    # Save sample images every N epochs
    
    # Early stopping
    'early_stopping': True,
    'patience': 20,            # Epochs without improvement
    
    # Logging
    'log_every': 50,           # Log every N batches
    'tensorboard': True,       # Use TensorBoard
}

print("="*70)
print("üìä TRAINING CONFIGURATION")
print("="*70)
print(f"\\nüéØ Training Settings:")
print(f"   Epochs: {training_config['num_epochs']}")
print(f"   Batch size: {training_config['batch_size']}")
print(f"   Num workers: {training_config['num_workers']}")

print(f"\\n‚öôÔ∏è Optimizer Settings:")
print(f"   Generator LR: {training_config['lr_g']}")
print(f"   Discriminator LR: {training_config['lr_d']}")
print(f"   Beta1: {training_config['beta1']}, Beta2: {training_config['beta2']}")

print(f"\\nüìâ Scheduler Settings:")
print(f"   Decay after epoch: {training_config['decay_after_epoch']}")
print(f"   Decay type: {training_config['decay_type']}")

print(f"\\nüõ°Ô∏è Training Stability:")
print(f"   Gradient clipping: {training_config['gradient_clip']}")
print(f"   Discriminator steps: {training_config['discriminator_steps']}")

print(f"\\nüíæ Checkpointing:")
print(f"   Save every: {training_config['save_every']} epochs")
print(f"   Keep best: {training_config['keep_best']}")
print(f"   Max checkpoints: {training_config['max_checkpoints']}")

print(f"\\n‚úÖ Validation:")
print(f"   Validate every: {training_config['validate_every']} epochs")
print(f"   Save images every: {training_config['save_images_every']} epochs")

print(f"\\nüõë Early Stopping:")
print(f"   Enabled: {training_config['early_stopping']}")
print(f"   Patience: {training_config['patience']} epochs")

print("\\n" + "="*70)

## 4. Load Model Architectures from Notebook 08

In [None]:
# Copy model architecture classes from Notebook 08

# Building blocks
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.norm1 = nn.InstanceNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.norm2 = nn.InstanceNorm2d(channels)
    
    def forward(self, x):
        residual = x
        out = F.relu(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        return out + residual


class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        B, C, H, W = x.shape
        query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)
        key = self.key(x).view(B, -1, H * W)
        attention = F.softmax(torch.bmm(query, key), dim=-1)
        value = self.value(x).view(B, C, H * W)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)
        return self.gamma * out + x


class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1)
        self.norm = nn.InstanceNorm2d(out_channels)
    
    def forward(self, x):
        return F.relu(self.norm(self.conv(x)))


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)
        self.norm = nn.InstanceNorm2d(out_channels)
    
    def forward(self, x):
        return F.relu(self.norm(self.conv(x)))


print("‚úÖ Building blocks defined")
print("   - ResidualBlock, SelfAttention, DownsampleBlock, UpsampleBlock")

## 5. Generator and Discriminator Networks

In [None]:
# Generator Network (U-Net with Self-Attention)
class Generator(nn.Module):
    def __init__(self, input_channels=41, output_channels=3, ngf=64, num_downs=4, num_res_blocks=9):
        super().__init__()
        
        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, ngf, 7, 1, 3),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        )
        
        # Encoder (downsampling)
        self.encoder = nn.ModuleList()
        in_ch = ngf
        for i in range(num_downs):
            out_ch = min(in_ch * 2, 512)
            self.encoder.append(DownsampleBlock(in_ch, out_ch))
            in_ch = out_ch
        
        # Bottleneck with residual blocks and attention
        self.bottleneck = nn.ModuleList()
        for _ in range(num_res_blocks):
            self.bottleneck.append(ResidualBlock(in_ch))
        self.attention = SelfAttention(in_ch)
        
        # Decoder (upsampling)
        self.decoder = nn.ModuleList()
        for i in range(num_downs):
            out_ch = max(in_ch // 2, ngf)
            self.decoder.append(UpsampleBlock(in_ch * 2, out_ch))  # *2 for skip connections
            in_ch = out_ch
        
        # Final convolution
        self.final = nn.Sequential(
            nn.Conv2d(ngf, output_channels, 7, 1, 3),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Initial
        x = self.initial(x)
        
        # Encoder with skip connections
        skips = []
        for enc in self.encoder:
            x = enc(x)
            skips.append(x)
        
        # Bottleneck
        for res_block in self.bottleneck:
            x = res_block(x)
        x = self.attention(x)
        
        # Decoder with skip connections
        for dec, skip in zip(self.decoder, reversed(skips)):
            x = torch.cat([x, skip], dim=1)
            x = dec(x)
        
        # Final
        return self.final(x)


# Discriminator Network (PatchGAN)
class Discriminator(nn.Module):
    def __init__(self, input_channels=6, ndf=64, n_layers=3):
        super().__init__()
        
        def discriminator_block(in_ch, out_ch, normalize=True):
            layers = [nn.utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 4, 2, 1))]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_ch))
            layers.append(nn.LeakyReLU(0.2, True))
            return layers
        
        layers = []
        in_ch = input_channels
        
        # First layer (no normalization)
        layers.extend(discriminator_block(in_ch, ndf, normalize=False))
        in_ch = ndf
        
        # Intermediate layers
        for i in range(n_layers - 1):
            out_ch = min(in_ch * 2, 512)
            layers.extend(discriminator_block(in_ch, out_ch))
            in_ch = out_ch
        
        # Final layer
        layers.extend(discriminator_block(in_ch, in_ch))
        layers.append(nn.utils.spectral_norm(nn.Conv2d(in_ch, 1, 4, 1, 1)))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, image, condition):
        x = torch.cat([image, condition], dim=1)
        return self.model(x)


print("‚úÖ Generator and Discriminator networks defined")
print("   - Generator: U-Net with self-attention")
print("   - Discriminator: PatchGAN with spectral normalization")

## 6. Load Loss Functions from Notebook 09

In [None]:
# Import loss functions from torchvision
from torchvision.models import VGG19_Weights
import torchvision.models as models

# VGG Perceptual Loss
class VGGPerceptualLoss(nn.Module):
    def __init__(self, layers=['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'], weights=None):
        super().__init__()
        vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features
        for param in vgg.parameters():
            param.requires_grad = False
        
        self.layer_name_mapping = {
            'relu1_1': '1', 'relu2_1': '6', 'relu3_1': '11', 'relu4_1': '20', 'relu5_1': '29'
        }
        
        self.features = nn.ModuleDict()
        for layer_name in layers:
            layer_idx = int(self.layer_name_mapping[layer_name])
            self.features[layer_name] = nn.Sequential(*[vgg[i] for i in range(layer_idx + 1)])
        
        self.weights = weights if weights is not None else [1.0] * len(layers)
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    def normalize(self, x):
        x = (x + 1) / 2
        x = (x - self.mean) / self.std
        return x
    
    def forward(self, x, y):
        x = self.normalize(x)
        y = self.normalize(y)
        loss = 0.0
        for (layer_name, feature_extractor), weight in zip(self.features.items(), self.weights):
            x_feat = feature_extractor(x)
            y_feat = feature_extractor(y)
            loss += weight * F.l1_loss(x_feat, y_feat)
        return loss


# GAN Loss
class GANLoss(nn.Module):
    def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0):
        super().__init__()
        self.gan_mode = gan_mode
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'hinge':
            self.loss = None
        else:
            raise ValueError(f\"Unsupported GAN mode: {gan_mode}\")
    
    def get_target_tensor(self, prediction, target_is_real):
        if target_is_real:
            target = torch.ones_like(prediction) * self.real_label
        else:
            target = torch.ones_like(prediction) * self.fake_label
        return target
    
    def forward(self, prediction, target_is_real):
        if self.gan_mode == 'hinge':
            if target_is_real:
                loss = F.relu(1.0 - prediction).mean()
            else:
                loss = F.relu(1.0 + prediction).mean()
        else:
            target = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target)
        return loss


# Combined VITON Loss
class VITONLoss(nn.Module):
    def __init__(self, lambda_gan=1.0, lambda_perceptual=10.0, lambda_l1=10.0, 
                 vgg_layers=['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'], gan_mode='lsgan'):
        super().__init__()
        self.lambda_gan = lambda_gan
        self.lambda_perceptual = lambda_perceptual
        self.lambda_l1 = lambda_l1
        
        self.gan_loss = GANLoss(gan_mode=gan_mode)
        self.perceptual_loss = VGGPerceptualLoss(layers=vgg_layers)
        self.l1_loss = nn.L1Loss()
    
    def compute_generator_loss(self, fake_image, real_image, disc_fake):
        losses = {}
        losses['gan'] = self.gan_loss(disc_fake, target_is_real=True) * self.lambda_gan
        losses['perceptual'] = self.perceptual_loss(fake_image, real_image) * self.lambda_perceptual
        losses['l1'] = self.l1_loss(fake_image, real_image) * self.lambda_l1
        losses['total'] = losses['gan'] + losses['perceptual'] + losses['l1']
        return losses
    
    def compute_discriminator_loss(self, disc_real, disc_fake):
        losses = {}
        losses['real'] = self.gan_loss(disc_real, target_is_real=True)
        losses['fake'] = self.gan_loss(disc_fake, target_is_real=False)
        losses['total'] = (losses['real'] + losses['fake']) * 0.5
        return losses


print("‚úÖ Loss functions defined")
print("   - VGGPerceptualLoss, GANLoss, VITONLoss")

## 7. Load Dataset (Small Subset for Testing)

In [None]:
# For now, we'll create a simple dummy dataset to test the training loop
# In production, you would load the actual VITON dataset

class DummyVITONDataset(torch.utils.data.Dataset):
    """Dummy dataset for testing training loop"""
    def __init__(self, num_samples=100):
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Generate dummy data matching expected shapes
        # In real dataset: load actual images, parsing, pose, etc.
        return {
            'multi_channel_input': torch.randn(41, 1024, 768),  # 41-channel input
            'target_image': torch.randn(3, 1024, 768),          # Target RGB image
            'cloth_image': torch.randn(3, 1024, 768),           # Cloth condition
        }


# Create datasets
print("Creating dummy datasets for testing...")
train_dataset = DummyVITONDataset(num_samples=200)  # Small for testing
val_dataset = DummyVITONDataset(num_samples=50)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=training_config['batch_size'],
    shuffle=True,
    num_workers=training_config['num_workers'],
    pin_memory=True if device == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=training_config['batch_size'],
    shuffle=False,
    num_workers=training_config['num_workers'],
    pin_memory=True if device == 'cuda' else False
)

print("="*70)
print("üìä DATASET LOADED")
print("="*70)
print(f"\\n‚úÖ Train dataset: {len(train_dataset)} samples")
print(f"‚úÖ Val dataset: {len(val_dataset)} samples")
print(f"\\nüì¶ Train batches: {len(train_loader)}")
print(f"üì¶ Val batches: {len(val_loader)}")
print(f"\\n‚öôÔ∏è Batch size: {training_config['batch_size']}")
print(f"‚öôÔ∏è Num workers: {training_config['num_workers']}")
print("\\n" + "="*70)
print("\\n‚ö†Ô∏è  Note: Using dummy data for testing. Replace with actual VITON dataset for training.")

## 8. Initialize Models, Optimizers, and Loss

In [None]:
# Initialize models
print("Initializing models...")
generator = Generator(
    input_channels=41,
    output_channels=3,
    ngf=64,
    num_downs=4,
    num_res_blocks=9
).to(device)

discriminator = Discriminator(
    input_channels=6,  # 3 (image) + 3 (condition)
    ndf=64,
    n_layers=3
).to(device)

# Initialize loss
criterion = VITONLoss(
    lambda_gan=loss_config['loss_weights']['lambda_gan'],
    lambda_perceptual=loss_config['loss_weights']['lambda_perceptual'],
    lambda_l1=loss_config['loss_weights']['lambda_l1'],
    vgg_layers=loss_config['perceptual_loss']['layers'],
    gan_mode='lsgan'
).to(device)

# Initialize optimizers
optimizer_g = optim.Adam(
    generator.parameters(),
    lr=training_config['lr_g'],
    betas=(training_config['beta1'], training_config['beta2'])
)

optimizer_d = optim.Adam(
    discriminator.parameters(),
    lr=training_config['lr_d'],
    betas=(training_config['beta1'], training_config['beta2'])
)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

gen_params = count_parameters(generator)
disc_params = count_parameters(discriminator)

print("="*70)
print("üéØ MODELS INITIALIZED")
print("="*70)
print(f"\\nüî∑ Generator:")
print(f"   Parameters: {gen_params:,} ({gen_params/1e6:.2f}M)")
print(f"   Architecture: U-Net with self-attention")

print(f"\\nüî∂ Discriminator:")
print(f"   Parameters: {disc_params:,} ({disc_params/1e6:.2f}M)")
print(f"   Architecture: PatchGAN with spectral normalization")

print(f"\\nüíæ Total Parameters: {gen_params + disc_params:,} ({(gen_params + disc_params)/1e6:.2f}M)")

print(f"\\n‚öôÔ∏è Optimizer: Adam")
print(f"   Generator LR: {training_config['lr_g']}")
print(f"   Discriminator LR: {training_config['lr_d']}")
print(f"   Betas: ({training_config['beta1']}, {training_config['beta2']})")

print(f"\\nüìä Loss Functions:")
print(f"   GAN loss weight: {loss_config['loss_weights']['lambda_gan']}")
print(f"   Perceptual loss weight: {loss_config['loss_weights']['lambda_perceptual']}")
print(f"   L1 loss weight: {loss_config['loss_weights']['lambda_l1']}")

print("\\n" + "="*70)

## 9. Training and Validation Functions

In [None]:
def train_one_epoch(generator, discriminator, train_loader, optimizer_g, optimizer_d, 
                    criterion, device, epoch, config):
    """
    Train for one epoch.
    
    Returns:
        Dictionary with average losses for the epoch
    """
    generator.train()
    discriminator.train()
    
    # Metrics tracking
    metrics = defaultdict(list)
    
    # Progress bar
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    
    for batch_idx, batch in enumerate(pbar):
        # Move data to device
        multi_channel_input = batch['multi_channel_input'].to(device)
        target_image = batch['target_image'].to(device)
        cloth_condition = batch['cloth_image'].to(device)
        
        batch_size = multi_channel_input.size(0)
        
        # ==================== Train Discriminator ====================
        optimizer_d.zero_grad()
        
        # Generate fake images
        with torch.no_grad():
            fake_image = generator(multi_channel_input)
        
        # Discriminator on real
        disc_real = discriminator(target_image, cloth_condition)
        
        # Discriminator on fake
        disc_fake = discriminator(fake_image.detach(), cloth_condition)
        
        # Compute discriminator loss
        d_losses = criterion.compute_discriminator_loss(disc_real, disc_fake)
        d_loss = d_losses['total']
        
        # Backward and optimize
        d_loss.backward()
        if config['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config['gradient_clip'])
        optimizer_d.step()
        
        # ==================== Train Generator ====================
        optimizer_g.zero_grad()
        
        # Generate fake images
        fake_image = generator(multi_channel_input)
        
        # Discriminator on fake (for generator)
        disc_fake = discriminator(fake_image, cloth_condition)
        
        # Compute generator loss
        g_losses = criterion.compute_generator_loss(fake_image, target_image, disc_fake)
        g_loss = g_losses['total']
        
        # Backward and optimize
        g_loss.backward()
        if config['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(generator.parameters(), config['gradient_clip'])
        optimizer_g.step()
        
        # ==================== Log Metrics ====================
        metrics['g_loss'].append(g_loss.item())
        metrics['g_gan'].append(g_losses['gan'].item())
        metrics['g_perceptual'].append(g_losses['perceptual'].item())
        metrics['g_l1'].append(g_losses['l1'].item())
        metrics['d_loss'].append(d_loss.item())
        metrics['d_real'].append(d_losses['real'].item())
        metrics['d_fake'].append(d_losses['fake'].item())
        
        # Update progress bar
        if batch_idx % config['log_every'] == 0:
            pbar.set_postfix({
                'G_loss': f"{g_loss.item():.4f}",
                'D_loss': f"{d_loss.item():.4f}",
                'G_GAN': f"{g_losses['gan'].item():.4f}",
                'G_Perc': f"{g_losses['perceptual'].item():.4f}"
            })
    
    # Compute average metrics
    avg_metrics = {k: np.mean(v) for k, v in metrics.items()}
    
    return avg_metrics


def validate(generator, discriminator, val_loader, criterion, device):
    """
    Validate the model.
    
    Returns:
        Dictionary with average validation losses
    """
    generator.eval()
    discriminator.eval()
    
    metrics = defaultdict(list)
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            # Move data to device
            multi_channel_input = batch['multi_channel_input'].to(device)
            target_image = batch['target_image'].to(device)
            cloth_condition = batch['cloth_image'].to(device)
            
            # Generate fake images
            fake_image = generator(multi_channel_input)
            
            # Discriminator outputs
            disc_real = discriminator(target_image, cloth_condition)
            disc_fake = discriminator(fake_image, cloth_condition)
            
            # Compute losses
            g_losses = criterion.compute_generator_loss(fake_image, target_image, disc_fake)
            d_losses = criterion.compute_discriminator_loss(disc_real, disc_fake)
            
            # Log metrics
            metrics['g_loss'].append(g_losses['total'].item())
            metrics['g_gan'].append(g_losses['gan'].item())
            metrics['g_perceptual'].append(g_losses['perceptual'].item())
            metrics['g_l1'].append(g_losses['l1'].item())
            metrics['d_loss'].append(d_losses['total'].item())
    
    # Compute average metrics
    avg_metrics = {k: np.mean(v) for k, v in metrics.items()}
    
    return avg_metrics


print("‚úÖ Training and validation functions defined")
print("   - train_one_epoch(): Trains for one epoch with progress bar")
print("   - validate(): Validates model on validation set")
print("   - Includes gradient clipping and metric tracking")

## 10. Checkpointing and Utility Functions

In [None]:
def save_checkpoint(generator, discriminator, optimizer_g, optimizer_d, epoch, 
                   metrics, checkpoint_dir, is_best=False, config=None):
    """Save model checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_g_state_dict': optimizer_g.state_dict(),
        'optimizer_d_state_dict': optimizer_d.state_dict(),
        'metrics': metrics,
        'config': config
    }
    
    # Save regular checkpoint
    checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch:03d}.pth'
    torch.save(checkpoint, checkpoint_path)
    print(f"üíæ Saved checkpoint: {checkpoint_path.name}")
    
    # Save best model
    if is_best:
        best_path = checkpoint_dir / 'best_model.pth'
        torch.save(checkpoint, best_path)
        print(f"üåü Saved best model: {best_path.name}")
    
    # Keep only max_checkpoints recent checkpoints
    if config and 'max_checkpoints' in config:
        checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pth'))
        if len(checkpoints) > config['max_checkpoints']:
            for old_checkpoint in checkpoints[:-config['max_checkpoints']]:
                old_checkpoint.unlink()
                print(f"üóëÔ∏è  Removed old checkpoint: {old_checkpoint.name}")


def load_checkpoint(checkpoint_path, generator, discriminator, optimizer_g=None, 
                   optimizer_d=None, device='cuda'):
    """Load model checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    
    if optimizer_g is not None:
        optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
    if optimizer_d is not None:
        optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
    
    epoch = checkpoint['epoch']
    metrics = checkpoint.get('metrics', {})
    
    print(f"‚úÖ Loaded checkpoint from epoch {epoch}")
    return epoch, metrics


def save_sample_images(generator, val_loader, epoch, output_dir, device, num_samples=4):
    """Save sample generated images."""
    generator.eval()
    
    with torch.no_grad():
        # Get first batch
        batch = next(iter(val_loader))
        multi_channel_input = batch['multi_channel_input'][:num_samples].to(device)
        target_image = batch['target_image'][:num_samples].to(device)
        
        # Generate images
        fake_image = generator(multi_channel_input)
        
        # Denormalize images (from [-1, 1] to [0, 1])
        fake_image = (fake_image + 1) / 2
        target_image = (target_image + 1) / 2
        
        # Create figure
        fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 4, 8))
        
        for i in range(num_samples):
            # Target image
            target_np = target_image[i].cpu().permute(1, 2, 0).numpy()
            axes[0, i].imshow(target_np)
            axes[0, i].set_title(f'Target {i+1}', fontsize=12)
            axes[0, i].axis('off')
            
            # Generated image
            fake_np = fake_image[i].cpu().permute(1, 2, 0).numpy()
            axes[1, i].imshow(fake_np)
            axes[1, i].set_title(f'Generated {i+1}', fontsize=12)
            axes[1, i].axis('off')
        
        plt.tight_layout()
        save_path = output_dir / f'samples_epoch_{epoch:03d}.png'
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"üñºÔ∏è  Saved sample images: {save_path.name}")


print("‚úÖ Checkpoint and utility functions defined")
print("   - save_checkpoint(): Save model state")
print("   - load_checkpoint(): Load model state")
print("   - save_sample_images(): Save generated image samples")

## 11. Test Training Loop (1 Epoch)

In [None]:
print("="*70)
print("üöÄ TESTING TRAINING LOOP")
print("="*70)
print("\\nTraining for 1 epoch to test the pipeline...")
print("This will verify that all components work together correctly.\\n")

# Test training for 1 epoch
epoch = 1
start_time = time.time()

# Train
train_metrics = train_one_epoch(
    generator, discriminator, train_loader, 
    optimizer_g, optimizer_d, criterion, 
    device, epoch, training_config
)

# Validate
val_metrics = validate(generator, discriminator, val_loader, criterion, device)

# Save sample images
samples_dir = output_dir / 'samples'
samples_dir.mkdir(exist_ok=True)
save_sample_images(generator, val_loader, epoch, samples_dir, device, num_samples=4)

epoch_time = time.time() - start_time

# Print results
print("\\n" + "="*70)
print("üìä EPOCH 1 RESULTS")
print("="*70)

print(f"\\nüî∑ Training Metrics:")
print(f"   Generator Loss: {train_metrics['g_loss']:.4f}")
print(f"   - GAN Loss: {train_metrics['g_gan']:.4f}")
print(f"   - Perceptual Loss: {train_metrics['g_perceptual']:.4f}")
print(f"   - L1 Loss: {train_metrics['g_l1']:.4f}")
print(f"   Discriminator Loss: {train_metrics['d_loss']:.4f}")
print(f"   - Real Loss: {train_metrics['d_real']:.4f}")
print(f"   - Fake Loss: {train_metrics['d_fake']:.4f}")

print(f"\\nüî∂ Validation Metrics:")
print(f"   Generator Loss: {val_metrics['g_loss']:.4f}")
print(f"   - GAN Loss: {val_metrics['g_gan']:.4f}")
print(f"   - Perceptual Loss: {val_metrics['g_perceptual']:.4f}")
print(f"   - L1 Loss: {val_metrics['g_l1']:.4f}")
print(f"   Discriminator Loss: {val_metrics['d_loss']:.4f}")

print(f"\\n‚è±Ô∏è  Epoch Time: {epoch_time:.2f}s ({epoch_time/60:.2f}min)")
print(f"‚è±Ô∏è  Estimated time for 100 epochs: {epoch_time * 100 / 3600:.2f}h")

print("\\n" + "="*70)
print("\\n‚úÖ TRAINING LOOP TEST SUCCESSFUL!")
print("="*70)

## 12. Save Test Checkpoint

In [None]:
# Save checkpoint after test epoch
combined_metrics = {
    'train': train_metrics,
    'val': val_metrics,
    'epoch_time': epoch_time
}

save_checkpoint(
    generator, discriminator, 
    optimizer_g, optimizer_d,
    epoch=1,
    metrics=combined_metrics,
    checkpoint_dir=checkpoint_dir,
    is_best=True,
    config=training_config
)

print("\\n‚úÖ Test checkpoint saved successfully!")

## 13. Training Configuration Summary

In [None]:
# Save complete training configuration
full_training_config = {
    'model': {
        'generator': {
            'architecture': 'U-Net with self-attention',
            'input_channels': 41,
            'output_channels': 3,
            'ngf': 64,
            'num_downs': 4,
            'num_res_blocks': 9,
            'parameters': gen_params
        },
        'discriminator': {
            'architecture': 'PatchGAN with spectral normalization',
            'input_channels': 6,
            'ndf': 64,
            'n_layers': 3,
            'parameters': disc_params
        }
    },
    'training': training_config,
    'loss': loss_config,
    'test_results': {
        'epoch': 1,
        'train_metrics': {k: float(v) for k, v in train_metrics.items()},
        'val_metrics': {k: float(v) for k, v in val_metrics.items()},
        'epoch_time_seconds': epoch_time
    }
}

# Save configuration
config_save_path = output_dir / 'training_config.json'
with open(config_save_path, 'w') as f:
    json.dump(full_training_config, f, indent=2)

print("="*70)
print("üíæ TRAINING CONFIGURATION SAVED")
print("="*70)
print(f"\\nüìÑ Config saved to: {config_save_path}")

# List generated files
print(f"\\nüìÅ Generated Files:")
checkpoint_files = list(checkpoint_dir.glob('*.pth'))
sample_files = list(samples_dir.glob('*.png'))

for f in checkpoint_files:
    print(f"   ‚úì {f.relative_to(output_dir)}")
for f in sample_files:
    print(f"   ‚úì {f.relative_to(output_dir)}")
print(f"   ‚úì {config_save_path.relative_to(output_dir)}")

print("\\n" + "="*70)

## 14. Summary and Next Steps

In [None]:
print("="*70)
print("üéâ TRAINING LOOP COMPLETE!")
print("="*70)

print("\\n‚úÖ Completed Tasks:")
print("   1. ‚úì Loaded model architectures (Generator & Discriminator)")
print("   2. ‚úì Loaded loss functions (GAN, Perceptual, L1)")
print("   3. ‚úì Created dummy dataset for testing")
print("   4. ‚úì Initialized models and optimizers")
print("   5. ‚úì Implemented train_one_epoch() function")
print("   6. ‚úì Implemented validate() function")
print("   7. ‚úì Implemented checkpoint saving/loading")
print("   8. ‚úì Tested complete training pipeline for 1 epoch")
print("   9. ‚úì Generated sample images")
print("   10. ‚úì Saved configuration and checkpoints")

print(f"\\nüìä Training Pipeline Status:")
print(f"   üî∑ Generator: {gen_params/1e6:.2f}M parameters")
print(f"   üî∂ Discriminator: {disc_params/1e6:.2f}M parameters")
print(f"   ‚è±Ô∏è  1 Epoch Time: {epoch_time:.2f}s")
print(f"   üì¶ Train batches: {len(train_loader)}")
print(f"   üì¶ Val batches: {len(val_loader)}")

print(f"\\nüìà Test Results (1 Epoch):")
print(f"   Training:")
print(f"   - Generator Loss: {train_metrics['g_loss']:.4f}")
print(f"   - Discriminator Loss: {train_metrics['d_loss']:.4f}")
print(f"   Validation:")
print(f"   - Generator Loss: {val_metrics['g_loss']:.4f}")
print(f"   - Discriminator Loss: {val_metrics['d_loss']:.4f}")

print(f"\\nüìÅ Output Files:")
print(f"   - Checkpoints: {checkpoint_dir}")
print(f"   - Sample images: {samples_dir}")
print(f"   - Configuration: {config_save_path}")

print("\\nüöÄ Ready for Full Training!")
print("\\nüí° Next Steps:")
print("   1. Replace DummyVITONDataset with actual VITON-HD dataset")
print("   2. Add TensorBoard logging for real-time monitoring")
print("   3. Implement learning rate scheduling")
print("   4. Add early stopping based on validation loss")
print("   5. Train for full 100 epochs")
print("   6. Evaluate on test set")
print("   7. Implement inference pipeline")

print("\\n‚ö†Ô∏è  Important Notes:")
print("   - Currently using dummy data for pipeline testing")
print("   - Replace with actual dataset from Notebooks 03-07")
print("   - Adjust batch size based on GPU memory")
print("   - Monitor GPU memory usage during training")
print("   - Use gradient accumulation if batch size is too small")

print("\\n" + "="*70)
print("\\n‚úÖ TRAINING LOOP IMPLEMENTATION COMPLETE!")
print("="*70)