# Building the Vesuvius 3D Surface Detector





1.  **3D Volumetric Data**: Using `Conv3d` layers instead of `Conv2d`.
2.  **Low Memory Constraints (8GB VRAM)**: Using `GroupNorm` (stable at batch size 1) and Gradient Checkpointing.

In [None]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. The Building Block: Residual Layer

A standard U-Net uses simple convolution blocks. We upgrade this to **Residual Blocks** (ResNet style) which allow gradients to flow better during training, enabling deeper networks.

### Why GroupNorm?
*   **BatchNorm**: Normalizes across the Batch dimension. Requires large batches (e.g., 16+) to be accurate. We can only fit Batch Size = 1.
*   **GroupNorm**: Normalizes across the Channel dimension. Works perfectly even with Batch Size = 1.

We create a block that does: `Conv3D -> GroupNorm -> LeakyReLU -> Conv3D -> GroupNorm`. And adds the original input (`x`) back at the end.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 1st Convolution
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.gn1 = nn.GroupNorm(8, out_channels) # GroupNorm is key for small batches
        self.act1 = nn.LeakyReLU(inplace=True)
        
        # 2nd Convolution
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.gn2 = nn.GroupNorm(8, out_channels)
        self.act2 = nn.LeakyReLU(inplace=True)

        # Shortcut connection (to match dimensions if in != out)
        if in_channels != out_channels:
            self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        original = x
        
        # Pass through layers
        x = self.conv1(x)
        x = self.gn1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.gn2(x)
        
        # Add residual (the "Res" in ResNet)
        residual = self.shortcut(original)
        x += residual
        return self.act2(x)

## 2. The Architecture: U-Net Encoder & Decoder

Now we assemble the blocks.

*   **Encoder (Down)**: Reduces spatial size (128->64->32), increases depth (Features 16->32->64). Captures context.
*   **Decoder (Up)**: Increases spatial size (32->64->128), decreases depth. Refines details.
*   **Skip Connections**: We modify the Forward pass to concatenate Encoder features with Decoder features. This preserves fine-grained details.
*   **Gradient Checkpointing**: This is a memory trick! Instead of storing ALL activations in memory for backprop, we re-compute them on the fly. This makes training **slower (30%)** but uses **much less VRAM (50%)**.

In [None]:
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=16):
        super().__init__()
        
        # --- ENCODER ---
        self.enc1 = ResidualBlock(in_channels, init_features)
        self.pool1 = nn.MaxPool3d(2, 2)
        
        self.enc2 = ResidualBlock(init_features, init_features * 2)
        self.pool2 = nn.MaxPool3d(2, 2)
        
        self.enc3 = ResidualBlock(init_features * 2, init_features * 4)
        self.pool3 = nn.MaxPool3d(2, 2)
        
        # --- BOTTLENECK ---
        self.bottleneck = ResidualBlock(init_features * 4, init_features * 8)
        
        # --- DECODER ---
        # Uses Transpose Conv to upsample
        self.up3 = nn.ConvTranspose3d(init_features * 8, init_features * 4, kernel_size=2, stride=2)
        self.dec3 = ResidualBlock(init_features * 8, init_features * 4)
        
        self.up2 = nn.ConvTranspose3d(init_features * 4, init_features * 2, kernel_size=2, stride=2)
        self.dec2 = ResidualBlock(init_features * 4, init_features * 2)
        
        self.up1 = nn.ConvTranspose3d(init_features * 2, init_features, kernel_size=2, stride=2)
        self.dec1 = ResidualBlock(init_features * 2, init_features)
        
        # --- FINAL --- 
        self.final = nn.Conv3d(init_features, out_channels, kernel_size=1)
        self.activation = nn.Sigmoid() # Output 0-1 probability

    def forward(self, x):
        # Encoder L1
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        
        # Encoder L2 (We use checkpointing here to save RAM)
        if self.training:
            e2 = checkpoint.checkpoint(self.run_enc2, p1, use_reentrant=False)
        else:
            e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        
        # Encoder L3 (Checkpoint here too)
        if self.training:
            e3 = checkpoint.checkpoint(self.run_enc3, p2, use_reentrant=False)
        else:
            e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        
        # Bottleneck
        b = self.bottleneck(p3)
        
        # Decoder L3
        u3 = self.up3(b)
        # Skip Connection: Concatenate u3 with e3
        d3 = self.dec3(torch.cat((u3, e3), dim=1))
        
        # Decoder L2
        u2 = self.up2(d3)
        d2 = self.dec2(torch.cat((u2, e2), dim=1))
        
        # Decoder L1
        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat((u1, e1), dim=1))
        
        return self.activation(self.final(d1))
    
    # Helpers for checkpointing
    def run_enc2(self, x): return self.enc2(x)
    def run_enc3(self, x): return self.enc3(x)

## 3. Sanity Check

Let's put a dummy 3D volume into the model and see if it outputs the correct shape.

Input Shape: `(Batch=1, Channel=1, Depth=32, Height=128, Width=128)`
Expected Output: `(1, 1, 32, 128, 128)`

In [None]:
# Instantiate Model
model = UNet3D(init_features=16).to(device)
print("Model created successfully on GPU.")

# Create Dummy Input
dummy_input = torch.randn(1, 1, 32, 128, 128).to(device)
print(f"Input Shape: {dummy_input.shape}")

# Forward Pass
with torch.no_grad():
    output = model(dummy_input)
    
print(f"Output Shape: {output.shape}")

if output.shape == dummy_input.shape:
    print("✅ Architecture verification passed! Output matches Input dimensions.")
else:
    print("❌ Mismatch! Something is wrong with padding or strides.")