# Day 2: Model Implementation (No References)

**Time:** 4-5 hours

This version has NO reference implementations. You'll need to figure it out from:
- The docstrings and hints
- The test cells (they tell you expected behavior)
- Your Day 1 knowledge

## Architecture Overview

```
ENCODER (compresses):
256×256×1 → 128×128×64 → 64×64×128 → 32×32×256 → 16×16×64

DECODER (reconstructs):
16×16×64 → 32×32×256 → 64×64×128 → 128×128×64 → 256×256×1
```

---

## Setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


---

# Part 1: Encoder (45 minutes)

## 1.1 ConvBlock

Build a block that does: **Conv2d → BatchNorm → LeakyReLU**

Hints:
- `nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding, bias=...)`
- `nn.BatchNorm2d(num_features)`
- `nn.LeakyReLU(negative_slope)`
- If using BatchNorm, set `bias=False` in Conv2d (BN has its own bias)

In [2]:
class ConvBlock(nn.Module):
    """
    Conv2d → BatchNorm → LeakyReLU
    
    Default params create a block that halves spatial dimensions:
    - kernel_size=5, stride=2, padding=2
    - LeakyReLU slope = 0.2
    """
    
    def __init__(self, in_channels, out_channels, kernel_size=5,
                 stride=2, padding=2, use_bn=True):
        super().__init__()
        
        # TODO: Create self.conv, self.bn, self.activation
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias = not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.activation = nn.LeakyReLU(negative_slope=0.2)
    
    def forward(self, x):
        # TODO: Apply conv → bn → activation
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

In [3]:
# TEST
def test_conv_block():
    block = ConvBlock(1, 64)
    x = torch.randn(2, 1, 256, 256)
    y = block(x)
    assert y.shape == (2, 64, 128, 128), f"Expected (2,64,128,128), got {y.shape}"
    print(f"✓ ConvBlock: {x.shape} → {y.shape}")

test_conv_block()

✓ ConvBlock: torch.Size([2, 1, 256, 256]) → torch.Size([2, 64, 128, 128])


## 1.2 SAREncoder

Stack 4 ConvBlocks. The last one should be a plain Conv2d (no BN, no activation).

```
Layer 1: 1 → 64 channels
Layer 2: 64 → 128 channels  
Layer 3: 128 → 256 channels
Layer 4: 256 → latent_channels (just Conv2d, no activation)
```

In [4]:
class SAREncoder(nn.Module):
    """
    Encoder: 256×256×1 → 16×16×latent_channels
    
    4 layers, each halves spatial dimensions.
    Last layer has no activation (latent should be unbounded).
    """
    
    def __init__(self, in_channels=1, latent_channels=64, base_channels=64, use_bn = True):
        super().__init__()
        self.in_channels = in_channels
        self.latent_channels = latent_channels
        # TODO: Create 4 layers
        # Hint: Use ConvBlock for layers 1-3, plain nn.Conv2d for layer 4
        channels = [in_channels, base_channels, base_channels*2, base_channels*4]

        self.layer1 = ConvBlock(
            channels[0], channels[1],
            kernel_size=5, stride=2, padding=2, use_bn=use_bn
        )

        self.layer2 = ConvBlock(
            channels[1], channels[2],
            kernel_size=5, stride=2, padding=2, use_bn=use_bn
        )

        self.layer3 = ConvBlock(
            channels[2], channels[3],
            kernel_size=5, stride=2, padding=2, use_bn=use_bn
        )

        self.layer4 = nn.Conv2d(
            channels[3], latent_channels,
            kernel_size=5, stride=2, padding=2
        )
        
        self._initialise_weights()

    def _initialise_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
            
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
        
    
    def forward(self, x):
        # TODO
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x
    
    def calcreceptivefield(self):
        rf = 5  
        stride_product = 2
        for _ in range(3): 
            rf += (5 - 1) * stride_product
            stride_product *= 2
        return rf

In [5]:
# TEST
def test_encoder():
    encoder = SAREncoder(latent_channels=64)
    x = torch.randn(2, 1, 256, 256)
    z = encoder(x)
    print(f"\nArchitecture:")
    print(f"  Input shape:  {tuple(x.shape)}")
    print(f"  Output shape: {tuple(z.shape)}")
    print(f"  Receptive field: {encoder.calcreceptivefield()}")

    assert z.shape == (2, 64, 16, 16), f"Expected (2,64,16,16), got {z.shape}"
    print(f"✓ Encoder: {x.shape} → {z.shape}")
    
    # Check gradients
    z.mean().backward()
    grad_norms = []
    for name, param in encoder.named_parameters():
        if param.grad is not None:
            grad_norms.append((name, param.grad.norm().item()))
    print(f"\nGradient norms (should be non-zero, similar magnitude):")
    for name, norm in grad_norms[:4]:  # First 4
        print(f"  {name}: {norm:.6f}")
    print("✓ Gradients flow")


    
    params = sum(p.numel() for p in encoder.parameters())
    print(f"✓ Parameters: {params:,}")

test_encoder()


Architecture:
  Input shape:  (2, 1, 256, 256)
  Output shape: (2, 64, 16, 16)
  Receptive field: 61
✓ Encoder: torch.Size([2, 1, 256, 256]) → torch.Size([2, 64, 16, 16])

Gradient norms (should be non-zero, similar magnitude):
  layer1.conv.weight: 0.371410
  layer1.bn.weight: 0.011589
  layer1.bn.bias: 0.008325
  layer2.conv.weight: 0.446748
✓ Gradients flow
✓ Parameters: 1,436,160


---

# Part 2: Decoder (45 minutes)

## 2.1 DeconvBlock

Like ConvBlock but uses `nn.ConvTranspose2d` to upsample.

Key difference: need `output_padding=1` to get exact 2× upsampling.

Use ReLU (not LeakyReLU) for decoder.

In [6]:
class DeconvBlock(nn.Module):
    """
    ConvTranspose2d → BatchNorm → ReLU
    
    Doubles spatial dimensions.
    """
    
    def __init__(self, in_channels, out_channels, kernel_size=5,
                 stride=2, padding=2, output_padding=1, use_bn=True):
        super().__init__()
        
        # TODO: Create self.deconv, self.bn, self.activation
        self.deconv = nn.ConvTranspose2d(
            in_channels, out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            bias=not use_bn
        )
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.activation = nn.ReLU()



    
    def forward(self, x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x


In [7]:
# TEST
def test_deconv_block():
    block = DeconvBlock(64, 32)
    x = torch.randn(2, 64, 16, 16)
    y = block(x)
    assert y.shape == (2, 32, 32, 32), f"Expected (2,32,32,32), got {y.shape}"
    print(f"✓ DeconvBlock: {x.shape} → {y.shape}")

test_deconv_block()

✓ DeconvBlock: torch.Size([2, 64, 16, 16]) → torch.Size([2, 32, 32, 32])


## 2.2 SARDecoder

Mirror of encoder. Last layer: just ConvTranspose2d + sigmoid for [0,1] output.

In [8]:
class SARDecoder(nn.Module):
    """
    Decoder: 16×16×latent → 256×256×1
    
    Mirrors encoder. Output has sigmoid for [0,1] range.
    """
    
    def __init__(self, out_channels=1, latent_channels=64, base_channels=64, use_bn = True):
        super().__init__()
        self.out_channels = out_channels
        self.latent_channels = latent_channels
        # TODO: Create 4 layers (reverse of encoder channel progression)
        # Layer 1: latent → 256
        # Layer 2: 256 → 128
        # Layer 3: 128 → 64
        # Layer 4: 64 → 1 (no BN, use sigmoid)
        channels = [latent_channels, base_channels*4, base_channels*2, 
                    base_channels, out_channels]
            
        self.layer1 = DeconvBlock(
            channels[0], channels[1],
            kernel_size=5, stride=2, padding=2, output_padding=1, use_bn=use_bn
        )    

        self.layer2 = DeconvBlock(
            channels[1], channels[2],
            kernel_size=5, stride=2, padding=2, output_padding=1, use_bn=use_bn
        )

        self.layer3 = DeconvBlock(
            channels[2], channels[3],
            kernel_size=5, stride=2, padding=2, output_padding=1, use_bn=use_bn
        )

        self.layer4 = nn.ConvTranspose2d(
            channels[3], channels[4],
            kernel_size=5, stride=2, padding=2, output_padding=1
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',
                                        nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, z):
        x = self.layer1(z)  # (B, 256, 32, 32)
        x = self.layer2(x)  # (B, 128, 64, 64)
        x = self.layer3(x)  # (B, 64, 128, 128)
        x = self.layer4(x)  # (B, 1, 256, 256)

        x = torch.sigmoid(x)
        
        return x

In [9]:
# TEST
def test_decoder():
    decoder = SARDecoder(latent_channels=64)
    z = torch.randn(2, 64, 16, 16)
    x_hat = decoder(z)
    
    assert x_hat.shape == (2, 1, 256, 256), f"Expected (2,1,256,256), got {x_hat.shape}"
    print(f"✓ Decoder: {z.shape} → {x_hat.shape}")
    
    assert x_hat.min() >= 0 and x_hat.max() <= 1, "Output should be [0,1]"
    print(f"✓ Output range: [{x_hat.min():.3f}, {x_hat.max():.3f}]")

test_decoder()

✓ Decoder: torch.Size([2, 64, 16, 16]) → torch.Size([2, 1, 256, 256])
✓ Output range: [0.060, 0.953]


---

# Part 3: Autoencoder (30 minutes)

In [10]:
class SARAutoencoder(nn.Module):
    """
    Complete autoencoder.
    
    forward() returns (x_hat, z)
    """
    
    def __init__(self, latent_channels=64, base_channels=64, use_bn = True):
        super().__init__()
        self.latent_channels = latent_channels
        
        # TODO: Create encoder and decoder
        
        self.latent_channels = latent_channels
        self.base_channels = base_channels

        self.encoder = SAREncoder(
            in_channels=1,
            latent_channels=latent_channels,
            base_channels=base_channels,
            use_bn=use_bn
        )
        
        self.decoder = SARDecoder(
            out_channels=1,
            latent_channels=latent_channels,
            base_channels=base_channels,
            use_bn=use_bn
        )


    
    def forward(self, x):
        """Returns (reconstruction, latent)"""
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z)
    
    def get_compression_ratio(self):
        """256×256×1 input, 16×16×latent_channels latent"""
        return (256 * 256 * 1) / (16 * 16 * self.latent_channels)
    
    def get_latent_size(self, input_size: int = 256):
        latent_spatial = input_size // 16
        return (self.latent_channels, latent_spatial, latent_spatial)
    
    def count_parameters(self):
        """Count parameters in encoder and decoder."""
        encoder_params = sum(p.numel() for p in self.encoder.parameters())
        decoder_params = sum(p.numel() for p in self.decoder.parameters())
        return {
            'encoder': encoder_params,
            'decoder': decoder_params,
            'total': encoder_params + decoder_params
        }
    
    def analyze_latent(self, z: torch.Tensor):
        """
        Analyze latent representation statistics.
        
        Useful for monitoring training and diagnosing issues.
        """
        with torch.no_grad():
            z_np = z.cpu().numpy()
            
            return {
                'mean': float(np.mean(z_np)),
                'std': float(np.std(z_np)),
                'min': float(np.min(z_np)),
                'max': float(np.max(z_np)),
                'sparsity': float(np.mean(np.abs(z_np) < 0.01)),  # Near-zero fraction
                'channel_stds': [float(np.std(z_np[:, c, :, :])) 
                                for c in range(min(z_np.shape[1], 8))],
            }

In [11]:
# TEST
def test_autoencoder():
    configs = [
        {'latent_channels': 64, 'name': 'Standard (4x compression)'},
        {'latent_channels': 32, 'name': 'High compression (8x)'},
        {'latent_channels': 128, 'name': 'Low compression (2x)'},
    ]
    for config in configs:
        print(f"\n--- {config['name']} ---")
        
        model = SARAutoencoder(latent_channels=config['latent_channels'])
        x = torch.randn(2, 1, 256, 256)
        x = torch.sigmoid(x)
    
        x_hat, z = model(x)
    
        assert x_hat.shape == x.shape, f"Shape mismatch: {x_hat.shape} vs {x.shape}"
        expected_latent = model.get_latent_size(256)
        actual_latent = tuple(z.shape[1:])
        assert actual_latent == expected_latent, f"Latent shape mismatch"

        params = model.count_parameters()
        compression = model.get_compression_ratio()
        latent_stats = model.analyze_latent(z)

        print(f"  Compression ratio: {compression:.1f}x")
        print(f"  Parameters: {params['total']:,}")
        print(f"    Encoder: {params['encoder']:,}")
        print(f"    Decoder: {params['decoder']:,}")
        print(f"  Latent shape: {tuple(z.shape)}")
        print(f"  Latent mean: {latent_stats['mean']:.4f}")
        print(f"  Latent std: {latent_stats['std']:.4f}")
        print(f"  Output range: [{x_hat.min():.3f}, {x_hat.max():.3f}]")

    print("\n--- Gradient Flow Test ---")
    model = SARAutoencoder(latent_channels=64)
    x = torch.randn(2, 1, 256, 256, requires_grad=True)
    x_hat, z = model(x)
    loss = F.mse_loss(x_hat, torch.zeros_like(x_hat))
    loss.backward()

    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            if grad_norm == 0:
                print(f"  WARNING: Zero gradient in {name}")
            elif grad_norm > 100:
                print(f"  WARNING: Large gradient in {name}: {grad_norm:.2f}")
    
    print("\n✓ All autoencoder tests passed!")
    
    # print(f"✓ Autoencoder: {x.shape} → {z.shape} → {x_hat.shape}")
    # print(f"✓ Compression: {model.get_compression_ratio():.1f}x")
    # print(f"✓ Parameters: {sum(p.numel() for p in model.parameters()):,}")

test_autoencoder()


--- Standard (4x compression) ---
  Compression ratio: 4.0x
  Parameters: 2,872,257
    Encoder: 1,436,160
    Decoder: 1,436,097
  Latent shape: (2, 64, 16, 16)
  Latent mean: 0.1596
  Latent std: 1.9508
  Output range: [0.098, 0.918]

--- High compression (8x) ---
  Compression ratio: 8.0x
  Parameters: 2,462,625
    Encoder: 1,231,328
    Decoder: 1,231,297
  Latent shape: (2, 32, 16, 16)
  Latent mean: -0.2214
  Latent std: 2.8458
  Output range: [0.153, 0.964]

--- Low compression (2x) ---
  Compression ratio: 2.0x
  Parameters: 3,691,521
    Encoder: 1,845,824
    Decoder: 1,845,697
  Latent shape: (2, 128, 16, 16)
  Latent mean: -0.0645
  Latent std: 1.3746
  Output range: [0.096, 0.929]

--- Gradient Flow Test ---

✓ All autoencoder tests passed!


---

# Part 4: Loss Function (30 minutes)

## 4.1 SSIM Loss

Steps:
1. Create Gaussian window (for local averaging)
2. Compute local means: μx, μy
3. Compute local variances: σx², σy² = E[X²] - E[X]²
4. Compute covariance: σxy = E[XY] - E[X]E[Y]
5. Apply SSIM formula
6. Return 1 - mean(SSIM)

In [12]:
class SSIMLoss(nn.Module):
    """
    SSIM Loss: returns 1 - SSIM
    
    Lower is better (0 = identical images).
    """
    
    def __init__(self, window_size=11, sigma=1.5, data_range = 1.0, channel = 1):
        super().__init__()
        
        self.window_size = window_size
        self.sigma = sigma
        self.data_range = data_range
        self.channel = channel
        self.C1 = (0.01 * data_range) ** 2
        self.C2 = (0.03 * data_range) ** 2
        
        # TODO: Create Gaussian window and register as buffer
        # 1. coords = torch.arange(window_size) - window_size // 2
        # 2. g = exp(-coords² / (2σ²)), normalize so sum = 1
        # 3. window = outer product: g.unsqueeze(1) @ g.unsqueeze(0)
        # 4. Reshape to (1, 1, window_size, window_size)
        # 5. self.register_buffer('window', window)
        
        self.register_buffer('window', self._create_window(window_size, sigma, channel))

    def _create_window(self, window_size, sigma, channel):
        coords = torch.arange(window_size, dtype=torch.float32)
        coords -= window_size // 2
        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        g /= g.sum()

        window = g.unsqueeze(1) @ g.unsqueeze(0)
        
        # Expand to (channel, 1, window_size, window_size)
        window = window.unsqueeze(0).unsqueeze(0)
        window = window.expand(channel, 1, window_size, window_size).contiguous()
        
        return window
    
    def forward(self, x_hat, x):
        # TODO: Compute SSIM
        mu_x = F.conv2d(x, self.window, padding=self.window_size//2, groups=self.channel)
        mu_y = F.conv2d(x_hat, self.window, padding=self.window_size//2, groups=self.channel)
        
        mu_x_sq = mu_x ** 2
        mu_y_sq = mu_y ** 2
        mu_xy = mu_x * mu_y
        
        # Local variances
        sigma_x_sq = F.conv2d(x ** 2, self.window, padding=self.window_size//2, 
                              groups=self.channel) - mu_x_sq
        sigma_y_sq = F.conv2d(x_hat ** 2, self.window, padding=self.window_size//2,
                              groups=self.channel) - mu_y_sq
        sigma_xy = F.conv2d(x * x_hat, self.window, padding=self.window_size//2,
                            groups=self.channel) - mu_xy
        
        # SSIM formula
        numerator = (2 * mu_xy + self.C1) * (2 * sigma_xy + self.C2)
        denominator = (mu_x_sq + mu_y_sq + self.C1) * (sigma_x_sq + sigma_y_sq + self.C2)
        
        ssim_map = numerator / (denominator + 1e-8)
        
        # Return 1 - mean SSIM as loss
        return 1 - ssim_map.mean()
        
        pass

In [13]:
# TEST
def test_ssim():
    ssim_loss = SSIMLoss()
    x = torch.rand(2, 1, 64, 64)
    
    loss_same = ssim_loss(x, x)
    loss_diff = ssim_loss(torch.rand_like(x), x)
    
    assert loss_same < 0.01, f"Same images should have ~0 loss, got {loss_same:.4f}"
    assert loss_diff > loss_same, "Different images should have higher loss"
    print(f"✓ SSIM(same): {loss_same:.6f}")
    print(f"✓ SSIM(diff): {loss_diff:.4f}")

test_ssim()


✓ SSIM(same): 0.000000
✓ SSIM(diff): 0.9300


In [14]:
project_root = Path.cwd().parents[1] # Up from learningnotebooks/phase4_sar_codec/
print(project_root)
import sys
sys.path.insert(0, str(project_root / "src"))
from losses.mse import MSELoss

d:\Projects\CNNAutoencoderProject


## 4.2 Combined Loss

In [15]:
class CombinedLoss(nn.Module):
    """
    loss = mse_weight * MSE + ssim_weight * (1 - SSIM)
    
    Returns: (loss_tensor, metrics_dict)
    metrics_dict has: 'loss', 'mse', 'ssim', 'psnr'
    """
    
    def __init__(self, mse_weight=1.0, ssim_weight=0.1, window_size = 11):
        super().__init__()
        self.mse_weight = mse_weight
        self.ssim_weight = ssim_weight
        
        self.mse_loss = MSELoss()
        self.ssim_loss = SSIMLoss(window_size=window_size)
        
    
    def forward(self, x_hat, x):
        # TODO:
        # 1. mse = F.mse_loss(x_hat, x)
        # 2. ssim_l = self.ssim_loss(x_hat, x)
        # 3. loss = weighted combination
        # 4. psnr = 10 * log10(1 / mse)  (use torch.log10)
        # 5. Return loss, {'loss':..., 'mse':..., 'ssim': 1-ssim_l, 'psnr':...}
        
        mse = self.mse_loss(x_hat, x)
        ssim_l = self.ssim_loss(x_hat, x)
        
        loss = self.mse_weight * mse + self.ssim_weight * ssim_l
        
        # Compute PSNR (for logging)
        with torch.no_grad():
            psnr = 10 * torch.log10(1.0 / (mse + 1e-10))
            ssim = 1 - ssim_l
        
        metrics = {
            'loss': loss.item(),
            'mse': mse.item(),
            'ssim': ssim.item(),
            'psnr': psnr.item(),
        }
        
        return loss, metrics

In [16]:
# TEST
def test_combined_loss():
    loss_fn = CombinedLoss()
    x = torch.rand(2, 1, 64, 64)
    x_noisy = (x + 0.1 * torch.randn_like(x)).clamp(0, 1)
    
    loss, metrics = loss_fn(x_noisy, x)
    
    assert all(k in metrics for k in ['loss', 'mse', 'ssim', 'psnr'])
    print(f"✓ Loss: {metrics['loss']:.4f}")
    print(f"✓ PSNR: {metrics['psnr']:.2f} dB")
    print(f"✓ SSIM: {metrics['ssim']:.4f}")

test_combined_loss()

✓ Loss: 0.0140
✓ PSNR: 20.50 dB
✓ SSIM: 0.9493


In [17]:
class EdgePreservingLoss(nn.Module):

    
    def __init__(self, mse_weight: float = 1.0, ssim_weight: float = 0.1,
                 edge_weight: float = 0.1):
        super().__init__()
        
        self.mse_weight = mse_weight
        self.ssim_weight = ssim_weight
        self.edge_weight = edge_weight
        
        self.mse_loss = MSELoss()
        self.ssim_loss = SSIMLoss()
        
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], 
                               dtype=torch.float32).view(1, 1, 3, 3)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], 
                               dtype=torch.float32).view(1, 1, 3, 3)
        
        self.register_buffer('sobel_x', sobel_x)
        self.register_buffer('sobel_y', sobel_y)
        
    def _compute_edges(self, x: torch.Tensor):
        """Compute edge magnitude using Sobel filters."""
        edge_x = F.conv2d(x, self.sobel_x, padding=1)
        edge_y = F.conv2d(x, self.sobel_y, padding=1)
        return torch.sqrt(edge_x ** 2 + edge_y ** 2 + 1e-8)
    
    def forward(self, x_hat: torch.Tensor, x: torch.Tensor):
        """Compute edge-preserving loss."""
        mse = self.mse_loss(x_hat, x)
        ssim_l = self.ssim_loss(x_hat, x)
        
        # Edge loss
        edges_x = self._compute_edges(x)
        edges_x_hat = self._compute_edges(x_hat)
        edge_loss = F.mse_loss(edges_x_hat, edges_x)
        
        loss = (self.mse_weight * mse + 
                self.ssim_weight * ssim_l + 
                self.edge_weight * edge_loss)
        
        with torch.no_grad():
            psnr = 10 * torch.log10(1.0 / (mse + 1e-10))
        
        metrics = {
            'loss': loss.item(),
            'mse': mse.item(),
            'ssim': (1 - ssim_l).item(),
            'edge': edge_loss.item(),
            'psnr': psnr.item(),
        }
        
        return loss, metrics

In [18]:
def test_losses():
    """Test loss functions."""
    print("=" * 60)
    print("LOSS FUNCTION TESTS")
    print("=" * 60)
    
    # Create test images
    x = torch.rand(4, 1, 64, 64)
    
    # Perfect reconstruction
    x_hat_perfect = x.clone()
    
    # Noisy reconstruction
    x_hat_noisy = x + 0.1 * torch.randn_like(x)
    x_hat_noisy = x_hat_noisy.clamp(0, 1)
    
    # Blurry reconstruction
    x_hat_blur = F.avg_pool2d(x, 3, stride=1, padding=1)
    
    # Test each loss
    losses_to_test = [
        ('MSE', MSELoss()),
        ('SSIM', SSIMLoss()),
        ('Combined', CombinedLoss()),
        ('EdgePreserving', EdgePreservingLoss()),
    ]
    
    print("\nLoss values for different reconstruction types:")
    print("-" * 60)
    
    for loss_name, loss_fn in losses_to_test:
        print(f"\n{loss_name}:")
        
        returns_tuple = loss_name in ['Combined', 'EdgePreserving']
        
        for recon_name, x_hat in [('Perfect', x_hat_perfect), 
                                   ('Noisy', x_hat_noisy),
                                   ('Blurry', x_hat_blur)]:
            if returns_tuple:
                loss, metrics = loss_fn(x_hat, x)
                print(f"  {recon_name}: loss={metrics['loss']:.4f}, psnr={metrics.get('psnr', 'N/A')}")
            else:
                loss = loss_fn(x_hat, x)
                print(f"  {recon_name}: loss={loss.item():.4f}")
    
    print("\n✓ Loss function tests passed!")

In [19]:
test_losses()

LOSS FUNCTION TESTS

Loss values for different reconstruction types:
------------------------------------------------------------

MSE:
  Perfect: loss=0.0000
  Noisy: loss=0.0090
  Blurry: loss=0.0758

SSIM:
  Perfect: loss=0.0000
  Noisy: loss=0.0513
  Blurry: loss=0.7742

Combined:
  Perfect: loss=0.0000, psnr=100.0
  Noisy: loss=0.0141, psnr=20.47527313232422
  Blurry: loss=0.1532, psnr=11.203914642333984

EdgePreserving:
  Perfect: loss=0.0000, psnr=100.0
  Noisy: loss=0.0237, psnr=20.47527313232422
  Blurry: loss=0.2404, psnr=11.203914642333984

✓ Loss function tests passed!


---

# Part 5: Training (60 minutes)

## Dataset (provided)

In [20]:
import random

In [21]:
class SARPatchDataset(Dataset):
    def __init__(self, patches, augment=True, normalisestats = None):
        self.patches = patches.astype(np.float32)
        self.augment = augment
        self.normalisestats = normalisestats

        assert len(patches.shape) == 3, f"Expected (N, H, W), got {patches.shape}"
        assert patches.min() >= 0 and patches.max() <= 1, \
            f"Expected [0,1] range, got [{patches.min()}, {patches.max()}]"
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, idx):
        
        patch = self.patches[idx].copy()
        
        if self.augment:
            patch = self._augment(patch)
        
        # Add channel dimension: (H, W) → (1, H, W)
        patch = torch.from_numpy(patch).unsqueeze(0)
        
        return patch
    
    def _augment(self, patch: np.ndarray) -> np.ndarray:
        """
        Apply random augmentations.
        
        For SAR, safe augmentations are:
        - Horizontal flip
        - Vertical flip
        - 90° rotations
        """
        # Random horizontal flip
        if random.random() > 0.5:
            patch = np.fliplr(patch).copy()
        
        # Random vertical flip
        if random.random() > 0.5:
            patch = np.flipud(patch).copy()
        
        # Random 90° rotation (0, 90, 180, or 270 degrees)
        k = random.randint(0, 3)
        if k > 0:
            patch = np.rot90(patch, k).copy()
        
        return patch

In [22]:
class SARDataModule:
    """
    Data module for managing train/validation data.
    
    Handles:
    - Loading patches from disk
    - Splitting into train/validation
    - Creating DataLoaders
    """
    
    def __init__(self, patches_path, val_fraction= 0.1,batch_size = 16, num_workers= 4, augment_train= True, seed = 42):
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        print(f"Loading patches from {patches_path}")
        all_patches = np.load(patches_path)
        print(f"Loaded {len(all_patches)} patches of shape {all_patches.shape[1:]}")
        
        # Split into train/validation
        np.random.seed(seed)
        indices = np.random.permutation(len(all_patches))
        val_size = int(len(all_patches) * val_fraction)
        
        val_indices = indices[:val_size]
        train_indices = indices[val_size:]
        
        self.train_patches = all_patches[train_indices]
        self.val_patches = all_patches[val_indices]
        
        print(f"Train: {len(self.train_patches)}, Validation: {len(self.val_patches)}")
        
        # Create datasets
        self.train_dataset = SARPatchDataset(self.train_patches, augment=augment_train)
        self.val_dataset = SARPatchDataset(self.val_patches, augment=False)
        
    def train_dataloader(self) -> DataLoader:
        """Get training DataLoader."""
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True,  # Drop incomplete batches for stable BN
        )
    
    def val_dataloader(self) -> DataLoader:
        """Get validation DataLoader."""
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )
    
    def get_sample_batch(self, split: str = 'train') -> torch.Tensor:
        """Get a sample batch for visualization."""
        loader = self.train_dataloader() if split == 'train' else self.val_dataloader()
        return next(iter(loader))

In [23]:
def test_dataset():
    """Test dataset and dataloader."""
    print("=" * 60)
    print("DATASET TEST")
    print("=" * 60)
    
    # Create synthetic patches for testing
    np.random.seed(42)
    test_patches = np.random.rand(100, 256, 256).astype(np.float32)
    np.save('test_patches.npy', test_patches)
    
    # Test DataModule
    data_module = SARDataModule(
        patches_path='test_patches.npy',
        val_fraction=0.2,
        batch_size=8,
        num_workers=0,  # 0 for testing
    )
    
    # Test train loader
    train_loader = data_module.train_dataloader()
    train_batch = next(iter(train_loader))
    
    print(f"\nTrain batch shape: {train_batch.shape}")
    print(f"Train batch range: [{train_batch.min():.3f}, {train_batch.max():.3f}]")
    
    # Test validation loader
    val_loader = data_module.val_dataloader()
    val_batch = next(iter(val_loader))
    
    print(f"Val batch shape: {val_batch.shape}")
    print(f"Val batch range: [{val_batch.min():.3f}, {val_batch.max():.3f}]")
    
    # Verify augmentation creates variety
    dataset = data_module.train_dataset
    patch1 = dataset[0]
    patch2 = dataset[0]  # Same index, should be different due to augmentation
    
    if torch.allclose(patch1, patch2):
        print("\nWARNING: Augmentation may not be working")
    else:
        print(f"\nAugmentation verified: same index gives different patches")
    
    # Cleanup
    import os
    os.remove('test_patches.npy')
    
    print("\n✓ Dataset test passed!")


test_dataset()

DATASET TEST
Loading patches from test_patches.npy
Loaded 100 patches of shape (256, 256)
Train: 80, Validation: 20

Train batch shape: torch.Size([8, 1, 256, 256])
Train batch range: [0.000, 1.000]
Val batch shape: torch.Size([8, 1, 256, 256])
Val batch range: [0.000, 1.000]

Augmentation verified: same index gives different patches

✓ Dataset test passed!


## Training Functions

In [44]:
from datetime import datetime
from typing import Dict, Optional, Callable
import json
from torch.utils.tensorboard import SummaryWriter
class Trainer:

    
    def __init__(self, 
                 model,
                 train_loader,
                 val_loader,
                 loss_fn,
                 config,
                 device = None):

        # Device setup
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(device)
        print(f"Using device: {self.device}")
        
        # Model
        self.model = model.to(self.device)
        
        # Data
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        # Loss
        self.loss_fn = loss_fn.to(self.device)
        
        # Config
        self.config = config
        
        # Optimizer
        self.optimizer = optim.Adam(
            model.parameters(),
            lr=config.get('learning_rate', 1e-4),
            betas=config.get('betas', (0.9, 0.999)),
            weight_decay=config.get('weight_decay', 0),
        )
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=config.get('lr_factor', 0.5),
            patience=config.get('lr_patience', 10),
        )
        
        # Logging
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_dir = Path(config.get('log_dir', 'runs')) / timestamp
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.writer = SummaryWriter(self.log_dir)
        
        # Save config
        with open(self.log_dir / 'config.json', 'w') as f:
            json.dump(config, f, indent=2)
        
        # Checkpointing
        self.checkpoint_dir = Path(config.get('checkpoint_dir', 'checkpoints'))
        self.checkpoint_dir.mkdir(exist_ok=True)
        
        # Training state
        self.epoch = 0
        self.global_step = 0
        self.best_val_loss = float('inf')
        self.epochs_without_improvement = 0
        
    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        
        epoch_metrics = {'loss': 0, 'mse': 0, 'ssim': 0, 'psnr': 0}
        num_batches = 0
        
        pbar = tqdm(self.train_loader, desc=f"Epoch {self.epoch+1} [Train]")
        
        for batch in pbar:
            x = batch.to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            x_hat, z = self.model(x)
            
            # Compute loss
            loss, metrics = self.loss_fn(x_hat, x)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            max_grad_norm = self.config.get('max_grad_norm', 1.0)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
            
            # Update weights
            self.optimizer.step()
            
            # Accumulate metrics
            for key in epoch_metrics:
                if key in metrics:
                    epoch_metrics[key] += metrics[key]
            num_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{metrics['loss']:.4f}",
                'psnr': f"{metrics.get('psnr', 0):.1f}",
            })
            
            # Log to TensorBoard
            self.writer.add_scalar('train/loss_step', metrics['loss'], self.global_step)
            self.global_step += 1
        
        # Average metrics
        for key in epoch_metrics:
            epoch_metrics[key] /= num_batches
        
        return epoch_metrics
    
    @torch.no_grad()
    def validate(self) -> Dict[str, float]:
        """Validate on validation set."""
        self.model.eval()
        
        epoch_metrics = {'loss': 0, 'mse': 0, 'ssim': 0, 'psnr': 0}
        num_batches = 0
        
        for batch in tqdm(self.val_loader, desc=f"Epoch {self.epoch+1} [Val]"):
            x = batch.to(self.device)
            x_hat, z = self.model(x)
            
            loss, metrics = self.loss_fn(x_hat, x)
            
            for key in epoch_metrics:
                if key in metrics:
                    epoch_metrics[key] += metrics[key]
            num_batches += 1
        
        # Average metrics
        for key in epoch_metrics:
            epoch_metrics[key] /= num_batches
        
        return epoch_metrics
    
    def save_checkpoint(self, is_best: bool = False):
        """Save model checkpoint."""
        checkpoint = {
            'epoch': self.epoch,
            'global_step': self.global_step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_loss': self.best_val_loss,
            'config': self.config,
        }
        
        # Save latest
        torch.save(checkpoint, self.checkpoint_dir / 'latest.pth')
        
        # Save best
        if is_best:
            torch.save(checkpoint, self.checkpoint_dir / 'best.pth')
            print(f"  ✓ New best model saved (val_loss: {self.best_val_loss:.4f})")
    
    def load_checkpoint(self, path: str):
        """Load checkpoint."""
        checkpoint = torch.load(path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.epoch = checkpoint['epoch']
        self.global_step = checkpoint['global_step']
        self.best_val_loss = checkpoint['best_val_loss']
        
        print(f"Loaded checkpoint from epoch {self.epoch}")
    
    @torch.no_grad()
    def log_images(self, num_images: int = 4):
        """Log sample reconstructions to TensorBoard."""
        self.model.eval()
        
        # Get sample batch
        batch = next(iter(self.val_loader))[:num_images].to(self.device)
        x_hat, z = self.model(batch)
        
        # Create grid: original | reconstructed | difference
        from torchvision.utils import make_grid
        
        # Original
        self.writer.add_images('val/original', batch, self.epoch)
        
        # Reconstructed
        self.writer.add_images('val/reconstructed', x_hat, self.epoch)
        
        # Difference (scaled for visibility)
        diff = torch.abs(batch - x_hat)
        diff = diff / diff.max() if diff.max() > 0 else diff
        self.writer.add_images('val/difference', diff, self.epoch)
    
    def train(self, epochs: int, early_stopping_patience: int = 20):
        """
        Main training loop.
        
        Args:
            epochs: Number of epochs to train
            early_stopping_patience: Stop if no improvement for this many epochs
        """
        print("\n" + "=" * 70)
        print("TRAINING START")
        print("=" * 70)
        print(f"Epochs: {epochs}")
        print(f"Device: {self.device}")
        print(f"Training samples: {len(self.train_loader.dataset)}")
        print(f"Validation samples: {len(self.val_loader.dataset)}")
        print(f"Batch size: {self.train_loader.batch_size}")
        print(f"Log directory: {self.log_dir}")
        print("=" * 70 + "\n")
        
        for epoch in range(epochs):
            self.epoch = epoch
            
            # Train
            train_metrics = self.train_epoch()
            
            # Validate
            val_metrics = self.validate()
            
            # Log metrics
            for key, value in train_metrics.items():
                self.writer.add_scalar(f'train/{key}', value, epoch)
            for key, value in val_metrics.items():
                self.writer.add_scalar(f'val/{key}', value, epoch)
            
            # Log learning rate
            lr = self.optimizer.param_groups[0]['lr']
            self.writer.add_scalar('train/learning_rate', lr, epoch)
            
            # Log images periodically
            if epoch % 5 == 0:
                self.log_images()
            
            # Update scheduler
            self.scheduler.step(val_metrics['loss'])
            
            # Check for improvement
            is_best = val_metrics['loss'] < self.best_val_loss
            if is_best:
                self.best_val_loss = val_metrics['loss']
                self.epochs_without_improvement = 0
            else:
                self.epochs_without_improvement += 1
            
            # Save checkpoint
            self.save_checkpoint(is_best=is_best)
            
            # Print epoch summary
            print(f"\nEpoch {epoch+1}/{epochs}")
            print(f"  Train: loss={train_metrics['loss']:.4f}, "
                  f"psnr={train_metrics.get('psnr', 0):.2f}, "
                  f"ssim={train_metrics.get('ssim', 0):.4f}")
            print(f"  Val:   loss={val_metrics['loss']:.4f}, "
                  f"psnr={val_metrics.get('psnr', 0):.2f}, "
                  f"ssim={val_metrics.get('ssim', 0):.4f}")
            
            # Early stopping
            if self.epochs_without_improvement >= early_stopping_patience:
                print(f"\nEarly stopping: no improvement for {early_stopping_patience} epochs")
                break
        
        print("\n" + "=" * 70)
        print("TRAINING COMPLETE")
        print(f"Best validation loss: {self.best_val_loss:.4f}")
        print("=" * 70)
        
        self.writer.close()


def main():
    """Example training script."""
    # Configuration
    config = {
        'latent_channels': 64,
        'batch_size': 16,
        'learning_rate': 1e-4,
        'mse_weight': 1.0,
        'ssim_weight': 0.1,
        'max_grad_norm': 1.0,
        'lr_factor': 0.5,
        'lr_patience': 10,
        'log_dir': 'runs',
        'checkpoint_dir': 'checkpoints',
    }
    
    # Create synthetic data for testing
    print("Creating synthetic test data...")
    np.random.seed(42)
    test_patches = np.random.rand(500, 256, 256).astype(np.float32)
    np.save('test_patches.npy', test_patches)
    
    # Create data module
    data_module = SARDataModule(
        patches_path='test_patches.npy',
        val_fraction=0.1,
        batch_size=config['batch_size'],
        num_workers=0,
    )
    
    # Create model
    model = SARAutoencoder(latent_channels=config['latent_channels'])
    
    # Create loss function
    loss_fn = CombinedLoss(
        mse_weight=config['mse_weight'],
        ssim_weight=config['ssim_weight']
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        train_loader=data_module.train_dataloader(),
        val_loader=data_module.val_dataloader(),
        loss_fn=loss_fn,
        config=config,
    )
    
    # Train
    trainer.train(epochs=5, early_stopping_patience=10)
    
    # Cleanup
    import os
    os.remove('test_patches.npy')

main()

Creating synthetic test data...
Loading patches from test_patches.npy
Loaded 500 patches of shape (256, 256)
Train: 450, Validation: 50
Using device: cuda

TRAINING START
Epochs: 5
Device: cuda
Training samples: 450
Validation samples: 50
Batch size: 16
Log directory: runs\20260118_112047



Epoch 1 [Train]: 100%|██████████| 28/28 [00:01<00:00, 19.66it/s, loss=0.1767, psnr=10.8]
Epoch 1 [Val]: 100%|██████████| 4/4 [00:00<00:00, 69.21it/s]


  ✓ New best model saved (val_loss: 0.1794)

Epoch 1/5
  Train: loss=0.1811, psnr=10.67, ssim=0.0478
  Val:   loss=0.1794, psnr=10.79, ssim=0.0393


Epoch 2 [Train]: 100%|██████████| 28/28 [00:01<00:00, 25.43it/s, loss=0.1681, psnr=10.9]
Epoch 2 [Val]: 100%|██████████| 4/4 [00:00<00:00, 61.49it/s]


  ✓ New best model saved (val_loss: 0.1710)

Epoch 2/5
  Train: loss=0.1726, psnr=10.89, ssim=0.0892
  Val:   loss=0.1710, psnr=10.95, ssim=0.0945


Epoch 3 [Train]: 100%|██████████| 28/28 [00:01<00:00, 25.42it/s, loss=0.1585, psnr=11.1]
Epoch 3 [Val]: 100%|██████████| 4/4 [00:00<00:00, 60.18it/s]


  ✓ New best model saved (val_loss: 0.1584)

Epoch 3/5
  Train: loss=0.1629, psnr=11.02, ssim=0.1616
  Val:   loss=0.1584, psnr=11.08, ssim=0.1954


Epoch 4 [Train]: 100%|██████████| 28/28 [00:01<00:00, 25.37it/s, loss=0.1511, psnr=11.2]
Epoch 4 [Val]: 100%|██████████| 4/4 [00:00<00:00, 60.65it/s]


  ✓ New best model saved (val_loss: 0.1510)

Epoch 4/5
  Train: loss=0.1543, psnr=11.15, ssim=0.2250
  Val:   loss=0.1510, psnr=11.19, ssim=0.2507


Epoch 5 [Train]: 100%|██████████| 28/28 [00:01<00:00, 25.57it/s, loss=0.1455, psnr=11.3]
Epoch 5 [Val]: 100%|██████████| 4/4 [00:00<00:00, 60.25it/s]


  ✓ New best model saved (val_loss: 0.1458)

Epoch 5/5
  Train: loss=0.1478, psnr=11.25, ssim=0.2715
  Val:   loss=0.1458, psnr=11.28, ssim=0.2877

TRAINING COMPLETE
Best validation loss: 0.1458


---

# Part 6: Run Training

In [45]:
import struct

def get_npy_shape(filepath):
    with open(filepath, 'rb') as f:
        f.read(8)
        header_len = struct.unpack('<H', f.read(2))[0]
        header = f.read(header_len).decode('latin1')
        shape_start = header.find('(') + 1
        shape_end = header.find(')')
        shape_str = header[shape_start:shape_end]
        shape = tuple(int(x.strip()) for x in shape_str.split(',') if x.strip())
        return shape

In [46]:

class LazyPatchDataset(Dataset):
    def __init__(self, metadata, shuffle_idx, augment=True):
        self.files = [Path(f) for f, _ in metadata['file_index']]
        self.cumsum = [0]
        for _, n in metadata['file_index']:
            self.cumsum.append(self.cumsum[-1] + n)
        self.shuffle_idx = shuffle_idx
        self.augment = augment
        self._cache_idx = None
        self._cache_data = None
    
    def __len__(self):
        return len(self.shuffle_idx)
    
    def __getitem__(self, idx):
        real_idx = self.shuffle_idx[idx]
        file_idx = next(i for i, (s, e) in enumerate(zip(self.cumsum[:-1], self.cumsum[1:])) if s <= real_idx < e)
        local_idx = real_idx - self.cumsum[file_idx]
        
        if self._cache_idx != file_idx:
            self._cache_idx = file_idx
            self._cache_data = np.load(self.files[file_idx])
        
        patch = self._cache_data[local_idx].copy()
        
        if self.augment:
            if random.random() > 0.5: patch = np.fliplr(patch).copy()
            if random.random() > 0.5: patch = np.flipud(patch).copy()
            k = random.randint(0, 3)
            if k: patch = np.rot90(patch, k).copy()
        
        return torch.from_numpy(patch).unsqueeze(0).float()

In [47]:
outputdir = project_root / "data" / "patches"
metadata = np.load(outputdir / 'metadata.npy', allow_pickle=True).item()
shuffle_idx = np.load(outputdir / "shuffle_idx.npy")

print(f"Total patches: {metadata['num_patches']:,}")
print(f"Global bounds: [{metadata['vmin']:.2f}, {metadata['vmax']:.2f}] dB")

Total patches: 696,277
Global bounds: [14.77, 24.54] dB


In [39]:
val_size = 200
train_size = 1000
print("Loading datasets...")
train_dataset = LazyPatchDataset(metadata, shuffle_idx[:train_size], augment=True)
val_dataset = LazyPatchDataset(metadata, shuffle_idx[train_size:train_size + val_size], augment=False)

print("Creating train dataset...")

print(f"Train: {len(train_dataset):,}, Val: {len(val_dataset):,}")

batch_size = 16
print("Creating dataloaders...")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                          num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=0, pin_memory=True)

# Test batch
batch = next(iter(train_loader))
print(f"Batch shape: {batch.shape}, range: [{batch.min():.3f}, {batch.max():.3f}]")



Loading datasets...
Creating train dataset...
Train: 1,000, Val: 200
Creating dataloaders...
Batch shape: torch.Size([16, 1, 256, 256]), range: [0.000, 1.000]


In [48]:
model = SARAutoencoder(latent_channels=64).to(device)
loss_fn = CombinedLoss(mse_weight=1.0, ssim_weight=0.1).to(device)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Compression: {model.get_compression_ratio():.1f}x")

model.eval()
model.to(device)
with torch.no_grad():
    x = batch.to(device)
    x_hat, z = model(x)
    _, metrics = loss_fn(x_hat, x)

print(f"\nBaseline (untrained):")
print(f"  Loss: {metrics['loss']:.4f}")
print(f"  PSNR: {metrics['psnr']:.2f} dB")
print(f"  SSIM: {metrics['ssim']:.4f}")
print(f"  Latent shape: {z.shape}")

Parameters: 2,872,257
Compression: 4.0x

Baseline (untrained):
  Loss: 0.1371
  PSNR: 13.20 dB
  SSIM: 0.1073
  Latent shape: torch.Size([16, 64, 16, 16])


In [49]:
config = {
    'latent_channels': 64,
    'batch_size': batch_size,
    'learning_rate': 1e-4,
    'mse_weight': 1.0,
    'ssim_weight': 0.1,
    'max_grad_norm': 1.0,
    'lr_factor': 0.5,
    'lr_patience': 5,
    'log_dir': str(project_root / 'runs'),
    'checkpoint_dir': str(project_root / 'checkpoints'),
}

In [50]:
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

CUDA available: True
GPU: NVIDIA GeForce RTX 3070


In [43]:
trainer = Trainer(model, train_loader, val_loader, loss_fn, config)
trainer.train(epochs=3, early_stopping_patience=5)

Using device: cuda

TRAINING START
Epochs: 3
Device: cuda
Training samples: 1000
Validation samples: 200
Batch size: 16
Log directory: d:\Projects\CNNAutoencoderProject\runs\20260118_105754



Epoch 1 [Train]:   0%|          | 0/62 [00:17<?, ?it/s]


KeyboardInterrupt: 

In [None]:
val_size = metadata['num_patches'] // 10
train_dataset = LazyPatchDataset(metadata, shuffle_idx[val_size:], augment=True)
val_dataset = LazyPatchDataset(metadata, shuffle_idx[:val_size], augment=False)

print(f"Train: {len(train_dataset):,}, Val: {len(val_dataset):,}")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, 
                          num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False,
                        num_workers=0, pin_memory=True)

model = SARAutoencoder(latent_channels=64)
loss_fn = CombinedLoss(mse_weight=1.0, ssim_weight=0.1)

config['lr_patience'] = 10
trainer = Trainer(model, train_loader, val_loader, loss_fn, config)
trainer.train(epochs=50, early_stopping_patience=15)

Train: 626,650, Val: 69,627
Using device: cuda

TRAINING START
Epochs: 50
Device: cuda
Training samples: 626650
Validation samples: 69627
Batch size: 16
Log directory: d:\Projects\CNNAutoencoderProject\runs\20260118_112127



Epoch 1 [Train]:   0%|          | 1/39165 [01:05<715:23:40, 65.76s/it, loss=0.1527, psnr=12.2]

In [None]:
checkpoint = torch.load(project_root / 'checkpoints' / 'best.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best val loss: {checkpoint['best_val_loss']:.4f}")

In [None]:
model.eval()
model.to(device)

with torch.no_grad():
    sample = next(iter(val_loader))[:8].to(device)
    recon, latent = model(sample)

loss_fn_eval = CombinedLoss(mse_weight=1.0, ssim_weight=0.1)
_, metrics = loss_fn_eval(recon, sample)
print(f"Sample metrics: PSNR={metrics['psnr']:.2f} dB, SSIM={metrics['ssim']:.4f}")

In [None]:
# outputdir = project_root / "data" / "patches"
# metadata = np.load(outputdir / 'metadata.npy', allow_pickle=True).item()
# shuffle_idx = np.load(outputdir / "shuffle_idx.npy")

# print(f"Total patches: {metadata['num_patches']:,}")

# # Split
# val_size = metadata['num_patches'] // 10

# train_dataset = LazyPatchDataset(metadata, shuffle_idx[:1000], augment=True)
# val_dataset = LazyPatchDataset(metadata, shuffle_idx[1000:1200], augment=False)

# print(f"Train: {len(train_dataset):,}, Val: {len(val_dataset):,}")

# # Dataloaders
# batch_size = 16
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
#                           num_workers=0, pin_memory=True, drop_last=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
#                         num_workers=0, pin_memory=True)

# # Model
# model = SARAutoencoder(latent_channels=64)
# loss_fn = CombinedLoss(mse_weight=1.0, ssim_weight=0.1)
# print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# # Config
# config = {
#     'latent_channels': 64,
#     'batch_size': batch_size,
#     'learning_rate': 1e-4,
#     'mse_weight': 1.0,
#     'ssim_weight': 0.1,
#     'max_grad_norm': 1.0,
#     'lr_factor': 0.5,
#     'lr_patience': 10,
#     'log_dir': str(project_root / 'runs'),
#     'checkpoint_dir': str(project_root / 'checkpoints'),
# }

# # Train
# trainer = Trainer(model, train_loader, val_loader, loss_fn, config)
# trainer.train(epochs=3, early_stopping_patience=5)

Total patches: 696,277
Train: 1,000, Val: 200
Parameters: 2,872,257
Using device: cpu

TRAINING START
Epochs: 3
Device: cpu
Training samples: 1000
Validation samples: 200
Batch size: 16
Log directory: d:\Projects\CNNAutoencoderProject\runs\20260118_094427



Epoch 1 [Train]:  15%|█▍        | 9/62 [11:41<1:08:51, 77.95s/it, loss=0.1186, psnr=15.4]


KeyboardInterrupt: 

In [None]:
fig, axes = plt.subplots(3, 8, figsize=(16, 6))

for i in range(8):
    # Original
    axes[0, i].imshow(sample[i, 0].cpu(), cmap='gray', vmin=0, vmax=1)
    axes[0, i].axis('off')
    if i == 0: axes[0, i].set_ylabel('Original', fontsize=12)
    
    # Reconstructed
    axes[1, i].imshow(recon[i, 0].cpu(), cmap='gray', vmin=0, vmax=1)
    axes[1, i].axis('off')
    if i == 0: axes[1, i].set_ylabel('Reconstructed', fontsize=12)
    
    # Difference
    diff = torch.abs(sample[i, 0] - recon[i, 0]).cpu()
    axes[2, i].imshow(diff, cmap='hot', vmin=0, vmax=0.2)
    axes[2, i].axis('off')
    if i == 0: axes[2, i].set_ylabel('Difference', fontsize=12)

plt.suptitle(f'Reconstruction Results (PSNR: {metrics["psnr"]:.1f} dB, SSIM: {metrics["ssim"]:.3f})', fontsize=14)
plt.tight_layout()
plt.savefig(project_root / 'checkpoints' / 'reconstruction_samples.png', dpi=150)
plt.show()

In [None]:
print(f"Latent shape: {latent.shape}")
print(f"Latent stats: mean={latent.mean():.3f}, std={latent.std():.3f}")

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
    axes[0, i].imshow(latent[0, i].cpu(), cmap='viridis')
    axes[0, i].axis('off')
    axes[0, i].set_title(f'Ch {i}')
    
    axes[1, i].imshow(latent[0, i+8].cpu(), cmap='viridis')
    axes[1, i].axis('off')
    axes[1, i].set_title(f'Ch {i+8}')

plt.suptitle('Latent Space Channels (first 16 of 64)', fontsize=12)
plt.tight_layout()
plt.show()

---

# Done!

If all tests pass and you can train, copy your implementations to:
- `src/models/blocks.py` → ConvBlock, DeconvBlock
- `src/models/encoder.py` → SAREncoder
- `src/models/decoder.py` → SARDecoder
- `src/models/autoencoder.py` → SARAutoencoder
- `src/losses/ssim.py` → SSIMLoss
- `src/losses/combined.py` → CombinedLoss
- `src/training/trainer.py` → training functions