In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Configurazione Device e Riproducibilità
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device in uso: {device}")

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

Device in uso: cpu


In [3]:
# Caricamento Dati MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST('./', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

# Definizione Modello (SimpleMLP)
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=512, output_dim=10):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.view(x.size(0), -1) # Flatten
        x = F.relu(self.layer1(x))
        x = self.layer2(x)
        return x

In [4]:
def train_and_save_model(model_id, data_loader, num_epochs=3):
    #Addestra un modello con un seed specifico e lo salva
    # Seed diverso per ogni modello = inizializzazione diversa
    torch.manual_seed(42 + model_id)

    model = SimpleMLP().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Addestramento
    model.train()
    for epoch in range(num_epochs):
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

    path = f'model_checkpoint_{model_id}.pth'
    torch.save(model.state_dict(), path)
    return path

In [5]:
def align_models(reference_state_dict, target_state_dict):

  # Allinea i neuroni del target al reference risolvendo la Permutation Invariance.

    new_target_dict = target_state_dict.copy()

    # Lavoriamo sullo stesso device
    ref_w = reference_state_dict['layer1.weight'].to(device)
    tgt_w = target_state_dict['layer1.weight'].to(device)

    # Calcolo Similarità (Prodotto Scalare)
    similarity = torch.matmul(ref_w, tgt_w.T)

    # Trova la permutazione ottimale (Argmax Greedy)
    # Per ogni neurone del ref, troviamo il neurone del target più simile
    permutation = torch.argmax(similarity, dim=1)

    # Applica la permutazione ai pesi
    # Layer 1 (Output): Riordiniamo le righe
    new_target_dict['layer1.weight'] = target_state_dict['layer1.weight'][permutation]
    new_target_dict['layer1.bias'] = target_state_dict['layer1.bias'][permutation]

    # Layer 2 (Input): Riordiniamo le colonne
    new_target_dict['layer2.weight'] = target_state_dict['layer2.weight'][:, permutation]

    return new_target_dict

In [6]:
def load_align_and_vectorize(N, device):

    # Carica -> Allinea al Modello 1 -> Vettorizza -> Normalizza.

    aligned_vectors = []

    # Carichiamo il Modello 1 come Riferimento
    path_ref = 'model_checkpoint_1.pth'
    if not os.path.exists(path_ref): return None
    ref_state_dict = torch.load(path_ref, map_location=device)

    # Chiavi per appiattire i pesi in ordine
    keys = list(ref_state_dict.keys())

    for i in range(1, N + 1):
        path = f'model_checkpoint_{i}.pth'
        curr_state_dict = torch.load(path, map_location=device)

        # FASE DI ALLINEAMENTO
        if i > 1:

            final_dict = align_models(ref_state_dict, curr_state_dict)
        else:
            final_dict = curr_state_dict # Il modello 1 è già "allineato a se stesso"

        # Vettorizzazione
        tensors = [final_dict[k].view(-1) for k in keys]
        v_flat = torch.cat(tensors)

        # Normalizzazione Sferica (L2)
        v_norm = F.normalize(v_flat, p=2, dim=0)
        aligned_vectors.append(v_norm)

    return torch.stack(aligned_vectors)

In [7]:
def spherical_loss(merged_vector, original_vectors, weights):
    # Prodotto scalare e clamp per stabilità numerica
    dot = torch.einsum('d, nd -> n', merged_vector, original_vectors)
    dot = torch.clamp(dot, -1.0 + 1e-7, 1.0 - 1e-7)
    # Distanza geodetica
    dist = torch.acos(dot)
    return 0.5 * torch.sum(weights * (dist ** 2))

def find_karcher_mean(vectors, weights, lr=0.5, max_iter=200):
    # Inizializzazione: Media Euclidea proiettata
    mean_init = torch.mean(vectors, dim=0)
    m = F.normalize(mean_init, p=2, dim=0).clone().detach().requires_grad_(True)

    optimizer = optim.SGD([m], lr=lr)

    for _ in range(max_iter):
        optimizer.zero_grad()
        loss = spherical_loss(m, vectors, weights)
        loss.backward()
        optimizer.step()

        # Proietta di nuovo sulla sfera dopo il passo
        with torch.no_grad():
            m.data = F.normalize(m.data, p=2, dim=0)

    return m.detach()

In [9]:
def inverse_vectorize(vector, ref_model): # Dal vettore piatto torna al state_dict
    state_dict = ref_model.state_dict()
    new_dict = {}
    idx = 0
    for k, v in state_dict.items():
        num = v.numel()
        new_dict[k] = vector[idx : idx+num].view(v.shape)
        idx += num
    return new_dict

def evaluate(weights):
    model = SimpleMLP().to(device)
    model.load_state_dict(weights)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += len(data)
    return correct / total

In [11]:
# Lista degli esperimenti
N_values = [3, 5, 10]
results_table = []

print("INIZIO ESPERIMENTI")

for N in N_values:
    print(f"Processing N={N}")

    # Training
    for i in range(1, N + 1):
        if not os.path.exists(f'model_checkpoint_{i}.pth'):
            train_and_save_model(i, train_loader)

    # Caricamento, Allineamento e Vettorizzazione
    V = load_align_and_vectorize(N, device)
    weights = torch.ones(N).to(device) / N # Pesi uniformi

    # Calcolo Karcher Mean
    karcher_vec = find_karcher_mean(V, weights)

    # Calcolo Media Euclidea
    euclidean_vec = torch.mean(V, dim=0)

    # Valutazione
    ref_model = SimpleMLP().to('cpu')

    # Eval Karcher
    w_karcher = inverse_vectorize(karcher_vec.cpu(), ref_model)
    acc_karcher = evaluate(w_karcher)

    # Eval Euclidean
    w_euclid = inverse_vectorize(euclidean_vec.cpu(), ref_model)
    acc_euclid = evaluate(w_euclid)

    # Save results
    gap = acc_karcher - acc_euclid
    results_table.append((N, acc_karcher, acc_euclid, gap))
    print(f"Karcher: {acc_karcher:.4f} | Euclidean: {acc_euclid:.4f}")

# --- STAMPA TABELLA FINALE ---
print("FINAL RESULTS TABLE")
print("|  N  | Karcher Acc | Euclidean Acc |   Gap   |")
print("|-----|-------------|---------------|---------|")
for row in results_table:
    print(f"| {row[0]:<3} |   {row[1]*100:.2f}%    |    {row[2]*100:.2f}%     | {row[3]*100:+.2f}%  |")

INIZIO ESPERIMENTI
Processing N=3
Karcher: 0.9672 | Euclidean: 0.9665
Processing N=5
Karcher: 0.9632 | Euclidean: 0.9614
Processing N=10
Karcher: 0.9564 | Euclidean: 0.9543
Processing N=20
Karcher: 0.9603 | Euclidean: 0.9596
FINAL RESULTS TABLE
|  N  | Karcher Acc | Euclidean Acc |   Gap   |
|-----|-------------|---------------|---------|
| 3   |   96.72%    |    96.65%     | +0.07%  |
| 5   |   96.32%    |    96.14%     | +0.18%  |
| 10  |   95.64%    |    95.43%     | +0.21%  |
| 20  |   96.03%    |    95.96%     | +0.07%  |
