# Capstone Project: Advanced Model Compression

**Objective:** Reduce model size from ~5MB to < 500kB while maintaining accuracy.

**Methodology:**
1. **Structured Pruning:** We will physically remove 50% of the convolutional filters (channels). This changes the architecture dimensions (e.g., 32 filters -> 16 filters).
2. **Weight Transplantation:** Instead of training the small model from scratch, we copy the "most important" weights (based on L1-Norm) from the large trained model to the small model.
3. **Fine-Tuning:** We train the small model briefly to heal the accuracy drop.
4. **Dynamic Quantization:** We convert the float32 weights to int8 for a final 4x size reduction.

In [None]:
import os
import copy
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "checkpoints/final_best.pt"
IMG_SIZE = 128
PRUNE_AMOUNT = 0.6  # Target: Remove 50% of channels
FINE_TUNE_EPOCHS = 5
FINE_TUNE_LR = 1e-4

print(f"Using Device: {DEVICE}")

Using Device: cuda


In [22]:
# --- 1. Flexible Model Architecture ---
# We modify the CustomCNN to accept a specific list of channel counts.
# This allows us to build the "Child" model with exact pruned dimensions.

class CustomCNN(nn.Module):
    def __init__(self, reg_strength=0.0, dropout_conv=0.1, dropout_dense=0.1, 
                 dense_units=512, filters_multiplier=1.0, 
                 channel_list=None): 
        super().__init__()
        
        # If channel_list is explicitly provided (for pruning), use it.
        # Otherwise, calculate standard sizes based on multiplier.
        if channel_list is not None:
            f1, f2, f3 = channel_list
        else:
            f1 = max(8, int(32 * filters_multiplier))
            f2 = max(16, int(64 * filters_multiplier))
            f3 = max(32, int(128 * filters_multiplier))

        self.channels = [f1, f2, f3]
        
        # IMPORTANT: We define the layers precisely so we can index them later.
        # Structure: [0]Conv -> [1]ReLU -> [2]BN -> [3]Conv -> [4]ReLU -> [5]Pool -> [6]Drop
        self.block1 = nn.Sequential(
            nn.Conv2d(3, f1, 3, padding=1),      
            nn.ReLU(inplace=True),               
            nn.BatchNorm2d(f1),                  
            nn.Conv2d(f1, f1, 3, padding=1),     
            nn.ReLU(inplace=True),               
            nn.MaxPool2d(2),                     
            nn.Dropout(dropout_conv)             
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(f1, f2, 3, padding=1),     
            nn.ReLU(inplace=True),               
            nn.BatchNorm2d(f2),                  
            nn.Conv2d(f2, f2, 3, padding=1),     
            nn.ReLU(inplace=True),               
            nn.MaxPool2d(2),                     
            nn.Dropout(dropout_conv)             
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(f2, f3, 3, padding=1),     
            nn.ReLU(inplace=True),               
            nn.BatchNorm2d(f3),                  
            nn.Conv2d(f3, f3, 3, padding=1),     
            nn.ReLU(inplace=True),               
            nn.MaxPool2d(2),                     
            nn.Dropout(dropout_conv)             
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # Head Structure:
        # [0]Flat -> [1]Linear(In pruned) -> [2]ReLU -> [3]BN -> [4]Drop -> [5]Linear
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(f3, dense_units),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(dense_units),
            nn.Dropout(dropout_dense),
            nn.Linear(dense_units, 1)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.gap(x)
        return self.head(x)

In [23]:
# --- 2. Helpers --- 

def get_dataloaders():
    # We need training data to fine-tune the pruned model
    # We use standard transforms. Ensure paths match your directory structure.
    val_tfms = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])
    
    train_ds = datasets.ImageFolder("cats-v-non-cats/training/", transform=val_tfms)
    val_ds = datasets.ImageFolder("cats-v-non-cats/validation/", transform=val_tfms)
    
    return (
        DataLoader(train_ds, batch_size=64, shuffle=True),
        DataLoader(val_ds, batch_size=64, shuffle=False)
    )

def evaluate(model, loader):
    model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE).float()
            logits = model(x).squeeze(1)
            preds = (torch.sigmoid(logits) > 0.5).long()
            correct += (preds == y.long()).sum().item()
            total += x.size(0)
    return correct / max(total, 1)

def get_file_size(model, path="temp_model.pt"):
    torch.save(model.state_dict(), path)
    size_kb = os.path.getsize(path) / 1024
    os.remove(path)
    return size_kb

In [24]:
# --- 3. Load Baseline Model ---
print("Loading Baseline...")
ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
cfg = ckpt['config']

# Initialize original large model
baseline_model = CustomCNN(
    cfg.get('reg_strength', 0), cfg.get('dropout_conv', 0), cfg.get('dropout_dense', 0),
    int(cfg['dense_units']), float(cfg['filters_multiplier'])
).to(DEVICE)
baseline_model.load_state_dict(ckpt['state_dict'])

# Check baseline stats
train_dl, val_dl = get_dataloaders()
base_acc = evaluate(baseline_model, val_dl)
base_size = get_file_size(baseline_model)
print(f"Baseline Size: {base_size:.0f} KB | Baseline Acc: {base_acc:.4f}")

Loading Baseline...
Baseline Size: 5004 KB | Baseline Acc: 0.9691


In [26]:
# --- 4. Pruning Logic (The Engine) ---

def get_filter_norms(weight):
    # Calculate L1 norm for Conv2d weights.
    # Weight shape: [Out_Channels, In_Channels, Kernel, Kernel]
    # We sum over dimensions (1, 2, 3) to get a single value per Output Channel.
    return torch.sum(torch.abs(weight), dim=(1, 2, 3))

def prune_block(old_block, new_block, prev_idxs, keep_count):
    """
    Transfers weights from old_block to new_block, removing filters with low L1 norms.
    Returns the indices of the filters kept, to be used as input indices for the next layer.
    """
    
    # 1. Select Filters to Keep
    # We look at the 2nd Conv in the block (index 3) to decide the output shape of the block.
    w2 = old_block[3].weight.data
    norm = get_filter_norms(w2)
    _, current_idxs = torch.topk(norm, keep_count)
    current_idxs, _ = torch.sort(current_idxs) # Sort to preserve order

    # 2. Transfer Conv 1 (index 0)
    # Input Channels: Filtered by `prev_idxs` (from previous block)
    # Output Channels: Filtered by `current_idxs` (internal block consistency)
    w1_old = old_block[0].weight.data
    
    # Slicing: [Output_Idxs, Input_Idxs, :, :]
    if prev_idxs is None: # First block input is RGB (3 channels), keep all inputs
        w1_new = w1_old[current_idxs, :, :, :]
    else:
        w1_new = w1_old[current_idxs][:, prev_idxs, :, :]
        
    new_block[0].weight.data = w1_new
    new_block[0].bias.data = old_block[0].bias.data[current_idxs]

    # 3. Transfer Batch Norm (index 2)
    # Input/Output dimensions correspond to Conv1 Output (current_idxs)
    new_block[2].weight.data = old_block[2].weight.data[current_idxs]
    new_block[2].bias.data = old_block[2].bias.data[current_idxs]
    new_block[2].running_mean = old_block[2].running_mean[current_idxs]
    new_block[2].running_var = old_block[2].running_var[current_idxs]

    # 4. Transfer Conv 2 (index 3)
    # Input Channels: current_idxs (Output of Conv1)
    # Output Channels: current_idxs (Output of Block)
    w2_old = old_block[3].weight.data
    w2_new = w2_old[current_idxs][:, current_idxs, :, :]
    new_block[3].weight.data = w2_new
    new_block[3].bias.data = old_block[3].bias.data[current_idxs]

    return current_idxs

def transfer_head(old_head, new_head, prev_idxs):
    # 1. Linear Layer (Index 1) - Connects features to dense
    # Input dimensions correspond to Block 3 Output (prev_idxs)
    # Weight shape: [Out_Features, In_Features]
    old_w = old_head[1].weight.data
    new_head[1].weight.data = old_w[:, prev_idxs]
    new_head[1].bias.data = old_head[1].bias.data
    
    # 2. Batch Norm 1D (Index 3) - No shape change needed
    new_head[3].load_state_dict(old_head[3].state_dict())
    
    # 3. Final Linear (Index 5) - No shape change needed
    new_head[5].load_state_dict(old_head[5].state_dict())

In [33]:
# --- 5. Execute Pruning ---
print(f"\nStarting Pruning (Removing {PRUNE_AMOUNT*100:.0f}% of channels)...")

# 1. Define New Architecture Specs
c1 = int(baseline_model.channels[0] * (1 - PRUNE_AMOUNT))
c2 = int(baseline_model.channels[1] * (1 - PRUNE_AMOUNT))
c3 = int(baseline_model.channels[2] * (1 - PRUNE_AMOUNT))
new_channel_config = [c1, c2, c3]
print(f"Old Config: {baseline_model.channels}")
print(f"New Config: {new_channel_config}")

# 2. Initialize Child Model
slim_model = CustomCNN(
    cfg.get('reg_strength', 0), cfg.get('dropout_conv', 0), cfg.get('dropout_dense', 0),
    int(cfg['dense_units']), float(cfg['filters_multiplier']),
    channel_list=new_channel_config
).to(DEVICE)

# 3. Transplant Weights
idxs_1 = prune_block(baseline_model.block1, slim_model.block1, None, c1)
idxs_2 = prune_block(baseline_model.block2, slim_model.block2, idxs_1, c2)
idxs_3 = prune_block(baseline_model.block3, slim_model.block3, idxs_2, c3)
transfer_head(baseline_model.head, slim_model.head, idxs_3)

print("Pruning Complete. Evaluating initial accuracy drop...")
acc_drop = evaluate(slim_model, val_dl)
print(f"Accuracy (Before Fine-tuning): {acc_drop:.4f}")


Starting Pruning (Removing 50% of channels)...
Old Config: [60, 121, 242]
New Config: [30, 60, 121]
Pruning Complete. Evaluating initial accuracy drop...
Accuracy (Before Fine-tuning): 0.5075


In [28]:
# --- 6. Fine-Tuning ---
print(f"\nFine-tuning for {10} epochs...")
optimizer = torch.optim.Adam(slim_model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

slim_model.train()
for ep in range(10):
    loss_sum = 0
    for x, y in train_dl:
        x, y = x.to(DEVICE), y.to(DEVICE).float()
        optimizer.zero_grad()
        out = slim_model(x).squeeze(1)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
    print(f"Epoch {ep+1}: Loss {loss_sum:.4f}")

acc_tuned = evaluate(slim_model, val_dl)
size_pruned = get_file_size(slim_model)
print(f"Pruned + Tuned Accuracy: {acc_tuned:.4f}")
print(f"Pruned + Tuned Size:     {size_pruned:.0f} KB")


Fine-tuning for 10 epochs...
Epoch 1: Loss 84.7008
Epoch 2: Loss 63.6548
Epoch 3: Loss 57.4659
Epoch 4: Loss 52.5045
Epoch 5: Loss 48.1660
Epoch 6: Loss 45.8060
Epoch 7: Loss 42.8415
Epoch 8: Loss 40.6624
Epoch 9: Loss 38.3099
Epoch 10: Loss 37.1027
Pruned + Tuned Accuracy: 0.8081
Pruned + Tuned Size:     386 KB


In [None]:
#TODO THIS DOESN'T WORK YET. But, was able to get model dramatically smaller pre quanitization.



# --- 7. Static Quantization with Layer Fusion (FIXED) ---
import torch.quantization
from torch.quantization import fuse_modules

# 1. Define Wrapper with Stubs
class QuantizableCNN(CustomCNN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.gap(x)
        out = self.head(x)
        out = self.dequant(out)
        return out
    
    def fuse_model(self):
        # Fuse Conv+ReLU+BN in blocks
        # Standard pattern: Conv -> ReLU -> BN is WRONG for PyTorch fusion
        # Your model has: Conv -> ReLU -> BN -> Conv -> ReLU
        # PyTorch fusion usually supports Conv+BN+ReLU or Conv+BN.
        
        # Let's traverse and fuse manually where possible.
        # Block structure: [0]Conv, [1]ReLU, [2]BN, [3]Conv, [4]ReLU
        
        # Note: PyTorch usually fuses Conv+BN+ReLU. Your order is Conv->ReLU->BN.
        # This is non-standard for fusion. We will only fuse the SECOND conv in each block
        # which is Conv->ReLU (indices 3,4).
        # The first part (Conv->ReLU->BN) creates a barrier.
        # WE WILL FUSE [3] and [4] (Conv+ReLU).
        
        for m in self.modules():
            if type(m) == nn.Sequential and len(m) >= 5:
                # Try to fuse index 3 (Conv) and 4 (ReLU)
                if type(m[3]) == nn.Conv2d and type(m[4]) == nn.ReLU:
                    torch.quantization.fuse_modules(m, ['3', '4'], inplace=True)

print("\n--- Preparing Quantization-Ready Model ---")

# 2. Re-create Slim Model Wrapper
q_slim_model = QuantizableCNN(
    cfg.get('reg_strength', 0), cfg.get('dropout_conv', 0), cfg.get('dropout_dense', 0),
    int(cfg['dense_units']), float(cfg['filters_multiplier']),
    channel_list=slim_model.channels 
).to('cpu')

# 3. Copy Weights
q_slim_model.load_state_dict(slim_model.state_dict(), strict=False)
q_slim_model.eval()

# 4. FUSE LAYERS (Critical Step)
q_slim_model.fuse_model()

# 5. Configure Quantization (fbgemm for x86)
backend = 'fbgemm'
q_slim_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

# 6. Insert Observers & Calibrate
torch.quantization.prepare(q_slim_model, inplace=True)

print("Calibrating with validation data (CPU)...")
with torch.no_grad():
    for i, (x, _) in enumerate(train_dl):
        if i >= 20: break
        q_slim_model(x.to('cpu'))

# 7. Convert to INT8
# This step will fail if BN layers are still active on Int8 inputs.
# Since your architecture (Conv-ReLU-BN) is tricky to fuse perfectly,
# we set qconfig=None for the BN layers to keep them in Float32 (mixed precision).
# This bypasses the crash.

for module in q_slim_model.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.qconfig = None  # Disable quantization for BN to avoid crash

torch.quantization.convert(q_slim_model, inplace=True)

# 8. Save and Measure
FINAL_PATH = "checkpoints/model_capstone_static_quant.pt"
torch.save(q_slim_model.state_dict(), FINAL_PATH)
final_size = os.path.getsize(FINAL_PATH) / 1024

# 9. Final Evaluation
def evaluate_on_cpu(model, loader):
    model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to('cpu'), y.to('cpu').float()
            logits = model(x).squeeze(1)
            preds = (torch.sigmoid(logits) > 0.5).long()
            correct += (preds == y.long()).sum().item()
            total += x.size(0)
    return correct / max(total, 1)

print("="*40)
print(" CAPSTONE RESULTS (Static Quantization)")
print("="*40)
print(f"Baseline Size:      {base_size:.0f} kB")
print(f"Pruned Size (FP32): {size_pruned:.0f} kB")
print(f"Quantized Size:     {final_size:.0f} kB") 
print(f"Reduction Factor:   {base_size/final_size:.1f}x")

print("\nEvaluating Quantized Model on CPU...")
acc_final = evaluate_on_cpu(q_slim_model, val_dl)
print(f"Final INT8 Accuracy: {acc_final:.4f}")

if final_size < 500:
    print(f"\n✅ SUCCESS: {final_size:.0f}kB is under 500kB!")
else:
    print(f"\n⚠️ STILL TOO BIG: {final_size:.0f}kB.")
    print("Recommendation: Scroll up to Cell 4, set PRUNE_AMOUNT = 0.75, and 'Run All' below.")


--- Preparing Quantization-Ready Model ---
Calibrating with validation data (CPU)...
 CAPSTONE RESULTS (Static Quantization)
Baseline Size:      5004 kB
Pruned Size (FP32): 1519 kB
Quantized Size:     440 kB
Reduction Factor:   11.4x

Evaluating Quantized Model on CPU...


NotImplementedError: Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::native_batch_norm' is only available for these backends: [CPU, CUDA, Meta, MkldnnCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at aten\src\ATen\RegisterCPU.cpp:31420 [kernel]
CUDA: registered at aten\src\ATen\RegisterCUDA.cpp:44504 [kernel]
Meta: registered at /dev/null:488 [kernel]
MkldnnCPU: registered at aten\src\ATen\RegisterMkldnnCPU.cpp:515 [kernel]
BackendSelect: fallthrough registered at ..\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ..\aten\src\ATen\core\PythonFallbackKernel.cpp:154 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ..\aten\src\ATen\functorch\DynamicLayer.cpp:497 [backend fallback]
Functionalize: registered at ..\aten\src\ATen\FunctionalizeFallbackKernel.cpp:324 [backend fallback]
Named: registered at ..\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ..\aten\src\ATen\ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ..\aten\src\ATen\native\NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at ..\aten\src\ATen\ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ..\aten\src\ATen\core\VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradCPU: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradCUDA: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradHIP: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradXLA: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradMPS: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradIPU: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradXPU: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradHPU: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradVE: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradLazy: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradMTIA: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradPrivateUse1: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradPrivateUse2: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradPrivateUse3: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradMeta: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
AutogradNestedTensor: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:16277 [autograd kernel]
Tracer: registered at ..\torch\csrc\autograd\generated\TraceType_1.cpp:15950 [kernel]
AutocastCPU: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:378 [backend fallback]
AutocastCUDA: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:244 [backend fallback]
FuncTorchBatched: registered at ..\aten\src\ATen\functorch\BatchRulesNorm.cpp:864 [kernel]
BatchedNestedTensor: registered at ..\aten\src\ATen\functorch\LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ..\aten\src\ATen\functorch\VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at ..\aten\src\ATen\LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at ..\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ..\aten\src\ATen\functorch\TensorWrapper.cpp:202 [backend fallback]
PythonTLSSnapshot: registered at ..\aten\src\ATen\core\PythonFallbackKernel.cpp:162 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ..\aten\src\ATen\functorch\DynamicLayer.cpp:493 [backend fallback]
PreDispatch: registered at ..\aten\src\ATen\core\PythonFallbackKernel.cpp:166 [backend fallback]
PythonDispatcher: registered at ..\aten\src\ATen\core\PythonFallbackKernel.cpp:158 [backend fallback]


In [29]:
# --- 7. Final Evaluation & ONNX Export (No Quantization) ---
import torch.onnx

print("\n--- Exporting Pruned Model to ONNX ---")

# 1. Set up paths and dummy input
ONNX_PATH = "checkpoints/model_capstone_pruned.onnx"
os.makedirs("checkpoints", exist_ok=True)

# 2. Ensure model is in eval mode (fixes BatchNorm/Dropout behavior)
slim_model.eval()

# 3. Create dummy input on the same device as the model (CUDA or CPU)
# Shape: [Batch Size, Channels, Height, Width]
dummy_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE)

# 4. Export to ONNX
print(f"Exporting to {ONNX_PATH}...")
torch.onnx.export(
    slim_model,               # Model to export
    dummy_input,              # Model input (or a tuple for multiple inputs)
    ONNX_PATH,                # Where to save the model
    export_params=True,       # Store the trained parameter weights inside the model file
    opset_version=12,         # Standard ONNX opset version
    do_constant_folding=True, # Optimize constants
    input_names=['input'],    # Input tensor name
    output_names=['output'],  # Output tensor name
    dynamic_axes={'input': {0: 'batch_size'},  # Allow variable batch sizes
                  'output': {0: 'batch_size'}}
)

# 5. Measure Final Size
onnx_size_kb = os.path.getsize(ONNX_PATH) / 1024

print("="*40)
print(" CAPSTONE RESULTS (Pruned Only)")
print("="*40)
print(f"Baseline Model Size:  {base_size:.0f} kB")
print(f"Pruned ONNX Size:     {onnx_size_kb:.0f} kB")
print(f"Reduction Factor:     {base_size/onnx_size_kb:.1f}x")
print(f"Final Accuracy:       {acc_tuned:.4f}")

if onnx_size_kb < 500:
    print(f"\n✅ SUCCESS: {onnx_size_kb:.0f}kB is under 500kB!")
else:
    print(f"\n⚠️ STILL TOO BIG: {onnx_size_kb:.0f}kB.")
    print("Recommendation: Scroll up to 'Step 5' and increase PRUNE_AMOUNT to 0.7 or 0.8.")


--- Exporting Pruned Model to ONNX ---
Exporting to checkpoints/model_capstone_pruned.onnx...
 CAPSTONE RESULTS (Pruned Only)
Baseline Model Size:  5004 kB
Pruned ONNX Size:     379 kB
Reduction Factor:     13.2x
Final Accuracy:       0.8081

✅ SUCCESS: 379kB is under 500kB!
