# ResNet Autoencoder Training

Train ResNet autoencoder at 4x and 8x compression ratios to complete the rate-distortion curve.

**Note:** 16x already trained in sweep_all_16x.ipynb (21.13 dB PSNR, 0.739 SSIM)

**Proven hyperparameters** (from sweep_all_16x):
- LR=1e-4, AdamW, ReduceLROnPlateau
- base_channels=64 (22.4M params)
- 35 epochs, patience=12

**Monitor with TensorBoard:**
```bash
tensorboard --logdir=runs
```

## 1. Setup and Imports

In [None]:
import sys
import gc
import time
import json
import random
from pathlib import Path
from datetime import datetime

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Project imports
from src.data.datamodule import SARDataModule
from src.models.resnet_autoencoder import ResNetAutoencoder
from src.losses.combined import CombinedLoss
from src.training.trainer import Trainer

print(f"Project root: {project_root}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Configuration

**Training Modes:**
- `SINGLE_RATIO`: Train at one specific compression ratio (default: 16x)
- `MULTI_RATIO`: Train at 4x, 8x, 16x for rate-distortion curves

In [None]:
# ============================================================
# CONFIGURATION - ResNet Autoencoder (Proven Parameters)
# ============================================================

# =========================
# TRAINING MODE
# =========================
MULTI_RATIO = True   # True = train 4x, 8x | False = single ratio only
TARGET_RATIO = 8     # Used when MULTI_RATIO=False

# Data settings
DATA_PATH = "D:/Projects/CNNAutoencoderProject/data/patches/metadata.npy"
BATCH_SIZE = 16       # ResNet b=64 needs smaller batch (OOM at 32+)
NUM_WORKERS = 4       # Parallel data loading
VAL_FRACTION = 0.1    # 10% validation split
TRAIN_SUBSET = 0.10   # Use 10% of data for faster iteration

# Model settings - PROVEN from sweep_all_16x
BASE_CHANNELS = 64    # Full ResNet (22.4M params) - best results

# Loss settings
MSE_WEIGHT = 0.5
SSIM_WEIGHT = 0.5

# Training settings - PROVEN from sweep_all_16x
EPOCHS = 35
LEARNING_RATE = 1e-4          # Proven optimal
OPTIMIZER = 'adamw'           # AdamW > Adam
SCHEDULER = 'plateau'         # ReduceLROnPlateau
EARLY_STOPPING_PATIENCE = 12
LR_PATIENCE = 10
LR_FACTOR = 0.5
MAX_GRAD_NORM = 1.0
USE_AMP = True                # Mixed precision

# Compression ratios to train (16x already done in sweep_all_16x)
if MULTI_RATIO:
    # latent_channels: 64->4x, 32->8x
    SWEEP_CONFIGS = [
        {'latent_channels': 64, 'ratio': 4},
        {'latent_channels': 32, 'ratio': 8},
    ]
else:
    # Single ratio training
    lc = {4: 64, 8: 32, 16: 16}.get(TARGET_RATIO, 32)
    SWEEP_CONFIGS = [{'latent_channels': lc, 'ratio': TARGET_RATIO}]

# Results folder
RESULTS_DIR = Path('results')
RESULTS_DIR.mkdir(exist_ok=True)

print(f"{'='*60}")
print(f"MODE: {'MULTI-RATIO (4x, 8x)' if MULTI_RATIO else f'SINGLE RATIO ({TARGET_RATIO}x)'}")
print(f"{'='*60}")
print(f"Model: ResNetAutoencoder (base_channels={BASE_CHANNELS})")
print(f"Training: {TRAIN_SUBSET*100:.0f}% data, {EPOCHS} epochs, batch={BATCH_SIZE}")
print(f"Optimizer: {OPTIMIZER.upper()}, LR={LEARNING_RATE}, scheduler={SCHEDULER}")
print(f"Loss: {MSE_WEIGHT} MSE + {SSIM_WEIGHT} SSIM")
print(f"\nRuns planned:")
for cfg in SWEEP_CONFIGS:
    name = f"resnet_c{cfg['latent_channels']}_b{BASE_CHANNELS}_cr{cfg['ratio']}x"
    print(f"  {name}")
print(f"\nNote: 16x already trained in sweep_all_16x.ipynb")

MODE: MULTI-RATIO (4x, 8x)
Model: ResNetAutoencoder (base_channels=64)
Training: 10% data, 35 epochs, batch=16
Optimizer: ADAMW, LR=0.0001, scheduler=plateau
Loss: 0.5 MSE + 0.5 SSIM

Runs planned:
  resnet_c64_b64_cr4x
  resnet_c32_b64_cr8x

Note: 16x already trained in sweep_all_16x.ipynb


## 3. Load Data

In [None]:
print("Loading data...")
dm = SARDataModule(
    patches_path=DATA_PATH,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    val_fraction=VAL_FRACTION,
)

if TRAIN_SUBSET < 1.0:
    full_train_size = len(dm.train_dataset)
    train_subset_size = int(full_train_size * TRAIN_SUBSET)
    train_indices = random.sample(range(full_train_size), train_subset_size)
    dm.train_dataset = torch.utils.data.Subset(dm.train_dataset, train_indices)

    full_val_size = len(dm.val_dataset)
    val_subset_size = int(full_val_size * TRAIN_SUBSET)
    val_indices = random.sample(range(full_val_size), val_subset_size)
    dm.val_dataset = torch.utils.data.Subset(dm.val_dataset, val_indices)

    print(f"Using {TRAIN_SUBSET*100:.0f}% subset:")
    print(f"  Train: {train_subset_size:,} of {full_train_size:,}")
    print(f"  Val: {val_subset_size:,} of {full_val_size:,}")

print(f"\nPreprocessing params: {dm.preprocessing_params}")

Loading data...
Loading metadata from D:\Projects\CNNAutoencoderProject\data\patches\metadata.npy
Total patches: 696277
Train: 626650, Val: 69627
Using 10% subset:
  Train: 62,665 of 626,650
  Val: 6,962 of 69,627

Preprocessing params: {'vmin': np.float32(14.768799), 'vmax': np.float32(24.54073)}


In [None]:
# Verify a sample batch
sample_batch = next(iter(dm.train_dataloader()))
print(f"Sample batch shape: {sample_batch.shape}")
print(f"Sample batch dtype: {sample_batch.dtype}")
print(f"Sample batch range: [{sample_batch.min():.4f}, {sample_batch.max():.4f}]")

## 4. Find Optimal Learning Rate (Optional)

Run a learning rate range test to find the optimal LR. This sweeps from 1e-7 to 1 and plots loss vs LR.

**Skip this if using proven LR (1e-4 from sweep_all_16x).**

In [None]:
# LR Finder - Run this to find optimal learning rate for a new architecture
# Skip if using proven LR (1e-4)

RUN_LR_FINDER = False  # Set to True to run LR finder

if RUN_LR_FINDER:
    from tqdm import tqdm
    
    def find_lr_minimal(model, sample_batch, loss_fn, device='cuda',
                        start_lr=1e-7, end_lr=1, num_iter=100):
        """
        Minimal LR range test. Sweeps LR and records smoothed loss.
        Returns (lrs, losses) arrays for plotting.
        """
        print("Saving initial state...")
        initial_state = {k: v.clone() for k, v in model.state_dict().items()}
        
        print("Moving model to device...")
        model.train()
        model.to(device)
        
        print("Moving batch to device...")
        batch = sample_batch.to(device)
        print(f"Batch on device: {batch.device}, shape: {batch.shape}")
        
        print("Creating optimizer...")
        optimizer = torch.optim.AdamW(model.parameters(), lr=start_lr, weight_decay=1e-4)
        gamma = (end_lr / start_lr) ** (1 / num_iter)
        
        lrs, losses = [], []
        smoothed_loss = None
        best_loss = float('inf')
        
        print("Starting LR sweep...")
        for i in tqdm(range(num_iter), desc="LR sweep"):
            optimizer.zero_grad()
            output, _ = model(batch)
            loss, _ = loss_fn(output, batch)
            
            if torch.isnan(loss) or loss.item() > 4 * best_loss:
                print(f"\nStopped at iter {i}: lr={optimizer.param_groups[0]['lr']:.2e}")
                break
            
            lrs.append(optimizer.param_groups[0]['lr'])
            
            if smoothed_loss is None:
                smoothed_loss = loss.item()
            else:
                smoothed_loss = 0.1 * loss.item() + 0.9 * smoothed_loss
            losses.append(smoothed_loss)
            best_loss = min(best_loss, smoothed_loss)
            
            loss.backward()
            optimizer.step()
            
            for pg in optimizer.param_groups:
                pg['lr'] *= gamma
        
        print("Restoring initial state...")
        model.load_state_dict(initial_state)
        return np.array(lrs), np.array(losses)
    
    # Create fresh model for LR finder
    lr_model = ResNetAutoencoder(
        latent_channels=SWEEP_CONFIGS[0]['latent_channels'],
        base_channels=BASE_CHANNELS,
        in_channels=1,
    )
    lr_loss_fn = CombinedLoss(mse_weight=MSE_WEIGHT, ssim_weight=SSIM_WEIGHT)
    
    # Run LR finder with small batch
    lrs, losses = find_lr_minimal(
        model=lr_model,
        sample_batch=sample_batch[:4],  # Just 4 samples
        loss_fn=lr_loss_fn,
        num_iter=100
    )
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.semilogx(lrs, losses)
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss (smoothed)')
    plt.title('LR Range Test - ResNet Autoencoder')
    plt.axvline(x=1e-4, color='r', linestyle='--', label='1e-4 (proven)')
    plt.axvline(x=3e-4, color='g', linestyle='--', label='3e-4')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Find suggested LR (steepest descent)
    if len(losses) > 10:
        gradients = np.gradient(losses)
        min_grad_idx = np.argmin(gradients)
        suggested_lr = lrs[min_grad_idx]
        print(f"\nSuggested LR (steepest descent): {suggested_lr:.2e}")
        print(f"Proven LR from sweep: 1e-4")
    
    del lr_model, lr_loss_fn
    gc.collect()
    torch.cuda.empty_cache()
else:
    print("Skipping LR finder - using proven LR=1e-4 from sweep_all_16x")

## 4. Train Model(s)

Trains at all configured compression ratios. Skips existing checkpoints.

In [None]:
import glob

results = []
trainers = {}  # Store trainers for later analysis

for cfg in tqdm(SWEEP_CONFIGS, desc="Training Progress", unit="run"):
    lc = cfg['latent_channels']
    ratio = cfg['ratio']
    run_name = f"resnet_c{lc}_b{BASE_CHANNELS}_cr{ratio}x"
    
    # Check for existing checkpoint
    existing = glob.glob(f'checkpoints/{run_name}_*/best.pth')
    if existing:
        checkpoint_path = sorted(existing)[-1]  # Most recent
        print(f"\n{'='*70}")
        print(f"  SKIPPING {run_name} - checkpoint exists")
        print(f"  {checkpoint_path}")
        print(f"{'='*70}")
        
        # Load checkpoint to extract metrics
        ckpt = torch.load(checkpoint_path, weights_only=False)
        
        # Quick validation to get current metrics
        model = ResNetAutoencoder(latent_channels=lc, base_channels=BASE_CHANNELS, in_channels=1)
        model.load_state_dict(ckpt['model_state_dict'])
        model.eval().cuda()
        params = model.count_parameters()
        
        loss_fn = CombinedLoss(mse_weight=MSE_WEIGHT, ssim_weight=SSIM_WEIGHT)
        val_loader = dm.val_dataloader()
        
        val_losses, val_psnrs, val_ssims = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.cuda()
                output, _ = model(batch)
                loss, metrics = loss_fn(output, batch)
                val_losses.append(loss.item())
                val_psnrs.append(metrics['psnr'])
                val_ssims.append(metrics['ssim'])
        
        result = {
            'run_name': Path(checkpoint_path).parent.name,
            'latent_channels': lc,
            'compression_ratio': ratio,
            'parameters': params['total'],
            'epochs_trained': ckpt.get('epoch', 0) + 1,
            'checkpoint': checkpoint_path,
            'best_val_loss': sum(val_losses) / len(val_losses),
            'best_psnr': sum(val_psnrs) / len(val_psnrs),
            'best_ssim': sum(val_ssims) / len(val_ssims),
            'skipped': True,
        }
        results.append(result)
        print(f"  Validated: PSNR={result['best_psnr']:.2f} dB, SSIM={result['best_ssim']:.4f}")
        
        del model, loss_fn, val_loader
        gc.collect()
        torch.cuda.empty_cache()
        continue
    
    print(f"\n{'='*70}")
    print(f"  {run_name} | {ratio}x compression")
    print(f"{'='*70}")
    
    # Create model
    model = ResNetAutoencoder(
        latent_channels=lc,
        base_channels=BASE_CHANNELS,
        in_channels=1,
    )
    params = model.count_parameters()
    print(f"  Parameters: {params['total']:,}")
    
    # Loss
    loss_fn = CombinedLoss(mse_weight=MSE_WEIGHT, ssim_weight=SSIM_WEIGHT)
    
    # Trainer config
    config = {
        'learning_rate': LEARNING_RATE,
        'optimizer': OPTIMIZER,
        'scheduler': SCHEDULER,
        'lr_patience': LR_PATIENCE,
        'lr_factor': LR_FACTOR,
        'max_grad_norm': MAX_GRAD_NORM,
        'use_amp': USE_AMP,
        'notebook': True,
        'run_name': run_name,
        'preprocessing_params': dm.preprocessing_params,
        'model_type': 'resnet',
        'latent_channels': lc,
        'base_channels': BASE_CHANNELS,
        'mse_weight': MSE_WEIGHT,
        'ssim_weight': SSIM_WEIGHT,
        'batch_size': BATCH_SIZE,
        'compression_ratio': ratio,
    }
    
    train_loader = dm.train_dataloader()
    val_loader = dm.val_dataloader()
    
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=loss_fn,
        config=config,
    )
    
    # Train
    t0 = time.time()
    history = trainer.train(
        epochs=EPOCHS,
        early_stopping_patience=EARLY_STOPPING_PATIENCE,
    )
    elapsed = time.time() - t0
    
    # Collect results
    result = {
        'run_name': trainer.run_name,
        'latent_channels': lc,
        'compression_ratio': ratio,
        'parameters': params['total'],
        'epochs_trained': len(history),
        'elapsed_min': elapsed / 60,
        'checkpoint': str(trainer.checkpoint_dir / 'best.pth'),
        'log_dir': str(trainer.log_dir),
        'skipped': False,
    }
    
    if history:
        best_epoch = min(history, key=lambda h: h.get('val_loss', float('inf')))
        result['best_val_loss'] = best_epoch.get('val_loss')
        result['best_psnr'] = best_epoch.get('val_psnr')
        result['best_ssim'] = best_epoch.get('val_ssim')
        result['history'] = history
    
    results.append(result)
    trainers[run_name] = trainer
    
    psnr_str = f"{result.get('best_psnr', 0):.2f} dB" if result.get('best_psnr') else "N/A"
    print(f"\n  Done: {psnr_str} | {elapsed/60:.1f} min")
    
    # Cleanup GPU
    del model, loss_fn, train_loader, val_loader
    gc.collect()
    torch.cuda.empty_cache()

print("\nTraining complete!")

Training Progress:   0%|          | 0/2 [00:00<?, ?run/s]


  resnet_c64_b64_cr4x | 4x compression


2026-01-28 18:04:24,472 - Log directory: runs\resnet_c64_b64_cr4x_20260128_180424
2026-01-28 18:04:24,473 - Checkpoint directory: checkpoints\resnet_c64_b64_cr4x_20260128_180424
2026-01-28 18:04:24,474 - Mixed Precision (AMP): enabled
2026-01-28 18:04:24,474 - Starting training for 35 epochs
2026-01-28 18:04:24,475 - Model: ResNetAutoencoder
2026-01-28 18:04:24,475 - Config: {'learning_rate': 0.0001, 'optimizer': 'adamw', 'scheduler': 'plateau', 'lr_patience': 10, 'lr_factor': 0.5, 'max_grad_norm': 1.0, 'use_amp': True, 'notebook': True, 'run_name': 'resnet_c64_b64_cr4x', 'preprocessing_params': {'vmin': np.float32(14.768799), 'vmax': np.float32(24.54073)}, 'model_type': 'resnet', 'latent_channels': 64, 'base_channels': 64, 'mse_weight': 0.5, 'ssim_weight': 0.5, 'batch_size': 16, 'compression_ratio': 4}


  Parameters: 22,922,241
Using device: cuda
GPU memory: 1880MB / 8192MB (23% used, 6.2 GB free)
Using AdamW optimizer with weight_decay=1e-05
Using ReduceLROnPlateau: patience=10
Mixed Precision (AMP) enabled - ~2x training speedup


Epoch 1 [Train]:   0%|          | 0/3916 [00:00<?, ?it/s]

KeyboardInterrupt: 

## 5. Results Summary

In [None]:
print(f"\n{'Run':<45} {'Ratio':>6} {'Params':>10} {'PSNR':>10} {'SSIM':>10}")
print("-" * 85)

for r in results:
    name = r['run_name']
    ratio = f"{r['compression_ratio']}x"
    params_str = f"{r['parameters']/1e6:.1f}M"
    psnr = f"{r['best_psnr']:.2f} dB" if r.get('best_psnr') else "N/A"
    ssim = f"{r['best_ssim']:.4f}" if r.get('best_ssim') else "N/A"
    print(f"{name:<45} {ratio:>6} {params_str:>10} {psnr:>10} {ssim:>10}")

## 6. Training History Visualization

In [None]:
# Plot training curves for each run
for r in results:
    # Get history from result or load from checkpoint
    if 'history' in r and r['history']:
        history = r['history']
    else:
        # Load from checkpoint
        ckpt = torch.load(r['checkpoint'], weights_only=False)
        history = ckpt.get('history', [])
    
    if not history:
        print(f"No history available for {r['run_name']}")
        continue
    
    epochs = [h['epoch'] + 1 for h in history]
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Loss
    ax = axes[0, 0]
    ax.plot(epochs, [h['train_loss'] for h in history], label='Train')
    ax.plot(epochs, [h['val_loss'] for h in history], label='Val')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss')
    ax.legend()
    ax.grid(True)
    
    # PSNR
    ax = axes[0, 1]
    ax.plot(epochs, [h['train_psnr'] for h in history], label='Train')
    ax.plot(epochs, [h['val_psnr'] for h in history], label='Val')
    ax.axhline(y=25, color='r', linestyle='--', label='Target (25 dB)')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('PSNR (dB)')
    ax.set_title('PSNR')
    ax.legend()
    ax.grid(True)
    
    # SSIM
    ax = axes[1, 0]
    ax.plot(epochs, [h['train_ssim'] for h in history], label='Train')
    ax.plot(epochs, [h['val_ssim'] for h in history], label='Val')
    ax.axhline(y=0.85, color='r', linestyle='--', label='Target (0.85)')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('SSIM')
    ax.set_title('SSIM')
    ax.legend()
    ax.grid(True)
    
    # Learning Rate
    ax = axes[1, 1]
    ax.plot(epochs, [h['learning_rate'] for h in history])
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Learning Rate')
    ax.set_title('Learning Rate Schedule')
    ax.set_yscale('log')
    ax.grid(True)
    
    plt.suptitle(f"ResNet Training: {r['compression_ratio']}x Compression", fontsize=14)
    plt.tight_layout()
    
    # Save to results folder with clean model name (no timestamp)
    clean_name = f"resnet_c{r['latent_channels']}_b{BASE_CHANNELS}_cr{r['compression_ratio']}x"
    model_results_dir = RESULTS_DIR / clean_name
    model_results_dir.mkdir(exist_ok=True)
    save_path = model_results_dir / 'training_curves.png'
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {save_path}")

## 7. Checkpoint Verification

In [None]:
for r in results:
    checkpoint_path = r['checkpoint']
    
    if not Path(checkpoint_path).exists():
        print(f"Checkpoint not found: {checkpoint_path}")
        continue
    
    ckpt = torch.load(checkpoint_path, weights_only=False)
    
    print(f"\n{'='*60}")
    print(f"Checkpoint: {r['run_name']}")
    print(f"{'='*60}")
    print(f"  Path: {checkpoint_path}")
    print(f"  Keys: {list(ckpt.keys())}")
    print(f"  Epoch: {ckpt.get('epoch', 'N/A')}")
    print(f"  Best val loss: {ckpt.get('best_val_loss', 'N/A'):.4f}")
    print(f"  Preprocessing params: {ckpt.get('preprocessing_params', 'MISSING')}")
    
    # Test loading into fresh model
    test_model = ResNetAutoencoder(
        latent_channels=r['latent_channels'],
        base_channels=BASE_CHANNELS,
        in_channels=1,
    )
    test_model.load_state_dict(ckpt['model_state_dict'])
    test_model.eval()
    
    x = torch.rand(1, 1, 256, 256)
    with torch.no_grad():
        x_hat, z = test_model(x)
    
    print(f"  Inference: {x.shape} -> {z.shape} -> {x_hat.shape}")
    print(f"  Verification: PASS")
    
    # Save checkpoint info to results with clean model name
    clean_name = f"resnet_c{r['latent_channels']}_b{BASE_CHANNELS}_cr{r['compression_ratio']}x"
    model_results_dir = RESULTS_DIR / clean_name
    model_results_dir.mkdir(exist_ok=True)
    
    info = {
        'checkpoint_path': str(checkpoint_path),
        'epoch': ckpt.get('epoch'),
        'best_val_loss': ckpt.get('best_val_loss'),
        'preprocessing_params': {k: float(v) for k, v in ckpt.get('preprocessing_params', {}).items()},
        'config': ckpt.get('config', {}),
    }
    
    with open(model_results_dir / 'checkpoint_info.json', 'w') as f:
        json.dump(info, f, indent=2, default=str)
    
    del test_model

## 8. Sample Reconstructions

In [None]:
for r in results:
    # Load model
    ckpt = torch.load(r['checkpoint'], weights_only=False)
    model = ResNetAutoencoder(
        latent_channels=r['latent_channels'],
        base_channels=BASE_CHANNELS,
        in_channels=1,
    )
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    
    # Get validation batch
    val_batch = next(iter(dm.val_dataloader()))[:4].to(device)
    
    with torch.no_grad():
        reconstructed, latent = model(val_batch)
    
    originals = val_batch.cpu().numpy()
    reconstructions = reconstructed.cpu().numpy()
    
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    
    for i in range(4):
        axes[0, i].imshow(originals[i, 0], cmap='gray')
        axes[0, i].set_title(f'Original {i+1}')
        axes[0, i].axis('off')
        
        axes[1, i].imshow(reconstructions[i, 0], cmap='gray')
        axes[1, i].set_title(f'Reconstructed {i+1}')
        axes[1, i].axis('off')
        
        diff = abs(originals[i, 0] - reconstructions[i, 0])
        axes[2, i].imshow(diff, cmap='hot', vmin=0, vmax=0.5)
        axes[2, i].set_title(f'Difference {i+1}')
        axes[2, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=12)
    axes[1, 0].set_ylabel('Reconstructed', fontsize=12)
    axes[2, 0].set_ylabel('Difference', fontsize=12)
    
    plt.suptitle(f"ResNet Reconstructions ({r['compression_ratio']}x, PSNR: {r['best_psnr']:.2f} dB)", fontsize=14)
    plt.tight_layout()
    
    # Save with clean model name
    clean_name = f"resnet_c{r['latent_channels']}_b{BASE_CHANNELS}_cr{r['compression_ratio']}x"
    model_results_dir = RESULTS_DIR / clean_name
    model_results_dir.mkdir(exist_ok=True)
    save_path = model_results_dir / 'sample_reconstructions.png'
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {save_path}")
    
    del model
    gc.collect()
    torch.cuda.empty_cache()

## 9. Rate-Distortion Curve (Multi-Ratio Mode)

In [None]:
if len(results) > 1:
    # Sort by compression ratio
    sorted_results = sorted(results, key=lambda x: x['compression_ratio'])
    
    ratios = [r['compression_ratio'] for r in sorted_results]
    psnrs = [r['best_psnr'] for r in sorted_results]
    ssims = [r['best_ssim'] for r in sorted_results]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # PSNR
    ax1.plot(ratios, psnrs, 'go-', markersize=10, linewidth=2, label='ResNet')
    for r_val, p in zip(ratios, psnrs):
        ax1.annotate(f'{p:.1f}', (r_val, p), textcoords='offset points',
                     xytext=(0, 10), ha='center', fontsize=9)
    ax1.axhline(y=25, color='r', linestyle='--', alpha=0.5, label='Target (25 dB)')
    ax1.set_xlabel('Compression Ratio', fontsize=12)
    ax1.set_ylabel('PSNR (dB)', fontsize=12)
    ax1.set_title('Rate-Distortion: PSNR', fontsize=14)
    ax1.set_xscale('log', base=2)
    ax1.set_xticks(ratios)
    ax1.set_xticklabels([f'{int(r)}x' for r in ratios])
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.invert_xaxis()
    
    # SSIM
    ax2.plot(ratios, ssims, 'go-', markersize=10, linewidth=2, label='ResNet')
    for r_val, s in zip(ratios, ssims):
        ax2.annotate(f'{s:.3f}', (r_val, s), textcoords='offset points',
                     xytext=(0, 10), ha='center', fontsize=9)
    ax2.axhline(y=0.85, color='r', linestyle='--', alpha=0.5, label='Target (0.85)')
    ax2.set_xlabel('Compression Ratio', fontsize=12)
    ax2.set_ylabel('SSIM', fontsize=12)
    ax2.set_title('Rate-Distortion: SSIM', fontsize=14)
    ax2.set_xscale('log', base=2)
    ax2.set_xticks(ratios)
    ax2.set_xticklabels([f'{int(r)}x' for r in ratios])
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.invert_xaxis()
    
    plt.suptitle('ResNet Rate-Distortion Curves', fontsize=16, y=1.02)
    plt.tight_layout()
    
    datestamp = datetime.now().strftime('%Y%m%d')
    save_path = RESULTS_DIR / f'resnet_rate_distortion_{datestamp}.png'
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: {save_path}")
else:
    print("Single ratio mode - no R-D curve to plot")

## 10. Save Results to JSON

In [None]:
datestamp = datetime.now().strftime('%Y%m%d')

# Remove history from results for JSON (too large)
results_for_json = []
for r in results:
    r_copy = {k: v for k, v in r.items() if k != 'history'}
    results_for_json.append(r_copy)

output = {
    'sweep_type': 'resnet_training',
    'timestamp': datetime.now().isoformat(),
    'config': {
        'model': 'resnet',
        'base_channels': BASE_CHANNELS,
        'learning_rate': LEARNING_RATE,
        'optimizer': OPTIMIZER,
        'scheduler': SCHEDULER,
        'epochs': EPOCHS,
        'batch_size': BATCH_SIZE,
        'train_subset': TRAIN_SUBSET,
    },
    'results': results_for_json,
}

output_path = RESULTS_DIR / f'resnet_training_{datestamp}.json'
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)

print(f"Results saved to: {output_path}")

## 11. Comparison with Baseline

Load baseline results and compare.

In [None]:
# Load baseline results if available
baseline_json = sorted(Path('runs').glob('sweep_baseline_ratios*.json'))
if baseline_json:
    with open(baseline_json[-1]) as f:
        baseline_sweep = json.load(f)
    baseline_results = baseline_sweep['results']
    
    print(f"\n{'='*80}")
    print("BASELINE vs RESNET COMPARISON")
    print(f"{'='*80}")
    print(f"{'Ratio':<8} {'Baseline PSNR':>15} {'ResNet PSNR':>15} {'Delta':>12}")
    print("-" * 55)
    
    for r in results:
        ratio = r['compression_ratio']
        resnet_psnr = r['best_psnr']
        
        # Find matching baseline
        bl = next((b for b in baseline_results if abs(b['compression_ratio'] - ratio) < 1), None)
        if bl:
            bl_psnr = bl['best_psnr']
            delta = resnet_psnr - bl_psnr
            print(f"{ratio}x{'':<5} {bl_psnr:>15.2f} {resnet_psnr:>15.2f} {delta:>+12.2f}")
        else:
            print(f"{ratio}x{'':<5} {'N/A':>15} {resnet_psnr:>15.2f} {'-':>12}")
else:
    print("No baseline results found. Run sweep_baseline_ratios.ipynb first.")

## 15. Best & Worst Reconstructions

Find the best and worst reconstruction cases to understand model strengths and failure modes.

In [None]:
# Find best and worst reconstruction casesfrom src.evaluation.evaluator import Evaluatorprint("Finding best and worst reconstructions...")# Create evaluatorevaluator = Evaluator(model, device=device)# Find best and worst casesn_cases = 8worst_cases = evaluator.find_failure_cases(val_loader, n_worst=n_cases)best_cases = evaluator.find_best_cases(val_loader, n_best=n_cases)print(f"Worst reconstructions (highest MSE):")for i, case in enumerate(worst_cases[:5]):    print(f"  {i+1}. MSE = {case['mse']:.6f}")print(f"Best reconstructions (lowest MSE):")for i, case in enumerate(best_cases[:5]):    print(f"  {i+1}. MSE = {case['mse']:.6f}")# Plot best reconstructionsfig, axes = plt.subplots(3, n_cases, figsize=(2.5 * n_cases, 8))fig.suptitle('Best Reconstructions (Top: Original, Middle: Reconstructed, Bottom: Difference)', fontsize=14)for i in range(min(n_cases, len(best_cases))):    orig = best_cases[i]['original'].numpy().squeeze()    recon = best_cases[i]['reconstructed'].numpy().squeeze()    diff = np.abs(orig - recon)        axes[0, i].imshow(orig, cmap='gray', vmin=0, vmax=1)    axes[0, i].axis('off')    axes[0, i].set_title(f"MSE: {best_cases[i]['mse']:.4f}", fontsize=8)        axes[1, i].imshow(recon, cmap='gray', vmin=0, vmax=1)    axes[1, i].axis('off')        axes[2, i].imshow(diff, cmap='hot', vmin=0, vmax=0.2)    axes[2, i].axis('off')plt.tight_layout()plt.savefig(MODEL_RESULTS_DIR / 'best_reconstructions.png', dpi=150, bbox_inches='tight')plt.show()# Plot worst reconstructionsfig, axes = plt.subplots(3, n_cases, figsize=(2.5 * n_cases, 8))fig.suptitle('Worst Reconstructions (Failure Cases)', fontsize=14)for i in range(min(n_cases, len(worst_cases))):    orig = worst_cases[i]['original'].numpy().squeeze()    recon = worst_cases[i]['reconstructed'].numpy().squeeze()    diff = np.abs(orig - recon)        axes[0, i].imshow(orig, cmap='gray', vmin=0, vmax=1)    axes[0, i].axis('off')    axes[0, i].set_title(f"MSE: {worst_cases[i]['mse']:.4f}", fontsize=8)        axes[1, i].imshow(recon, cmap='gray', vmin=0, vmax=1)    axes[1, i].axis('off')        axes[2, i].imshow(diff, cmap='hot', vmin=0, vmax=0.3)    axes[2, i].axis('off')plt.tight_layout()plt.savefig(MODEL_RESULTS_DIR / 'worst_reconstructions.png', dpi=150, bbox_inches='tight')plt.show()print(f"Saved: {MODEL_RESULTS_DIR / 'best_reconstructions.png'}")print(f"Saved: {MODEL_RESULTS_DIR / 'worst_reconstructions.png'}")

## 16. Latent Channel Visualization

Visualize activations in the latent space to understand what the model encodes.

In [None]:
# Visualize latent channel activationsprint("Latent space visualization (first sample):")with torch.no_grad():    sample = next(iter(val_loader))[:1].to(device)    _, latent = model(sample)    latent_np = latent[0].cpu().numpy()n_channels = min(16, latent_np.shape[0])fig, axes = plt.subplots(4, 4, figsize=(12, 12))fig.suptitle(f'Latent Channel Activations ({latent_np.shape[1]}x{latent_np.shape[2]} spatial, {latent_np.shape[0]} channels)', fontsize=14)for i, ax in enumerate(axes.flatten()):    if i < n_channels:        channel = latent_np[i]        vmax = max(abs(channel.min()), abs(channel.max()), 0.1)        im = ax.imshow(channel, cmap='RdBu_r', vmin=-vmax, vmax=vmax)        ax.set_title(f'Ch {i}: std={channel.std():.2f}', fontsize=8)        ax.axis('off')    else:        ax.axis('off')plt.tight_layout()plt.savefig(MODEL_RESULTS_DIR / 'latent_channels.png', dpi=150, bbox_inches='tight')plt.show()# Compute latent statisticsactive_channels = sum(1 for i in range(latent_np.shape[0]) if latent_np[i].std() > 0.01)print(f"Latent statistics:")print(f"  Shape: {latent_np.shape}")print(f"  Active channels (std > 0.01): {active_channels}/{latent_np.shape[0]}")print(f"  Mean activation: {latent_np.mean():.4f}")print(f"  Std activation: {latent_np.std():.4f}")print(f"Saved: {MODEL_RESULTS_DIR / 'latent_channels.png'}")

---

## Done!

**Results saved to:** `results/resnet_*/`

**View in TensorBoard:**
```bash
tensorboard --logdir=runs
```

**Next steps:**
1. Compare with baseline at same ratios
2. Run full evaluation with SAR metrics
3. Proceed to Phase 6 (Final Experiments)