In [1]:
import numpy as np
import torch
import torchvision
import sys
sys.path.append("netrep")
from netrep.metrics import LinearMetric
sys.path.append("..")
from Procrustes import ProcrustesDistance

In [2]:
random_state = torch.manual_seed(17)
angles = torch.linspace(0, 2*torch.pi, steps=181)

def givens_rotation(p, i, j, theta):
    """Return a p×p rotation in the (i,j) plane by angle theta."""
    G = torch.eye(p)
    c, s = torch.cos(theta), torch.sin(theta)
    G[i, i] = c; G[j, j] = c
    G[i, j] = -s; G[j, i] = s
    return G

def compare_metrics(dimensions=256, neurons=32):
    """Compare LinearMetric and ProcrustesDistance on rotated data."""
    X = torch.randn(dimensions, neurons, requires_grad=True)
    lookup = {}
    for theta in angles:
        Q = givens_rotation(neurons, 7, 15, theta)   # rotate in plane (0,1)
        Y = X @ Q + torch.randn(dimensions, neurons) * 0.1
        proc_metric = LinearMetric(alpha=1.0, center_columns=True, score_method='euclidean')
        proc_metric.fit(X.detach().numpy(), Y.detach().numpy())
        dist = proc_metric.score(X.detach().numpy(), Y.detach().numpy())
        diff_metric = ProcrustesDistance()
        loss = diff_metric(X, Y)
        lookup[theta.item()] = (dist, loss.item())
    return lookup

In [3]:
random_state = torch.manual_seed(17)
for dimensions in [64, 128, 256]:
    for neurons in [16, 32, 64]:
        print(f"Dimensions: {dimensions}, Neurons: {neurons}")
        lookup = compare_metrics(dimensions=dimensions, neurons=neurons)
        for theta in angles:
            print(f"DxN: {dimensions}x{neurons}, Angle: {theta.item():.2f}, LinearMetric: {lookup[theta.item()][0]:.4f}, ProcrustesDistance: {lookup[theta.item()][1]:.4f}, Diff: {abs(lookup[theta.item()][0]-lookup[theta.item()][1]):.4f}") if abs(lookup[theta.item()][0]-lookup[theta.item()][1]) > 1e-4 else None
        print("\n")

Dimensions: 64, Neurons: 16


Dimensions: 64, Neurons: 32


Dimensions: 64, Neurons: 64


Dimensions: 128, Neurons: 16


Dimensions: 128, Neurons: 32


Dimensions: 128, Neurons: 64


Dimensions: 256, Neurons: 16


Dimensions: 256, Neurons: 32


Dimensions: 256, Neurons: 64




In [4]:
for theta in angles:
    print(f"Angle: {theta.item():.2f}, LinearMetric: {lookup[theta.item()][0]:.4f}, ProcrustesDistance: {lookup[theta.item()][1]:.4f}, Diff: {abs(lookup[theta.item()][0]-lookup[theta.item()][1]):.4f}")

Angle: 0.00, LinearMetric: 11.9621, ProcrustesDistance: 11.9621, Diff: 0.0000
Angle: 0.03, LinearMetric: 11.9358, ProcrustesDistance: 11.9358, Diff: 0.0000
Angle: 0.07, LinearMetric: 11.8997, ProcrustesDistance: 11.8997, Diff: 0.0000
Angle: 0.10, LinearMetric: 11.9011, ProcrustesDistance: 11.9011, Diff: 0.0000
Angle: 0.14, LinearMetric: 11.9615, ProcrustesDistance: 11.9615, Diff: 0.0000
Angle: 0.17, LinearMetric: 11.7043, ProcrustesDistance: 11.7043, Diff: 0.0000
Angle: 0.21, LinearMetric: 12.1138, ProcrustesDistance: 12.1138, Diff: 0.0000
Angle: 0.24, LinearMetric: 11.9534, ProcrustesDistance: 11.9534, Diff: 0.0000
Angle: 0.28, LinearMetric: 12.0672, ProcrustesDistance: 12.0672, Diff: 0.0000
Angle: 0.31, LinearMetric: 11.9173, ProcrustesDistance: 11.9173, Diff: 0.0000
Angle: 0.35, LinearMetric: 11.9462, ProcrustesDistance: 11.9462, Diff: 0.0000
Angle: 0.38, LinearMetric: 12.0599, ProcrustesDistance: 12.0599, Diff: 0.0000
Angle: 0.42, LinearMetric: 12.0153, ProcrustesDistance: 12.0153,

In [7]:
X = torch.randn(32, 32, requires_grad=True)
Y = X @ givens_rotation(32, 7, 15, angles[4]) + torch.randn(32, 32) * 0.1
metric = ProcrustesDistance()
loss = metric(X, Y)
print(f"Procrustes Distance: {loss.item():.4f}")

Procrustes Distance: 2.3323


In [8]:
loss.backward(retain_graph=True)
print(X.grad, Y.grad)

tensor([[ 0.0011, -0.0026,  0.0010,  ..., -0.0024, -0.0049,  0.0045],
        [-0.0021,  0.0022, -0.0023,  ...,  0.0032,  0.0024, -0.0078],
        [-0.0021, -0.0003, -0.0006,  ...,  0.0025, -0.0052, -0.0006],
        ...,
        [ 0.0026,  0.0009,  0.0040,  ..., -0.0090, -0.0015,  0.0026],
        [-0.0005, -0.0033, -0.0005,  ...,  0.0029,  0.0022,  0.0003],
        [ 0.0007,  0.0021,  0.0004,  ...,  0.0041,  0.0054, -0.0033]]) None


  print(X.grad, Y.grad)


In [32]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

datasets.MNIST.mirrors = ["https://ossci-datasets.s3.amazonaws.com/mnist"]
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
print(f"MNIST Trainset Size: {len(mnist_trainset)}, Testset Size: {len(mnist_testset)}")

MNIST Trainset Size: 60000, Testset Size: 10000


In [56]:
# Basic 2 layer MLP
import torch.nn as nn
import torch.nn.functional as F
class SimpleMLP(nn.Module):
    def __init__(self, input_size=28*28, output_size=10):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the input
        x = self.fc1(x)
        return x

In [57]:
# Train the MLP on MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

train_loader = DataLoader(mnist_trainset, batch_size=64, shuffle=True)

# Split test set into validation and test sets
val_size = 2000
test_size = len(mnist_testset) - val_size
val_set, test_set = random_split(mnist_testset, [val_size, test_size])

val_loader = DataLoader(val_set, batch_size=1000, shuffle=False)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)

model = SimpleMLP()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

def calculate_accuracy(loader, model):
    """Calculate accuracy on a dataset."""
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return total_loss / len(loader), 100. * correct / total

# Training loop
for epoch in range(1, 6):  # 5 epochs
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        train_correct += pred.eq(target.view_as(pred)).sum().item()
        train_total += target.size(0)
    
    # Calculate average training metrics
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100. * train_correct / train_total
    
    # Calculate validation metrics
    val_loss, val_accuracy = calculate_accuracy(val_loader, model)
    
    print(f"Epoch {epoch}, Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

# Testing loop
test_loss, test_accuracy = calculate_accuracy(test_loader, model)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

Epoch 1, Train Loss: 0.5423, Train Acc: 86.98%, Val Loss: 0.3317, Val Acc: 91.10%
Epoch 2, Train Loss: 0.3234, Train Acc: 91.03%, Val Loss: 0.2960, Val Acc: 91.50%
Epoch 3, Train Loss: 0.2953, Train Acc: 91.84%, Val Loss: 0.2805, Val Acc: 92.35%
Epoch 4, Train Loss: 0.2819, Train Acc: 92.13%, Val Loss: 0.2795, Val Acc: 92.25%
Epoch 5, Train Loss: 0.2733, Train Acc: 92.37%, Val Loss: 0.2716, Val Acc: 92.15%
Test Loss: 0.2694, Test Accuracy: 92.49%


In [58]:
print(model)
W, b = model.fc1.weight, model.fc1.bias
print(W.shape, b.shape)
B = torch.cat([W, b.unsqueeze(1)], dim=1)  # Combine weights and bias
print(B.shape)  # Should be [10, 785]

# Recover
W_recovered = B[:, :-1]  # All columns except the last one
b_recovered = B[:, -1]   # Only the last column
assert torch.allclose(W, W_recovered)
assert torch.allclose(b, b_recovered)

SimpleMLP(
  (fc1): Linear(in_features=784, out_features=10, bias=True)
)
torch.Size([10, 784]) torch.Size([10])
torch.Size([10, 785])


In [60]:
A = torch.randn(10, 785, requires_grad=True)
gradient = torch.optim.Adam([A], lr=0.01)
loss_fn = ProcrustesDistance()

# Make A look like B
for _ in range(700):
    gradient.zero_grad()
    loss = loss_fn(A, B)
    if _ % 100 == 0:
        proc_metric = LinearMetric(alpha=1.0, center_columns=True, score_method='euclidean')
        proc_metric.fit(A.detach().numpy(), B.detach().numpy())
        dist = proc_metric.score(A.detach().numpy(), B.detach().numpy())
    loss.backward()
    gradient.step()
    if _ % 100 == 0:
        print(f"Procrustes Distance: {loss.item():.4f}, LinearMetric Distance: {dist:.4f}")
print(f"Final Procrustes Distance between A and B: {loss_fn(A, B).item():.4f}")

Procrustes Distance: 68.2172, LinearMetric Distance: 68.2172
Procrustes Distance: 12.8499, LinearMetric Distance: 12.8499
Procrustes Distance: 0.0448, LinearMetric Distance: 0.0448
Procrustes Distance: 0.0302, LinearMetric Distance: 0.0302
Procrustes Distance: 0.0306, LinearMetric Distance: 0.0306
Procrustes Distance: 0.0391, LinearMetric Distance: 0.0391
Procrustes Distance: 0.0239, LinearMetric Distance: 0.0239
Final Procrustes Distance between A and B: 0.0258


In [62]:
print(A.shape)
W_prime, b_prime = A[:, :-1], A[:, -1]
new_model = SimpleMLP()
new_model.fc1.weight.data = W_prime
new_model.fc1.bias.data = b_prime
new_test_loss, new_test_accuracy = calculate_accuracy(test_loader, new_model)
print(f"New Model Test Loss: {new_test_loss:.4f}, New Model Test Accuracy: {new_test_accuracy:.2f}%")

torch.Size([10, 785])
New Model Test Loss: 3.3069, New Model Test Accuracy: 9.69%
