# 03. Model Optimization and Deployment for Gravitational Lensing

**Objective**: Optimize the trained SwinIR model for efficient deployment. 

In scientific applications (e.g., satellite-based processing), computational resources are often limited. This notebook demonstrates:
1. **Weight Pruning**: Removing redundant connections to reduce model complexity.
2. **Quantization**: Converting model weights from FP32 to INT8 to reduce size and increase speed.
3. **ONNX Export**: Exporting the model for platform-independent, high-performance inference.

### 1. Setup and Model Loading

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import os
import sys
import time
import numpy as np

# Add parent directory to path to import local modules
sys.path.append(os.path.abspath('..'))

from model import SwinIR

DEVICE = 'cpu' # Optimization is often verified on CPU for deployment targets

# Try to find the most recent checkpoint
MODEL_PATHS = [
    '../swinir_advanced_epoch_final.pth',
    '../swinir_advanced_epoch_20.pth',
    '../swinir_advanced_epoch_15.pth',
    '../swinir_advanced_epoch_10.pth',
    '../swinir_advanced_epoch_5.pth',
]

MODEL_PATH = None
for path in MODEL_PATHS:
    if os.path.exists(path):
        MODEL_PATH = path
        break

if MODEL_PATH is None:
    print("⚠️ WARNING: No model checkpoint found. Please train the model first using notebook 02.")
    print("Using default model architecture for demonstration.")
    MODEL_PATH = None

print(f"Optimization Pipeline Initialized.")
if MODEL_PATH:
    print(f"Will load model from: {MODEL_PATH}")

Optimization Pipeline Initialized.


### 2. Model Pruning (L1 Unstructured)
We remove the smallest 20% of weights in the convolutional layers.

In [None]:
# Load model first
if MODEL_PATH is None:
    print("⚠️ No checkpoint found. Using untrained model for demonstration.")
    model = SwinIR(embed_dim=60, depths=[4,4,4,4], num_heads=[4,4,4,4]).to(DEVICE)
else:
    print(f"Loading model from {MODEL_PATH}...")
    model = SwinIR(embed_dim=60, depths=[4,4,4,4], num_heads=[4,4,4,4]).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
    
print(f"Model Parameter Count: {sum(p.numel() for p in model.parameters()):,}")
print(f"Original Model Size: {os.path.getsize(MODEL_PATH) / (1024*1024):.2f} MB" if MODEL_PATH else "No checkpoint to measure")

def apply_pruning(model, amount=0.2):
    print(f"Applying {amount*100}% L1 Unstructured Pruning to Conv layers...")
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)
            prune.remove(module, 'weight') # Make pruning permanent
    return model

# Example usage (uncomment to apply):
# model = apply_pruning(model, amount=0.2)

### 3. Post-Training Static Quantization (INT8)
Reduces the model size by ~4x.

In [None]:
def quantize_model(model):
    """Quantize model to INT8. Requires representative calibration data."""
    model.eval()
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    model_prepared = torch.quantization.prepare(model)
    
    # Calibrate with representative data
    # NOTE: You need to provide calibration data here
    # Example:
    # for i in range(10):
    #     dummy_input = torch.randn(1, 1, 64, 64)
    #     model_prepared(dummy_input)
    
    # For demonstration, we'll skip actual calibration
    print("⚠️ WARNING: Calibration data not provided. Skipping quantization.")
    print("To properly quantize, uncomment calibration loop above with real data.")
    return model
    
    # Uncomment below when calibration data is ready:
    # model_int8 = torch.quantization.convert(model_prepared)
    # print("Model quantized to INT8.")
    # return model_int8

# Example usage:
# quantized_model = quantize_model(model)

### 4. ONNX Export
Exporting for OpenVINO, TensorRT, or ONNX Runtime.

In [None]:
def export_onnx(model, output_path="../swinir_lensing.onnx"):
    """Export model to ONNX format for deployment."""
    model.eval()
    dummy_input = torch.randn(1, 1, 64, 64)
    
    try:
        torch.onnx.export(model, 
                          dummy_input, 
                          output_path, 
                          export_params=True, 
                          opset_version=12, 
                          do_constant_folding=True,
                          input_names=['input'], 
                          output_names=['output'],
                          dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}})
        print(f"✅ Model exported to {output_path}")
        print(f"   File size: {os.path.getsize(output_path) / (1024*1024):.2f} MB")
    except Exception as e:
        print(f"❌ Error exporting ONNX: {e}")
        print("   Some models may require trace mode or different opset versions.")

# Example usage:
# export_onnx(model, output_path="../swinir_lensing.onnx")

### 5. Final Benchmarking
Compare inference speed and file size.

In [None]:
def get_model_size(path):
    """Get model file size in MB."""
    if not os.path.exists(path):
        print(f"⚠️ File not found: {path}")
        return 0
    size = os.path.getsize(path) / (1024 * 1024)
    print(f"Model Size: {size:.2f} MB")
    return size

def benchmark_inference(model, num_iterations=100):
    """Benchmark model inference time."""
    model.eval()
    dummy_input = torch.randn(1, 1, 64, 64).to(DEVICE)
    
    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(dummy_input)
    
    # Benchmark
    if DEVICE == 'cuda':
        torch.cuda.synchronize()
    
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_iterations):
            _ = model(dummy_input)
    
    if DEVICE == 'cuda':
        torch.cuda.synchronize()
    
    elapsed_time = time.time() - start_time
    avg_time = elapsed_time / num_iterations * 1000  # ms
    
    print(f"Average inference time: {avg_time:.2f} ms ({num_iterations} iterations)")
    return avg_time

print("Benchmarking logic ready.")
print("\nExample usage:")
print("  - Compare model sizes: get_model_size(MODEL_PATH)")
print("  - Benchmark inference: benchmark_inference(model)")
print("  - Export to ONNX: export_onnx(model)")

Benchmarking logic ready.
