# Linearization Framework - Sandwich Architecture

This notebook demonstrates the Linearizer framework using the **sandwich architecture**:
- f(x) = g⁻¹ᵧ(Agₓ(x))
- gₓ: Image → Latent (invertible network)
- A: Linear operator in latent space
- g⁻¹ᵧ: Latent → Embedding (invertible network)

Based on: Berman et al. "Who Said Neural Networks Aren't Linear?" (2025)

## Objectives
1. Load InsightFace buffalo_l model
2. Create Linearizer with sandwich architecture
3. Train invertible networks (gₓ and g⁻¹ᵧ) and linear operator A
4. Evaluate reconstruction quality


In [None]:
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
from tqdm import tqdm

from utils.model_loader import load_model_from_config
from linearizer.linearizer import Linearizer
from data.dataloader import get_ms1mv2_dataloader

# Load configuration
with open('../config.yaml', 'r') as f:
    config = yaml.safe_load(f)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


## 1. Load Face Recognition Model

In [None]:
# Load the original face recognition model
model = load_model_from_config(config)
model = model.to(device)
model.eval()

print("Original model loaded")
embedding_size = config['model'].get('embedding_size', 512)
print(f"Embedding size: {embedding_size}")


## 2. Create Linearizer

In [None]:
# Create Linearizer with sandwich architecture
linearizer_config = config['linearizer']
latent_dim = linearizer_config.get('latent_dim', 512)
image_size = (112, 112)  # Standard face recognition image size

linearizer = Linearizer(
    model=model,
    embedding_size=embedding_size,
    latent_dim=latent_dim,
    num_blocks=linearizer_config.get('num_blocks', 4),
    hidden_dim=linearizer_config.get('hidden_dim', 1024),
    num_layers=linearizer_config.get('num_layers', 3),
    image_size=image_size
)
linearizer = linearizer.to(device)

print("Linearizer created successfully with sandwich architecture!")
print(f"  - gₓ: Image ({image_size}) → Latent ({latent_dim})")
print(f"  - A: Linear operator ({latent_dim} × {latent_dim})")
print(f"  - g⁻¹ᵧ: Latent ({latent_dim}) → Embedding ({embedding_size})")
print(f"  - Invertible blocks per network: {linearizer_config.get('num_blocks', 4)}")

# Count parameters
total_params = sum(p.numel() for p in linearizer.parameters())
trainable_params = sum(p.numel() for p in linearizer.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


## 3. Train Linearizer

In [None]:
# Load training data
ms1mv2_path = config['data']['ms1mv2']['path']
dataloader = get_ms1mv2_dataloader(
    ms1mv2_path,
    batch_size=linearizer_config.get('batch_size', 64),
    num_workers=4,
    is_training=True
)

# Train linearizer
print("Training linearizer...")
linearizer.train_linearizer(
    dataloader,
    num_epochs=linearizer_config.get('num_epochs', 100),
    lr=linearizer_config.get('learning_rate', 0.0001),
    device=device
)

print("Training completed!")


## 4. Evaluate Reconstruction Quality

Test the linearized model's ability to reconstruct original embeddings:

## 5. Visualize Sandwich Architecture Flow

Let's visualize how the sandwich architecture processes an image step-by-step:
- Input image → gₓ → Latent space → A (linear op) → g⁻¹ᵧ → Embedding

In [None]:
# Test reconstruction quality
linearizer.eval()
test_dataloader = get_ms1mv2_dataloader(
    ms1mv2_path,
    batch_size=32,
    is_training=False
)

reconstruction_errors = []
with torch.no_grad():
    for images, _ in tqdm(test_dataloader, desc="Testing"):
        images = images.to(device)
        
        # Original embeddings
        original_emb = model.extract_features(images)
        
        # Linearized embeddings
        linearized_emb = linearizer(images)
        
        # Compute reconstruction error
        error = torch.nn.functional.mse_loss(original_emb, linearized_emb)
        reconstruction_errors.append(error.item())

avg_error = np.mean(reconstruction_errors)
print(f"Average reconstruction error: {avg_error:.6f}")

# Visualize some examples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
with torch.no_grad():
    images, _ = next(iter(test_dataloader))
    images = images[:5].to(device)
    
    original_emb = model.extract_features(images)
    linearized_emb = linearizer(images)
    
    for i in range(5):
        # Show image
        axes[0, i].imshow(images[i].cpu().permute(1, 2, 0) * 0.5 + 0.5)
        axes[0, i].set_title("Input Image")
        axes[0, i].axis('off')
        
        # Show embedding difference
        diff = (original_emb[i] - linearized_emb[i]).cpu().numpy()
        axes[1, i].bar(range(min(20, len(diff))), diff[:20])
        axes[1, i].set_title(f"Embedding Diff\nMSE: {torch.nn.functional.mse_loss(original_emb[i], linearized_emb[i]):.6f}")
        axes[1, i].set_ylim(-0.1, 0.1)

plt.tight_layout()
plt.show()


In [None]:
# Visualize the sandwich architecture flow
linearizer.eval()
with torch.no_grad():
    # Get a sample batch
    images, _ = next(iter(test_dataloader))
    images = images[:1].to(device)  # Single image for visualization
    
    # Step-by-step processing
    print("Sandwich Architecture Flow:")
    print("=" * 50)
    
    # Step 1: gₓ(x) - Image to Latent
    z = linearizer.g_x(images, reverse=False)
    print(f"1. Input image shape: {images.shape}")
    print(f"2. After gₓ (latent): {z.shape}")
    print(f"   Latent vector norm: {torch.norm(z).item():.4f}")
    
    # Step 2: A(z) - Linear operator
    z_transformed = linearizer.linear_op(z)
    print(f"3. After A (linear op): {z_transformed.shape}")
    print(f"   Transformed latent norm: {torch.norm(z_transformed).item():.4f}")
    
    # Step 3: g⁻¹ᵧ(z) - Latent to Embedding
    embedding = linearizer.g_y_inv(z_transformed, reverse=False)
    print(f"4. After g⁻¹ᵧ (embedding): {embedding.shape}")
    print(f"   Embedding norm: {torch.norm(embedding).item():.4f}")
    
    # Compare with original
    original_emb = model.extract_features(images)
    print(f"\n5. Original embedding: {original_emb.shape}")
    print(f"   Original embedding norm: {torch.norm(original_emb).item():.4f}")
    
    # Compute similarity
    similarity = torch.nn.functional.cosine_similarity(embedding, original_emb, dim=1)
    print(f"\nCosine similarity: {similarity.item():.6f}")
    print(f"MSE: {torch.nn.functional.mse_loss(embedding, original_emb).item():.6f}")