# Sweep: Baseline vs ResNet at 16x Compression

Fair architecture comparison with correct hyperparameters.

**Previous failures:** LR too high (7e-3), base_channels too small (32), OneCycleLR overshoot.  
**This sweep:** LR=1e-4, AdamW, ReduceLROnPlateau, base_channels=64.

**Monitor with TensorBoard:**
```bash
tensorboard --logdir=runs
```

## 1. Setup

In [1]:
import sys
import gc
import time
from pathlib import Path

project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import torch
from tqdm.auto import tqdm

from src.data.datamodule import SARDataModule
from src.models.autoencoder import SARAutoencoder
from src.models.resnet_autoencoder import ResNetAutoencoder
from src.losses.combined import CombinedLoss
from src.training.trainer import Trainer

print(f"Project root: {project_root}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

Project root: d:\Projects\CNNAutoencoderProject
PyTorch: 2.5.1+cu121
CUDA: True
GPU: NVIDIA GeForce RTX 3070
VRAM: 8.0 GB


## 2. Sweep Configuration

In [2]:
# ============================================================
# SWEEP CONFIGURATION
# ============================================================

# Data
DATA_PATH = "D:/Projects/CNNAutoencoderProject/data/patches/metadata.npy"
BATCH_SIZE = 16
NUM_WORKERS = 4
VAL_FRACTION = 0.1
TRAIN_SUBSET = 0.10

# Fixed training params (proven with baseline@16x -> 20.47 dB)
LATENT_CHANNELS = 16  # 16x compression for all
EPOCHS = 35
LEARNING_RATE = 1e-4
EARLY_STOPPING_PATIENCE = 12
LR_PATIENCE = 10
LR_FACTOR = 0.5
MSE_WEIGHT = 0.5
SSIM_WEIGHT = 0.5

# Sweep variable: model architecture
SWEEP_CONFIGS = [
    {'name': 'baseline',  'cls': SARAutoencoder,       'base_channels': 64, 'extra_kwargs': {}},
    {'name': 'resnet',    'cls': ResNetAutoencoder,     'base_channels': 64, 'extra_kwargs': {'in_channels': 1}},
]

# ============================================================

compression_ratio = (256 * 256) / (16 * 16 * LATENT_CHANNELS)
print(f"Compression: {compression_ratio:.0f}x (LC={LATENT_CHANNELS})")
print(f"LR={LEARNING_RATE}, Epochs={EPOCHS}, Patience={EARLY_STOPPING_PATIENCE}")
print(f"Data: {TRAIN_SUBSET*100:.0f}% subset, batch_size={BATCH_SIZE}")
print()
print("Sweep Plan:")
for cfg in SWEEP_CONFIGS:
    name = f"{cfg['name']}_c{LATENT_CHANNELS}_b{cfg['base_channels']}_cr{int(compression_ratio)}x"
    print(f"  {name}")
print(f"\nTotal runs: {len(SWEEP_CONFIGS)}")

Compression: 16x (LC=16)
LR=0.0001, Epochs=35, Patience=12
Data: 10% subset, batch_size=16

Sweep Plan:
  baseline_c16_b64_cr16x
  resnet_c16_b64_cr16x

Total runs: 2


## 3. Load Data

In [3]:
import random

print("Loading data...")
dm = SARDataModule(
    patches_path=DATA_PATH,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    val_fraction=VAL_FRACTION,
)

if TRAIN_SUBSET < 1.0:
    full_train_size = len(dm.train_dataset)
    train_subset_size = int(full_train_size * TRAIN_SUBSET)
    train_indices = random.sample(range(full_train_size), train_subset_size)
    dm.train_dataset = torch.utils.data.Subset(dm.train_dataset, train_indices)

    full_val_size = len(dm.val_dataset)
    val_subset_size = int(full_val_size * TRAIN_SUBSET)
    val_indices = random.sample(range(full_val_size), val_subset_size)
    dm.val_dataset = torch.utils.data.Subset(dm.val_dataset, val_indices)

    print(f"Using {TRAIN_SUBSET*100:.0f}% subset:")
    print(f"  Train: {train_subset_size:,} of {full_train_size:,}")
    print(f"  Val: {val_subset_size:,} of {full_val_size:,}")

print(f"\nPreprocessing params: {dm.preprocessing_params}")

Loading data...
Loading metadata from D:\Projects\CNNAutoencoderProject\data\patches\metadata.npy
Total patches: 696277
Train: 626650, Val: 69627
Using 10% subset:
  Train: 62,665 of 626,650
  Val: 6,962 of 69,627

Preprocessing params: {'vmin': np.float32(14.768799), 'vmax': np.float32(24.54073)}


## 4. Run Sweep

In [None]:
import glob

results = []

sweep_pbar = tqdm(SWEEP_CONFIGS, desc="Sweep Progress", unit="run")

for cfg in sweep_pbar:
    model_name = cfg['name']
    base_channels = cfg['base_channels']
    run_name = f"{model_name}_c{LATENT_CHANNELS}_b{base_channels}_cr{int(compression_ratio)}x"
    sweep_pbar.set_postfix_str(run_name)

    # Check for existing checkpoint — skip if already trained
    existing = glob.glob(f'checkpoints/{run_name}_*/best.pth')
    if existing:
        checkpoint_path = existing[-1]  # most recent
        print(f"\n{'=' * 70}")
        print(f"  SKIPPING {run_name} — checkpoint exists")
        print(f"  {checkpoint_path}")
        print(f"{'=' * 70}")

        # Load checkpoint to extract metrics
        ckpt = torch.load(checkpoint_path, weights_only=False)
        log_dir = checkpoint_path.replace('checkpoints', 'runs').replace('/best.pth', '').replace('\\best.pth', '')

        # Run quick validation to get metrics
        model = cfg['cls'](
            latent_channels=LATENT_CHANNELS,
            base_channels=base_channels,
            **cfg['extra_kwargs'],
        )
        model.load_state_dict(ckpt['model_state_dict'])
        model.eval().cuda()
        params = model.count_parameters()

        loss_fn_tmp = CombinedLoss(mse_weight=MSE_WEIGHT, ssim_weight=SSIM_WEIGHT)
        val_loader = dm.val_dataloader()

        val_losses, val_psnrs, val_ssims = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.cuda()
                output, _ = model(batch)
                loss, metrics = loss_fn_tmp(output, batch)
                val_losses.append(loss.item())
                val_psnrs.append(metrics['psnr'])
                val_ssims.append(metrics['ssim'])

        best_psnr = sum(val_psnrs) / len(val_psnrs)
        best_ssim = sum(val_ssims) / len(val_ssims)
        best_loss = sum(val_losses) / len(val_losses)

        result = {
            'run_name': Path(checkpoint_path).parent.name,
            'model': model_name,
            'base_channels': base_channels,
            'parameters': params['total'],
            'epochs_trained': ckpt.get('epoch', '?'),
            'elapsed_min': 0,
            'checkpoint': checkpoint_path,
            'log_dir': log_dir,
            'best_val_loss': best_loss,
            'best_psnr': best_psnr,
            'best_ssim': best_ssim,
            'skipped': True,
        }
        results.append(result)
        print(f"  Validated: PSNR={best_psnr:.2f} dB, SSIM={best_ssim:.4f}")

        del model, loss_fn_tmp, val_loader
        gc.collect()
        torch.cuda.empty_cache()
        continue

    print(f"\n{'=' * 70}")
    print(f"  {run_name}")
    print(f"{'=' * 70}")

    # Create model
    model = cfg['cls'](
        latent_channels=LATENT_CHANNELS,
        base_channels=base_channels,
        **cfg['extra_kwargs'],
    )
    params = model.count_parameters()
    print(f"  Parameters: {params['total']:,}")

    # Loss
    loss_fn = CombinedLoss(mse_weight=MSE_WEIGHT, ssim_weight=SSIM_WEIGHT)

    # Trainer config
    config = {
        'learning_rate': LEARNING_RATE,
        'optimizer': 'adamw',
        'scheduler': 'plateau',
        'lr_patience': LR_PATIENCE,
        'lr_factor': LR_FACTOR,
        'max_grad_norm': 1.0,
        'use_amp': True,
        'notebook': True,
        'run_name': run_name,
        'preprocessing_params': dm.preprocessing_params,
        'model_type': model_name,
        'latent_channels': LATENT_CHANNELS,
        'base_channels': base_channels,
        'mse_weight': MSE_WEIGHT,
        'ssim_weight': SSIM_WEIGHT,
        'batch_size': BATCH_SIZE,
        'compression_ratio': compression_ratio,
    }

    train_loader = dm.train_dataloader()
    val_loader = dm.val_dataloader()

    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        loss_fn=loss_fn,
        config=config,
    )

    # Train
    t0 = time.time()
    history = trainer.train(
        epochs=EPOCHS,
        early_stopping_patience=EARLY_STOPPING_PATIENCE,
    )
    elapsed = time.time() - t0

    # Collect results
    result = {
        'run_name': trainer.run_name,
        'model': model_name,
        'base_channels': base_channels,
        'parameters': params['total'],
        'epochs_trained': len(history),
        'elapsed_min': elapsed / 60,
        'checkpoint': str(trainer.checkpoint_dir / 'best.pth'),
        'log_dir': str(trainer.log_dir),
        'skipped': False,
    }

    if history:
        best_epoch = min(history, key=lambda h: h.get('val_loss', float('inf')))
        result['best_val_loss'] = best_epoch.get('val_loss')
        result['best_psnr'] = best_epoch.get('val_psnr')
        result['best_ssim'] = best_epoch.get('val_ssim')

    results.append(result)

    psnr_str = f"{result.get('best_psnr', 0):.2f} dB" if result.get('best_psnr') else "N/A"
    print(f"\n  Done: {psnr_str} | {elapsed/60:.1f} min | {trainer.checkpoint_dir / 'best.pth'}")

    # Cleanup GPU
    del model, trainer, loss_fn, train_loader, val_loader
    gc.collect()
    torch.cuda.empty_cache()

print("\nSweep complete!")

Sweep Progress:   0%|          | 0/2 [00:00<?, ?run/s]


  SKIPPING baseline_c16_b64_cr16x — checkpoint exists
  checkpoints\baseline_c16_b64_cr16x_20260127_231730\best.pth
  Validated: PSNR=19.09 dB, SSIM=0.5721

  resnet_c16_b64_cr16x
  Parameters: 22,395,873
Using device: cuda


2026-01-28 00:39:26,730 - Log directory: runs\resnet_c16_b64_cr16x_20260128_003926
2026-01-28 00:39:26,730 - Checkpoint directory: checkpoints\resnet_c16_b64_cr16x_20260128_003926
2026-01-28 00:39:26,730 - Mixed Precision (AMP): enabled
2026-01-28 00:39:26,732 - Starting training for 35 epochs
2026-01-28 00:39:26,733 - Model: ResNetAutoencoder
2026-01-28 00:39:26,733 - Config: {'learning_rate': 0.0001, 'optimizer': 'adamw', 'scheduler': 'plateau', 'lr_patience': 10, 'lr_factor': 0.5, 'max_grad_norm': 1.0, 'use_amp': True, 'notebook': True, 'run_name': 'resnet_c16_b64_cr16x', 'preprocessing_params': {'vmin': np.float32(14.768799), 'vmax': np.float32(24.54073)}, 'model_type': 'resnet', 'latent_channels': 16, 'base_channels': 64, 'mse_weight': 0.5, 'ssim_weight': 0.5, 'batch_size': 16, 'compression_ratio': 16.0}


GPU memory: 634MB / 8192MB (8% used, 7.4 GB free)
Using AdamW optimizer with weight_decay=1e-05
Using ReduceLROnPlateau: patience=10
Mixed Precision (AMP) enabled - ~2x training speedup


Epoch 1 [Train]:   0%|          | 0/3916 [00:00<?, ?it/s]

## 5. Results Summary

In [None]:
print(f"{'Run':<40} {'Params':>10} {'PSNR':>10} {'SSIM':>10} {'Epochs':>8} {'Time':>8}")
print("-" * 88)

for r in results:
    name = r['run_name']
    params = f"{r['parameters']/1e6:.1f}M"
    psnr = f"{r['best_psnr']:.2f} dB" if r.get('best_psnr') else "N/A"
    ssim = f"{r['best_ssim']:.4f}" if r.get('best_ssim') else "N/A"
    epochs = str(r['epochs_trained'])
    mins = f"{r['elapsed_min']:.0f}m"
    print(f"{name:<40} {params:>10} {psnr:>10} {ssim:>10} {epochs:>8} {mins:>8}")

## 6. Architecture Comparison Chart

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime

models = [r['model'] for r in results if r.get('best_psnr')]
psnrs = [r['best_psnr'] for r in results if r.get('best_psnr')]
ssims = [r['best_ssim'] for r in results if r.get('best_ssim')]
param_counts = [r['parameters'] / 1e6 for r in results if r.get('best_psnr')]

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

colors = ['#2196F3', '#4CAF50']

# PSNR comparison
ax = axes[0]
bars = ax.bar(models, psnrs, color=colors[:len(models)])
ax.set_ylabel('PSNR (dB)', fontsize=12)
ax.set_title('PSNR by Architecture', fontsize=14)
ax.axhline(y=25, color='r', linestyle='--', alpha=0.5, label='Target')
for bar, val in zip(bars, psnrs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2,
            f'{val:.2f}', ha='center', fontsize=10)
ax.legend()
ax.grid(axis='y', alpha=0.3)

# SSIM comparison
ax = axes[1]
bars = ax.bar(models, ssims, color=colors[:len(models)])
ax.set_ylabel('SSIM', fontsize=12)
ax.set_title('SSIM by Architecture', fontsize=14)
ax.axhline(y=0.85, color='r', linestyle='--', alpha=0.5, label='Target')
for bar, val in zip(bars, ssims):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{val:.4f}', ha='center', fontsize=10)
ax.legend()
ax.grid(axis='y', alpha=0.3)

# PSNR vs Parameters (efficiency)
ax = axes[2]
ax.scatter(param_counts, psnrs, c=colors[:len(models)], s=150, zorder=5)
for m, p, psnr in zip(models, param_counts, psnrs):
    ax.annotate(m, (p, psnr), textcoords='offset points',
                xytext=(5, 5), fontsize=10)
ax.set_xlabel('Parameters (M)', fontsize=12)
ax.set_ylabel('PSNR (dB)', fontsize=12)
ax.set_title('Quality vs Model Size', fontsize=14)
ax.grid(True, alpha=0.3)

plt.suptitle('Architecture Comparison at 16x Compression', fontsize=16, y=1.02)
plt.tight_layout()

datestamp = datetime.now().strftime('%Y%m%d')
save_path = f'runs/architecture_comparison_16x_{datestamp}.png'
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved to: {save_path}")

In [None]:
# Load baseline ratios sweep and append 16x results for combined R-D curve
import json
import matplotlib.pyplot as plt
from datetime import datetime

# Load previous baseline ratios sweep
ratios_json = sorted(Path('runs').glob('sweep_baseline_ratios*.json'), key=lambda p: p.stat().st_mtime)
if ratios_json:
    with open(ratios_json[-1]) as f:
        prev_sweep = json.load(f)
    prev_results = prev_sweep['results']
    print(f"Loaded previous sweep: {ratios_json[-1].name}")
    print(f"  Ratios: {[f'{r[\"compression_ratio\"]:.0f}x' for r in prev_results]}")
else:
    prev_results = []
    print("No previous sweep found — plotting 16x only")

# Get best baseline result from this sweep
baseline_16x = next((r for r in results if r['model'] == 'baseline' and r.get('best_psnr')), None)
resnet_16x = next((r for r in results if r['model'] == 'resnet' and r.get('best_psnr')), None)

# Build combined baseline R-D data (4x, 8x, 12x from prev + 16x from this sweep)
bl_ratios = [r['compression_ratio'] for r in prev_results if r.get('best_psnr')]
bl_psnrs = [r['best_psnr'] for r in prev_results if r.get('best_psnr')]
bl_ssims = [r['best_ssim'] for r in prev_results if r.get('best_ssim')]

if baseline_16x:
    bl_ratios.append(16.0)
    bl_psnrs.append(baseline_16x['best_psnr'])
    bl_ssims.append(baseline_16x['best_ssim'])

# Sort by ratio
sorted_idx = sorted(range(len(bl_ratios)), key=lambda i: bl_ratios[i])
bl_ratios = [bl_ratios[i] for i in sorted_idx]
bl_psnrs = [bl_psnrs[i] for i in sorted_idx]
bl_ssims = [bl_ssims[i] for i in sorted_idx]

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# PSNR
ax1.plot(bl_ratios, bl_psnrs, 'bo-', markersize=8, linewidth=2, label='Baseline')
for r, p in zip(bl_ratios, bl_psnrs):
    ax1.annotate(f'{p:.1f}', (r, p), textcoords='offset points',
                 xytext=(0, 10), ha='center', fontsize=9)

if resnet_16x:
    ax1.plot(16.0, resnet_16x['best_psnr'], 'gs', markersize=10, label='ResNet @16x')
    ax1.annotate(f"{resnet_16x['best_psnr']:.1f}", (16.0, resnet_16x['best_psnr']),
                 textcoords='offset points', xytext=(10, -5), ha='left', fontsize=9, color='green')

ax1.axhline(y=25, color='r', linestyle='--', alpha=0.5, label='Target (25 dB)')
ax1.set_xlabel('Compression Ratio', fontsize=12)
ax1.set_ylabel('PSNR (dB)', fontsize=12)
ax1.set_title('Rate-Distortion: PSNR', fontsize=14)
ax1.set_xscale('log', base=2)
ax1.set_xticks(bl_ratios)
ax1.set_xticklabels([f'{int(r)}x' for r in bl_ratios])
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.invert_xaxis()

# SSIM
ax2.plot(bl_ratios, bl_ssims, 'bo-', markersize=8, linewidth=2, label='Baseline')
for r, s in zip(bl_ratios, bl_ssims):
    ax2.annotate(f'{s:.3f}', (r, s), textcoords='offset points',
                 xytext=(0, 10), ha='center', fontsize=9)

if resnet_16x:
    ax2.plot(16.0, resnet_16x['best_ssim'], 'gs', markersize=10, label='ResNet @16x')
    ax2.annotate(f"{resnet_16x['best_ssim']:.3f}", (16.0, resnet_16x['best_ssim']),
                 textcoords='offset points', xytext=(10, -5), ha='left', fontsize=9, color='green')

ax2.axhline(y=0.85, color='r', linestyle='--', alpha=0.5, label='Target (0.85)')
ax2.set_xlabel('Compression Ratio', fontsize=12)
ax2.set_ylabel('SSIM', fontsize=12)
ax2.set_title('Rate-Distortion: SSIM', fontsize=14)
ax2.set_xscale('log', base=2)
ax2.set_xticks(bl_ratios)
ax2.set_xticklabels([f'{int(r)}x' for r in bl_ratios])
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.invert_xaxis()

plt.suptitle('Combined Rate-Distortion (4x → 16x)', fontsize=16, y=1.02)
plt.tight_layout()

datestamp = datetime.now().strftime('%Y%m%d')
save_path = f'runs/combined_rate_distortion_{datestamp}.png'
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()

# Print combined table
print(f"\n{'Ratio':<8} {'Model':<12} {'PSNR':>10} {'SSIM':>10}")
print("-" * 42)
for r, p, s in zip(bl_ratios, bl_psnrs, bl_ssims):
    print(f"{int(r)}x{'':<5} {'baseline':<12} {p:>10.2f} {s:>10.4f}")
if resnet_16x:
    print(f"16x{'':<5} {'resnet':<12} {resnet_16x['best_psnr']:>10.2f} {resnet_16x['best_ssim']:>10.4f}")

print(f"\nSaved to: {save_path}")

## 7. Save Results to JSON

In [None]:
import json
from datetime import datetime

datestamp = datetime.now().strftime('%Y%m%d')

output = {
    'sweep_type': 'architecture_comparison_16x',
    'timestamp': datetime.now().isoformat(),
    'config': {
        'latent_channels': LATENT_CHANNELS,
        'compression_ratio': compression_ratio,
        'learning_rate': LEARNING_RATE,
        'epochs': EPOCHS,
        'batch_size': BATCH_SIZE,
        'train_subset': TRAIN_SUBSET,
    },
    'results': results,
}

output_path = Path(f'runs/sweep_all_16x_{datestamp}.json')
with open(output_path, 'w') as f:
    json.dump(output, f, indent=2, default=str)

print(f"Results saved to: {output_path}")

---

## Done!

**Next steps:**
1. View TensorBoard: `tensorboard --logdir=runs`
2. Compare with baseline rate-distortion curves from `sweep_baseline_ratios.ipynb`
3. Pick best architecture and sweep its compression ratios