In [1]:
import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# --- Part 1: Original Toy CNN Experiment (for reference, not run by default) ---
# ... (Code from Part 1 is unchanged and omitted for brevity) ...


# --- Part 2: Final MNIST Experiment (Homogeneous Models Only) ---

# --- 2.1. MNIST Experiment Settings ---
MNIST_LEARNING_RATE = 0.001 
MNIST_BATCH_SIZE = 64
MNIST_EPOCHS = 1000

# --- 2.2. MNIST Data Loading and Preprocessing ---
def get_mnist_binary_loaders(batch_size=64):
    """
    Loads MNIST and creates a binary classification task (3 vs 5).
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    # Use the full training dataset for evaluation
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    
    idx = (train_dataset.targets == 3) | (train_dataset.targets == 5)
    train_dataset.data = train_dataset.data[idx]
    train_dataset.targets = train_dataset.targets[idx]
    
    train_dataset.targets[train_dataset.targets == 3] = -1
    train_dataset.targets[train_dataset.targets == 5] = 1
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # Create a separate loader for evaluation without shuffling
    eval_loader = DataLoader(train_dataset, batch_size=512, shuffle=False)
    return train_loader, eval_loader

# --- 2.3. MNIST Model Definitions (Homogeneous Only) ---
def leaky_relu(x, negative_slope=0.01):
    return torch.where(x > 0, x, x * negative_slope)

# MLP Models
class SmallMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128, bias=False)
        self.fc2 = nn.Linear(128, 1, bias=False)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = leaky_relu(self.fc1(x))
        x = self.fc2(x)
        return x

class LargeMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256, bias=False)
        self.fc2 = nn.Linear(256, 1, bias=False)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = leaky_relu(self.fc1(x))
        x = self.fc2(x)
        return x

# CNN Models
class SmallMNISTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5, bias=False)
        self.pool = nn.AvgPool2d(2) 
        self.fc1 = nn.Linear(10 * 12 * 12, 1, bias=False)
    def forward(self, x):
        x = leaky_relu(self.conv1(x))
        x = self.pool(x)
        x = x.view(-1, 10 * 12 * 12)
        x = self.fc1(x)
        return x

class LargeMNISTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5, bias=False)
        self.pool = nn.AvgPool2d(2)
        self.fc1 = nn.Linear(20 * 12 * 12, 1, bias=False)
    def forward(self, x):
        x = leaky_relu(self.conv1(x))
        x = self.pool(x)
        x = x.view(-1, 20 * 12 * 12)
        x = self.fc1(x)
        return x

# --- 2.4. Transformation Functions (Homogeneous Only) ---
def transform_mlp_params(small_model):
    params = dict(small_model.named_parameters())
    W1_s, W2_s = params['fc1.weight'], params['fc2.weight']
    c = 1.0 / np.sqrt(2)
    W1_l = torch.cat([c * W1_s, c * W1_s], dim=0)
    W2_l = torch.cat([c * W2_s, c * W2_s], dim=1)
    return {'fc1.weight': W1_l, 'fc2.weight': W2_l}

def transform_cnn_mnist_params(small_model):
    params = dict(small_model.named_parameters())
    W_conv_s, W_fc_s = params['conv1.weight'], params['fc1.weight']
    c = 1.0 / np.sqrt(2)
    W_conv_l = torch.cat([c * W_conv_s, c * W_conv_s], dim=0)
    W_fc_l = torch.cat([c * W_fc_s, c * W_fc_s], dim=1)
    return {'conv1.weight': W_conv_l, 'fc1.weight': W_fc_l}

def exponential_loss(y_pred, y_true):
    return torch.mean(torch.exp(-y_true * y_pred))

# --- 2.5. New: Accuracy and Loss Calculation ---
def calculate_metrics(model, loader):
    model.eval() # Set model to evaluation mode
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(torch.double), target.to(torch.double).view(-1, 1)
            output = model(data)
            
            # Calculate loss
            loss = exponential_loss(output, target)
            total_loss += loss.item() * data.size(0)
            
            # Calculate accuracy
            predicted = torch.sign(output)
            correct += (predicted == target).sum().item()
            total += target.size(0)
            
    avg_loss = total_loss / total
    accuracy = 100 * correct / total
    model.train() # Set model back to training mode
    return avg_loss, accuracy

# --- 2.6. MNIST Training and Evaluation Loop ---
def run_mnist_experiment(model_type='mlp'):
    print(f"\n--- Running MNIST Experiment: Homogeneous {model_type.upper()} ---")
    
    train_loader, eval_loader = get_mnist_binary_loaders(MNIST_BATCH_SIZE)
    
    if model_type == 'mlp':
        small_net = SmallMLP()
        large_net = LargeMLP()
        transform_fn = transform_mlp_params
    else: # cnn
        small_net = SmallMNISTCNN()
        large_net = LargeMNISTCNN()
        transform_fn = transform_cnn_mnist_params

    small_net.to(torch.double)
    large_net.to(torch.double)

    with torch.no_grad():
        for param in small_net.parameters():
            if param.dim() > 1:
                nn.init.kaiming_normal_(param, a=math.sqrt(5))
            else:
                nn.init.zeros_(param)
            
        initial_large_params = transform_fn(small_net)
        for name, param in large_net.named_parameters():
            param.data.copy_(initial_large_params[name])

    optimizer_small = torch.optim.SGD(small_net.parameters(), lr=MNIST_LEARNING_RATE)
    optimizer_large = torch.optim.SGD(large_net.parameters(), lr=MNIST_LEARNING_RATE)
    
    error_history = []
    
    for epoch in range(MNIST_EPOCHS):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(torch.double), target.to(torch.double).view(-1, 1)
            
            optimizer_small.zero_grad()
            output_small = small_net(data)
            loss_small = exponential_loss(output_small, target)
            loss_small.backward()
            optimizer_small.step()
            
            optimizer_large.zero_grad()
            output_large = large_net(data)
            loss_large = exponential_loss(output_large, target)
            loss_large.backward()
            optimizer_large.step()
            
            with torch.no_grad():
                predicted_params_large = transform_fn(small_net)
                actual_vec_list = []
                predicted_vec_list = []
                for name, param in sorted(large_net.named_parameters()):
                    actual_vec_list.append(param.flatten())
                    predicted_vec_list.append(predicted_params_large[name].flatten())
                actual_vec = torch.cat(actual_vec_list)
                predicted_vec = torch.cat(predicted_vec_list)
                error = torch.norm(actual_vec - predicted_vec).item()
                error_history.append(error)

        # --- New: Print Loss and Accuracy at the end of each epoch ---
        avg_loss, accuracy = calculate_metrics(small_net, eval_loader)
        print(f"Epoch {epoch+1}/{MNIST_EPOCHS} | Avg Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}% | Final Batch Error: {error:.4e}")
        
    print(f"Result: Max Trajectory Error = {max(error_history):.4e}")
    print("--------------------------------------------------")
    return max(error_history)


if __name__ == '__main__':
    results = {}
    results['mlp_no_bias'] = run_mnist_experiment(model_type='mlp')
    results['cnn_no_bias'] = run_mnist_experiment(model_type='cnn')
    
    print("\n========= MNIST Experiment Summary =========")
    for key, value in results.items():
        print(f"{key:<15}: Max Trajectory Error = {value:.4e}")
    print("==========================================")


--- Running MNIST Experiment: Homogeneous MLP ---
Epoch 1/1000 | Avg Loss: 0.7281 | Accuracy: 87.79% | Final Batch Error: 5.9833e-15
Epoch 2/1000 | Avg Loss: 0.5808 | Accuracy: 90.87% | Final Batch Error: 8.1784e-15
Epoch 3/1000 | Avg Loss: 0.4834 | Accuracy: 92.11% | Final Batch Error: 1.0178e-14
Epoch 4/1000 | Avg Loss: 0.4197 | Accuracy: 92.43% | Final Batch Error: 1.1848e-14
Epoch 5/1000 | Avg Loss: 0.3774 | Accuracy: 92.97% | Final Batch Error: 1.3313e-14
Epoch 6/1000 | Avg Loss: 0.3471 | Accuracy: 93.68% | Final Batch Error: 1.4564e-14
Epoch 7/1000 | Avg Loss: 0.3279 | Accuracy: 93.91% | Final Batch Error: 1.5561e-14
Epoch 8/1000 | Avg Loss: 0.3103 | Accuracy: 94.01% | Final Batch Error: 1.6832e-14
Epoch 9/1000 | Avg Loss: 0.2985 | Accuracy: 93.98% | Final Batch Error: 1.7777e-14
Epoch 10/1000 | Avg Loss: 0.2890 | Accuracy: 94.21% | Final Batch Error: 1.8848e-14
Epoch 11/1000 | Avg Loss: 0.2813 | Accuracy: 94.40% | Final Batch Error: 1.9703e-14
Epoch 12/1000 | Avg Loss: 0.2749 |