## üì¶ Step 2: Install Required Packages

In [None]:
# !pip install -q diffusers transformers accelerate scikit-image

# print("‚úÖ Packages installed successfully!")

## üìö Step 3: Import Libraries

In [None]:
import os
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
from diffusers import UNet2DModel, DDPMScheduler
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import json
from datetime import datetime
import time

# Check environment
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

    # Optimize GPU memory - CRITICAL for large models
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:128'
    torch.cuda.empty_cache()

    # Enable cuDNN optimization
    torch.backends.cudnn.benchmark = True

    print("‚úÖ GPU memory optimizations enabled")

## ‚öôÔ∏è Step 4: Configuration

**üìù Update the paths to your H5 files in Google Drive!**

In [None]:
# ============================================================
# PATHS 
# ============================================================
TRAIN_H5_PATH = "/kaggle/input/datasets-train-val/train.h5"
VAL_H5_PATH = "/kaggle/input/datasets-train-val/val.h5"
CHECKPOINT_DIR = "/kaggle/working/colorization_checkpoints_128"
LOG_DIR = "/kaggle/working/colorization_logs_128"

# Create directories
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# ============================================================
# TRAINING CONFIGURATION
# ============================================================

# Training settings
NUM_EPOCHS = 27
BATCH_SIZE = 88             
LEARNING_RATE = 2e-4
NUM_WORKERS = 2
ACCUMULATION_STEPS = 1

# Model parameters
IMG_SIZE = 128               # 128√ó128 resolution
L_CHANNELS = 1
AB_CHANNELS = 2
UNET_IN_CHANNELS = 3         # L + AB
UNET_OUT_CHANNELS = 2        # AB only

# Diffusion parameters
NUM_TRAIN_TIMESTEPS = 1000
BETA_START = 0.0001
BETA_END = 0.02
BETA_SCHEDULE = "linear"

# Training settings
VAL_EVERY_N_EPOCHS = 1
LOG_EVERY_N_STEPS = 100
SAVE_CHECKPOINT_EVERY_N_STEPS = 300  # Save checkpoint every 300 steps 
MAX_GRAD_NORM = 1.0
USE_MIXED_PRECISION = True
SEED = 42

# H5 dataset key
H5_DATASET_KEY = "images"

# Resume from checkpoint (set path if resuming)
# Starting fresh with BATCH_SIZE=88 and LEARNING_RATE=2e-4 (linear scaled for dual GPU)
RESUME_FROM_CHECKPOINT = "/kaggle/input/12thcheckpoint/checkpoint_epoch_12_step_600.pth"

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Multi-GPU toggle (auto-enable if >1 CUDA devices)
GPU_COUNT = torch.cuda.device_count() if torch.cuda.is_available() else 0
USE_DATA_PARALLEL = GPU_COUNT > 1

# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    #torch.backends.cudnn.benchmark = True

print("=" * 70)
print("üöÄ CONFIGURATION")
print("=" * 70)
print(f"Device: {device} | CUDA devices: {GPU_COUNT}")
print(f"DataParallel: {'ON' if USE_DATA_PARALLEL else 'OFF'}")
print(f"Image Size: {IMG_SIZE}√ó{IMG_SIZE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Mixed Precision: {USE_MIXED_PRECISION}")
print(f"Diffusion Steps: {NUM_TRAIN_TIMESTEPS}")
print("=" * 70)

## üóÇÔ∏è Step 5: Dataset Class

In [None]:
class ColorDatasetH5(Dataset):
    """
    Dataset for loading 128√ó128 LAB images from H5 files.
    Expects preprocessed data: float16, normalized LAB values.
    """
    def __init__(self, h5_path, dataset_key='images'):
        self.h5_path = h5_path
        self.dataset_key = dataset_key

        # Get dataset info
        with h5py.File(h5_path, 'r') as f:
            if dataset_key not in f:
                available_keys = list(f.keys())
                raise KeyError(f"Key '{dataset_key}' not found. Available: {available_keys}")

            self.length = len(f[dataset_key])
            self.shape = f[dataset_key].shape
            self.dtype = f[dataset_key].dtype

        print(f"‚úì Dataset: {h5_path}")
        print(f"  - Images: {self.length:,}")
        print(f"  - Shape: {self.shape}")
        print(f"  - Dtype: {self.dtype}")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Load single image
        with h5py.File(self.h5_path, 'r') as f:
            lab = f[self.dataset_key][idx].astype(np.float32)

        # Split LAB channels
        L = lab[:, :, 0:1]   # (128, 128, 1)
        AB = lab[:, :, 1:3]  # (128, 128, 2)

        # Convert to PyTorch: (H, W, C) ‚Üí (C, H, W)
        L = torch.from_numpy(L).permute(2, 0, 1)   # (1, 128, 128)
        AB = torch.from_numpy(AB).permute(2, 0, 1) # (2, 128, 128)

        return L, AB

print("‚úÖ Dataset class defined")

## üìä Step 6: Create DataLoaders

In [None]:
print("Loading datasets...")

train_dataset = ColorDatasetH5(
    h5_path=TRAIN_H5_PATH,
    dataset_key=H5_DATASET_KEY
)

val_dataset = ColorDatasetH5(
    h5_path=VAL_H5_PATH,
    dataset_key=H5_DATASET_KEY
)

# Create DataLoaders with FIXED SEED for reproducible shuffling
#  Using SEED=42 ensures IDENTICAL shuffle order every time  restarting


def seed_worker(worker_id):
    """Fix worker seed for reproducibility."""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)  # SEED=42 (constant) - same shuffle every time!

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True,
    worker_init_fn=seed_worker,
    generator=g  #  Fixed generator ensures same shuffle order every time!
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=False
)

print("\n" + "=" * 70)
print("üìä DATASET INFO")
print("=" * 70)
print(f"Train: {len(train_dataset):,} images ({len(train_loader):,} batches)")
print(f"Val:   {len(val_dataset):,} images ({len(val_loader):,} batches)")
print(f"Batch size: {BATCH_SIZE}")
print(f"Steps per epoch: {len(train_loader):,}")
print(f"Estimated time per epoch: {len(train_loader) * 15 / 3600:.1f} hours")
print("=" * 70)

## üß† Step 7: Define Diffusion Model

In [None]:
class ColorDiffusionModel(nn.Module):
    """Conditional diffusion model for colorization."""

    def __init__(self, unet, noise_scheduler):
        super().__init__()
        self.unet = unet
        self.noise_scheduler = noise_scheduler

    def forward(self, L, AB, timesteps):
        """Forward pass: predict noise from L + noisy_AB"""
        model_input = torch.cat([L, AB], dim=1)  # (B, 3, 128, 128)
        predicted_noise = self.unet(model_input, timesteps).sample
        return predicted_noise

    def add_noise(self, AB, noise, timesteps):
        """Add noise to AB channels"""
        return self.noise_scheduler.add_noise(AB, noise, timesteps)


# Initialize UNet (optimized for 128√ó128 with memory efficiency)
print("Initializing UNet model...")
unet = UNet2DModel(
    sample_size=IMG_SIZE,
    in_channels=UNET_IN_CHANNELS,
    out_channels=UNET_OUT_CHANNELS,
    layers_per_block=2,
    block_out_channels=(96, 192, 384, 512),  # Reduced for memory efficiency
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
    attention_head_dim=8,
)

# Initialize noise scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=NUM_TRAIN_TIMESTEPS,
    beta_start=BETA_START,
    beta_end=BETA_END,
    beta_schedule=BETA_SCHEDULE,
    prediction_type="epsilon",
)

# Create model
model = ColorDiffusionModel(unet, noise_scheduler).to(device)

# Enable gradient checkpointing for memory efficiency
if hasattr(model.unet, 'enable_gradient_checkpointing'):
    model.unet.enable_gradient_checkpointing()
    print("‚úì Gradient checkpointing enabled")



# Optional: wrap with DataParallel for multi-GPU
if USE_DATA_PARALLEL:
    print(f"üèéÔ∏è Enabling DataParallel across {GPU_COUNT} GPUs")
    model = torch.nn.DataParallel(model)
else:
    print("Using single GPU or CPU")

# Count parameters (use underlying module for DP)
model_for_count = model.module if hasattr(model, 'module') else model
total_params = sum(p.numel() for p in model_for_count.parameters())
trainable_params = sum(p.numel() for p in model_for_count.parameters() if p.requires_grad)

print("\n" + "=" * 70)
print("üß† MODEL ARCHITECTURE")
print("=" * 70)
print(f"Parameters: {total_params:,}")
print(f"Trainable: {trainable_params:,}")
print(f"Model size (FP32): {total_params * 4 / (1024**2):.1f} MB")
print(f"Model size (FP16): {total_params * 2 / (1024**2):.1f} MB")
print("=" * 70)

## ‚ö° Step 8: Optimizer & Scheduler

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    weight_decay=0.01
)

# Learning rate scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR

lr_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,
    eta_min=1e-6
)

# Mixed precision scaler
scaler = GradScaler(enabled=USE_MIXED_PRECISION)

print("=" * 70)
print("‚ö° OPTIMIZATION")
print("=" * 70)
print(f"Optimizer: AdamW")
print(f"LR: {LEARNING_RATE}")
print(f"Scheduler: CosineAnnealingLR")
print(f"Mixed Precision: {USE_MIXED_PRECISION}")
print("=" * 70)

## üèãÔ∏è Step 9: Training & Validation Functions

In [None]:
def train_one_epoch(model, train_loader, optimizer, scaler, device, epoch, lr_scheduler, start_step=0):
    """Train for one epoch with mixed precision and mid-epoch checkpointing."""
    global current_training_step  # Update global step for emergency checkpoint
    model.train()
    total_loss = 0.0
    steps_processed = 0  # Track actual steps processed (not skipped)
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]", initial=start_step, total=len(train_loader))
    
    # Access underlying model for add_noise (DP-safe)
    model_module = model.module if hasattr(model, 'module') else model

    for step, (L, AB) in enumerate(progress_bar):
        current_training_step = step + 1  # Update global step counter
        # Skip already-processed steps when resuming from checkpoint
        if step < start_step:
            continue

        L = L.to(device, non_blocking=True)
        AB = AB.to(device, non_blocking=True)

        # Random timesteps
        batch_size = L.shape[0]
        timesteps = torch.randint(
            0, NUM_TRAIN_TIMESTEPS, (batch_size,),
            device=device, dtype=torch.long
        )

        # Add noise (use underlying module for DP compatibility)
        noise = torch.randn_like(AB)
        noisy_AB = model_module.add_noise(AB, noise, timesteps)

        # Forward pass with mixed precision
        with torch.amp.autocast('cuda', enabled=USE_MIXED_PRECISION):
            predicted_noise = model(L, noisy_AB, timesteps)
            loss = F.mse_loss(predicted_noise, noise)

        # Backward pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        steps_processed += 1

        # Update progress bar
        if (step + 1) % LOG_EVERY_N_STEPS == 0:
            avg_loss = total_loss / steps_processed
            progress_bar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'
            })

        # Save mid-epoch checkpoint 
        
        if (step + 1) % SAVE_CHECKPOINT_EVERY_N_STEPS == 0:  # and step >= start_step:

            avg_loss = total_loss / steps_processed
            # Save underlying model state (DP-safe)
            model_to_save = model.module if hasattr(model, 'module') else model
            checkpoint = {
                'epoch': epoch,
                'step': step + 1,
                'model_state_dict': model_to_save.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'train_loss': avg_loss,
                'data_parallel': hasattr(model, 'module'),
            }
            checkpoint_path = os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch+1}_step_{step+1}.pth')
            torch.save(checkpoint, checkpoint_path)
            print(f"\nüíæ Mid-epoch checkpoint saved: {os.path.basename(checkpoint_path)} (step {step+1})")

    return total_loss / steps_processed if steps_processed > 0 else 0.0


@torch.no_grad()
def validate(model, val_loader, device, epoch):
    """Validate the model."""
    model.eval()
    total_loss = 0.0
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]")
    
    # Access underlying model for add_noise (DP-safe)
    model_module = model.module if hasattr(model, 'module') else model

    for L, AB in progress_bar:
        L = L.to(device, non_blocking=True)
        AB = AB.to(device, non_blocking=True)

        batch_size = L.shape[0]
        timesteps = torch.randint(
            0, NUM_TRAIN_TIMESTEPS, (batch_size,),
            device=device, dtype=torch.long
        )

        noise = torch.randn_like(AB)
        noisy_AB = model_module.add_noise(AB, noise, timesteps)

        with torch.amp.autocast('cuda', enabled=USE_MIXED_PRECISION):
            predicted_noise = model(L, noisy_AB, timesteps)
            loss = F.mse_loss(predicted_noise, noise)

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{total_loss / (progress_bar.n + 1):.4f}'})

    return total_loss / len(val_loader)

print("‚úÖ Training functions defined")

## üíæ Step 10: Checkpoint Management

In [None]:
def _strip_module_prefix(state_dict: dict):
    return { (k[7:] if k.startswith('module.') else k): v for k, v in state_dict.items() }

def _add_module_prefix(state_dict: dict):
    return { (k if k.startswith('module.') else ('module.' + k)): v for k, v in state_dict.items() }



def save_checkpoint(epoch, model, optimizer, lr_scheduler, scaler,
                   train_loss, val_loss, is_best=False):
    """Save model checkpoint (DP-safe)."""
    model_to_save = model.module if hasattr(model, 'module') else model

    checkpoint = {
        'epoch': epoch,
        #'model_state_dict': model.state_dict(),
        'model_state_dict': model_to_save.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'lr_scheduler_state_dict': lr_scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'data_parallel': hasattr(model, 'module'),
    }

    # Save epoch checkpoint
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch+1}.pth')
    torch.save(checkpoint, checkpoint_path)

    # Save best model
    if is_best:
        best_path = os.path.join(CHECKPOINT_DIR, 'best_model.pth')
        torch.save(checkpoint, best_path)
        print(f"  üíæ Saved best model (val_loss: {val_loss:.4f})")

    return checkpoint_path


def load_checkpoint(checkpoint_path):
    """Load checkpoint and resume training (DP-safe)."""
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    #model.load_state_dict(checkpoint['model_state_dict'])
    # Load model weights (handle DP/non-DP prefixes flexibly)
    model_to_load = model.module if hasattr(model, 'module') else model
    try:
        model_to_load.load_state_dict(checkpoint['model_state_dict'], strict=True)
    except Exception as e:
        # Try stripping or adding 'module.' prefixes to match
        try:
            model_to_load.load_state_dict(_strip_module_prefix(checkpoint['model_state_dict']), strict=True)
            print("   Adjusted keys: stripped 'module.' prefix")
        except Exception:
            model_to_load.load_state_dict(_add_module_prefix(checkpoint['model_state_dict']), strict=True)
            print("   Adjusted keys: added 'module.' prefix")
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])

    # Handle both mid-epoch and end-of-epoch checkpoints
    if 'step' in checkpoint:
        # Mid-epoch checkpoint
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        print(f"  ‚úì Resumed from epoch {checkpoint['epoch'] + 1}, step {start_step}")
        print(f"  Train loss: {checkpoint['train_loss']:.4f}")
        return start_epoch, start_step
    else:
        # End-of-epoch checkpoint
        start_epoch = checkpoint['epoch'] + 1
        print(f"  ‚úì Resumed from epoch {checkpoint['epoch'] + 1}")
        print(f"  Train loss: {checkpoint['train_loss']:.4f}")
        print(f"  Val loss: {checkpoint['val_loss']:.4f}")
        return start_epoch, 0

print("‚úÖ Checkpoint functions defined")

## üìà Step 11: Training History Plotting

In [None]:

class TrainingHistory:
    """Track and plot training metrics."""

    def __init__(self):
        self.epochs = []
        self.train_losses = []
        self.val_losses = []
        self.learning_rates = []
        self.epoch_times = []
        self.best_val_loss = float('inf')
        self.best_epoch = 0

    def update(self, epoch, train_loss, val_loss, lr, epoch_time):
        self.epochs.append(epoch)
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.learning_rates.append(lr)
        self.epoch_times.append(epoch_time)

        is_best = val_loss < self.best_val_loss
        if is_best:
            self.best_val_loss = val_loss
            self.best_epoch = epoch

        return is_best

    def plot(self):
        """Plot training curves."""
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        # Loss curves
        axes[0].plot(self.epochs, self.train_losses, 'o-', label='Train Loss', linewidth=2)
        axes[0].plot(self.epochs, self.val_losses, 's-', label='Val Loss', linewidth=2)
        axes[0].axvline(x=self.best_epoch, color='r', linestyle='--',
                       label=f'Best (epoch {self.best_epoch})', alpha=0.7)
        axes[0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
        axes[0].set_ylabel('Loss (MSE)', fontsize=12, fontweight='bold')
        axes[0].set_title('üìâ Training & Validation Loss', fontsize=14, fontweight='bold')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Learning rate
        axes[1].plot(self.epochs, self.learning_rates, 'd-', color='green', linewidth=2)
        axes[1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
        axes[1].set_ylabel('Learning Rate', fontsize=12, fontweight='bold')
        axes[1].set_title('üìä Learning Rate Schedule', fontsize=14, fontweight='bold')
        axes[1].set_yscale('log')
        axes[1].grid(True, alpha=0.3)

        # Epoch times
        epoch_times_hours = [t / 3600 for t in self.epoch_times]
        axes[2].plot(self.epochs, epoch_times_hours, '^-', color='purple', linewidth=2)
        axes[2].set_xlabel('Epoch', fontsize=12, fontweight='bold')
        axes[2].set_ylabel('Time (hours)', fontsize=12, fontweight='bold')
        axes[2].set_title('‚è±Ô∏è Time per Epoch', fontsize=14, fontweight='bold')
        axes[2].grid(True, alpha=0.3)

        if epoch_times_hours:
            avg_time = np.mean(epoch_times_hours)
            axes[2].axhline(y=avg_time, color='red', linestyle='--', alpha=0.5)
            axes[2].text(0.5, 0.95, f'Avg: {avg_time:.1f}h',
                        transform=axes[2].transAxes, ha='center', va='top',
                        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        plt.tight_layout()
        plt.savefig(os.path.join(LOG_DIR, 'training_curves.png'), dpi=150, bbox_inches='tight')
        plt.show()

    def save(self):
        """Save history to JSON."""
        history_dict = {
            'epochs': self.epochs,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'learning_rates': self.learning_rates,
            'epoch_times': self.epoch_times,
            'best_val_loss': float(self.best_val_loss),
            'best_epoch': int(self.best_epoch)
        }

        history_path = os.path.join(LOG_DIR, 'training_history.json')
        with open(history_path, 'w') as f:
            json.dump(history_dict, f, indent=4)

history = TrainingHistory()
print("‚úÖ Training history initialized")

## üîÑ Step 12: Resume from Checkpoint (if needed)

**üí° Two types of checkpoints:**
1. **End-of-epoch**: `checkpoint_epoch_X.pth` (saved after each complete epoch)
2. **Mid-epoch**: `checkpoint_epoch_X_step_Y.pth` 

**üìù To resume from checkpoint:**
- Find the latest checkpoint 
- Update `RESUME_FROM_CHECKPOINT` in cell 4 (Configuration) with the full path


In [None]:
start_epoch = 0
start_step = 0

if RESUME_FROM_CHECKPOINT and os.path.exists(RESUME_FROM_CHECKPOINT):
    result = load_checkpoint(RESUME_FROM_CHECKPOINT)
    start_epoch, start_step = result[0], result[1]
    print(f"\n‚úÖ Resuming from epoch {start_epoch + 1}, step {start_step}")
else:
    print(f"\n‚úÖ Starting fresh training")

print(f"üìÖ Training epochs: {start_epoch + 1} to {NUM_EPOCHS}")

# Initialize global step counter 
#  Must account for steps already completed when resuming!
current_training_step = start_step

## üöÄ Step 13: MAIN TRAINING LOOP





In [None]:
print("\n" + "=" * 70)
print("üöÄ STARTING TRAINING")
print("=" * 70)
print(f"Training from epoch {start_epoch + 1} to {NUM_EPOCHS}")
print(f"Total samples: {len(train_dataset):,} train, {len(val_dataset):,} val")
print(f"Steps per epoch: {len(train_loader):,}")
print(f"Estimated time: {len(train_loader) * 15 * NUM_EPOCHS / 3600:.1f} hours total")
print("=" * 70)
print()

training_start = time.time()


try:
    for epoch in range(start_epoch, NUM_EPOCHS):
        epoch_start = time.time()

        # Determine if we're resuming mid-epoch
        current_start_step = start_step if epoch == start_epoch else 0

        # Train (pass start_step to skip already-processed batches when resuming)
        train_loss = train_one_epoch(model, train_loader, optimizer, scaler, device, epoch, lr_scheduler, current_start_step)

        # Reset start_step after first epoch (but keep current_training_step counting up!)
        if epoch == start_epoch:
            start_step = 0

        # Validate
        val_loss = validate(model, val_loader, device, epoch)

        # Update learning rate
        lr_scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        # Calculate epoch time
        epoch_time = time.time() - epoch_start

        # Update history
        is_best = history.update(epoch + 1, train_loss, val_loss, current_lr, epoch_time)

        # Print summary
        print(f"\n{'='*70}")
        print(f"üìä Epoch {epoch + 1}/{NUM_EPOCHS} Summary")
        print(f"{'='*70}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  LR:         {current_lr:.6f}")
        print(f"  Time:       {epoch_time/3600:.2f} hours")

        if is_best:
            print(f"  üåü NEW BEST MODEL!")

        # Estimate remaining time
        if history.epoch_times:
            avg_time = np.mean(history.epoch_times)
            remaining = (NUM_EPOCHS - epoch - 1) * avg_time
            print(f"  ‚è±Ô∏è ETA: {remaining/3600:.1f} hours ({NUM_EPOCHS - epoch - 1} epochs left)")

        print(f"{'='*70}\n")

        # Save checkpoint
        checkpoint_path = save_checkpoint(
            epoch, model, optimizer, lr_scheduler, scaler,
            train_loss, val_loss, is_best=is_best
        )
        print(f"  üíæ Checkpoint saved: {os.path.basename(checkpoint_path)}\n")

        # Save history
        history.save()

        # Plot every epoch
        history.plot()

except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted!")

    # Use dummy values if training was interrupted before first epoch completed
    emergency_train_loss = train_loss if 'train_loss' in locals() else 0.0
    emergency_val_loss = val_loss if 'val_loss' in locals() else 0.0
    emergency_epoch = epoch if 'epoch' in locals() else start_epoch

    # Use the global step counter (updated during training)
    emergency_step = current_training_step if current_training_step > 0 else None

    # Save emergency checkpoint with whatever state we have (DP-safe)
    model_to_save = model.module if hasattr(model, 'module') else model
    checkpoint = {
        'epoch': emergency_epoch,
        'model_state_dict': model_to_save.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'lr_scheduler_state_dict': lr_scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'train_loss': emergency_train_loss,
        'val_loss': emergency_val_loss,
        'interrupted': True,  # Flag to indicate this was an emergency save
        'data_parallel': hasattr(model, 'module'),
    }

    # Add step if we have it (mid-epoch interrupt)
    if emergency_step is not None:
        checkpoint['step'] = emergency_step
        emergency_filename = f'emergency_checkpoint_epoch_{emergency_epoch+1}_step_{emergency_step}.pth'
    else:
        # Interrupted before first batch or between epochs
        emergency_filename = 'emergency_checkpoint.pth'

    emergency_path = os.path.join(CHECKPOINT_DIR, emergency_filename)
    torch.save(checkpoint, emergency_path)
    print(f"üíæ Emergency checkpoint saved: {emergency_path}")
    print(f"üìç Saved at: Epoch {checkpoint['epoch'] + 1}" +
          (f", Step {emergency_step}" if emergency_step is not None else ""))
    print(f"üîÑ Resume with: RESUME_FROM_CHECKPOINT = '{emergency_path}'")

# Final statistics
training_time = time.time() - training_start

print("\n" + "=" * 70)
print("üéâ TRAINING COMPLETE!")
print("=" * 70)
print(f"Total time: {training_time/3600:.2f} hours")
print(f"Best val loss: {history.best_val_loss:.4f} (epoch {history.best_epoch})")
print(f"Best model: {os.path.join(CHECKPOINT_DIR, 'best_model.pth')}")
print("=" * 70)

## üìä Step 14: Final Training Curves

In [None]:
# Plot final curves
history.plot()

# Print summary statistics
print("\n" + "=" * 70)
print("üìà TRAINING SUMMARY")
print("=" * 70)
print(f"Best epoch: {history.best_epoch}")
print(f"Best val loss: {history.best_val_loss:.4f}")
print(f"Final train loss: {history.train_losses[-1]:.4f}")
print(f"Final val loss: {history.val_losses[-1]:.4f}")
print(f"Average epoch time: {np.mean(history.epoch_times)/3600:.2f} hours")
print(f"Total training time: {sum(history.epoch_times)/3600:.2f} hours")
print("=" * 70)

## üìÑ Step 16: Save Training Summary

In [None]:
# Create comprehensive summary
summary = {
    'completed': str(datetime.now()),
    'resolution': f'{IMG_SIZE}√ó{IMG_SIZE}',
    'total_epochs': NUM_EPOCHS,
    'best_epoch': int(history.best_epoch),
    'best_val_loss': float(history.best_val_loss),
    'final_train_loss': float(history.train_losses[-1]),
    'final_val_loss': float(history.val_losses[-1]),
    'total_training_hours': sum(history.epoch_times) / 3600,
    'avg_epoch_hours': np.mean(history.epoch_times) / 3600,
    'model_parameters': total_params,
    'config': {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'num_timesteps': NUM_TRAIN_TIMESTEPS,
        'mixed_precision': USE_MIXED_PRECISION,
        'train_samples': len(train_dataset),
        'val_samples': len(val_dataset),
    }
}

summary_path = os.path.join(LOG_DIR, 'training_summary.json')
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=4)

print(f"‚úÖ Summary saved: {summary_path}")

print("\n" + "=" * 70)
print("üéâ ALL DONE!")
print("=" * 70)
print(f"\nüìÅ Output Files:")
print(f"  - Best model: {os.path.join(CHECKPOINT_DIR, 'best_model.pth')}")
print(f"  - Checkpoints: {CHECKPOINT_DIR}")
print(f"  - Training curves: {os.path.join(LOG_DIR, 'training_curves.png')}")
print(f"  - History: {os.path.join(LOG_DIR, 'training_history.json')}")
print(f"  - Summary: {summary_path}")
print("\nüí° Next: Use the best model for inference on new images!")
print("=" * 70)