In [1]:
import os
import torch
from torch import nn
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch.optim import Adam
from collections import defaultdict
import numpy.ma as ma
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from nn.models import DWSModel
from typing import Tuple, NamedTuple
from nn.layers import BN, DWSLayer,InvariantLayer, Dropout, ReLU
from nn.layers.base import BaseLayer,GeneralSetLayer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.backends.cudnn.enabled = False

cpu


In [18]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMLP, self).__init__()
        self.hidden1 = nn.Linear(input_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.hidden1(x))
        x = torch.relu(self.hidden2(x))
        return self.output(x)

def permute_layer_weights(layer, perm):
    """Permutes weights and biases of a layer based on the given permutation matrix."""
    weight = layer.weight.data.clone()
    bias = layer.bias.data.clone()

    # Permute rows of weights (output dimension)
    layer.weight.data = weight[perm, :]

    # Permute bias (output dimension)
    layer.bias.data = bias[perm]
    return layer

def permute_model(model, permutations):
    """Applies a list of permutations to a model's layers."""
    permute_layer_weights(model.hidden1, permutations[0])
    permute_layer_weights(model.hidden2, permutations[1])

    # Adjust input weights of the second hidden layer
    model.hidden2.weight.data = model.hidden2.weight.data[:, permutations[0]]
    # Adjust input weights of the output layer
    model.output.weight.data = model.output.weight.data[:, permutations[1]]
    return model

def generate_permutation_matrix(hidden_dim):
    """Generates a random permutation matrix."""
    perm = np.random.permutation(hidden_dim)
    return torch.tensor(perm, dtype=torch.long)

# Example Usage
input_dim = 10
hidden_dim = 16
output_dim = 5

# Initialize two identical models
model_a = SimpleMLP(input_dim, hidden_dim, output_dim)
model_b = SimpleMLP(input_dim, hidden_dim, output_dim)
model_b.load_state_dict(model_a.state_dict())  # Ensure identical initial weights

# Generate random permutations for each hidden layer
perm1 = generate_permutation_matrix(hidden_dim)
perm2 = generate_permutation_matrix(hidden_dim)

# Apply permutations to model B
model_c = permute_model(model_b, [perm1, perm1]) # dont need to be seperate perms

# Verify functionality is preserved
x = torch.rand(1, input_dim)
output_a = model_a(x)
output_b = model_b(x)
output_c = model_c(x)

print("Outputs are identical:", torch.allclose(output_a, output_c))

Outputs are identical: True


In [19]:
output_a, output_b, output_c

(tensor([[-0.2754,  0.1488,  0.1450,  0.0545,  0.0156]],
        grad_fn=<AddmmBackward0>),
 tensor([[-0.2754,  0.1488,  0.1450,  0.0545,  0.0156]],
        grad_fn=<AddmmBackward0>),
 tensor([[-0.2754,  0.1488,  0.1450,  0.0545,  0.0156]],
        grad_fn=<AddmmBackward0>))

In [20]:
print( model_a.hidden1.weight )

Parameter containing:
tensor([[ 0.1146,  0.1517,  0.0849, -0.1838, -0.2705,  0.3162,  0.2603,  0.0251,
          0.3071, -0.2898],
        [-0.1530,  0.2477,  0.1027,  0.2589, -0.2825,  0.2156, -0.0958,  0.2503,
          0.1518, -0.3014],
        [-0.0836, -0.1677,  0.2058, -0.2556, -0.1140,  0.1546,  0.2420,  0.0414,
         -0.0385, -0.0745],
        [ 0.2134,  0.2503, -0.0340,  0.2803, -0.2951, -0.2324, -0.2870, -0.0131,
          0.0861, -0.2488],
        [-0.0777, -0.1802,  0.0702, -0.0501,  0.2609,  0.0517,  0.0555, -0.1036,
         -0.2032, -0.2833],
        [-0.0096,  0.1308,  0.3123,  0.1388,  0.0583, -0.0722,  0.1712, -0.0098,
         -0.2590, -0.1907],
        [-0.0919, -0.1051, -0.2803, -0.0928, -0.0947,  0.1869, -0.1440, -0.2760,
         -0.2720,  0.2293],
        [ 0.0651, -0.0656, -0.1989,  0.0404,  0.2409, -0.1033,  0.0823,  0.1227,
          0.2240,  0.0412],
        [-0.0895, -0.0282,  0.2695, -0.2857,  0.1852, -0.0996, -0.2727, -0.1151,
          0.1198, -0.1558

In [21]:
print( model_b.hidden1.weight )

Parameter containing:
tensor([[-0.0836, -0.1677,  0.2058, -0.2556, -0.1140,  0.1546,  0.2420,  0.0414,
         -0.0385, -0.0745],
        [-0.0777, -0.1802,  0.0702, -0.0501,  0.2609,  0.0517,  0.0555, -0.1036,
         -0.2032, -0.2833],
        [ 0.2267, -0.1168, -0.2147, -0.2599,  0.0416,  0.0828, -0.2591, -0.0569,
         -0.1012, -0.2804],
        [-0.0096,  0.1308,  0.3123,  0.1388,  0.0583, -0.0722,  0.1712, -0.0098,
         -0.2590, -0.1907],
        [ 0.1146,  0.1517,  0.0849, -0.1838, -0.2705,  0.3162,  0.2603,  0.0251,
          0.3071, -0.2898],
        [-0.0919, -0.1051, -0.2803, -0.0928, -0.0947,  0.1869, -0.1440, -0.2760,
         -0.2720,  0.2293],
        [ 0.2134,  0.2503, -0.0340,  0.2803, -0.2951, -0.2324, -0.2870, -0.0131,
          0.0861, -0.2488],
        [-0.1530,  0.2477,  0.1027,  0.2589, -0.2825,  0.2156, -0.0958,  0.2503,
          0.1518, -0.3014],
        [ 0.1899, -0.1269,  0.1942, -0.0556, -0.0682, -0.0782, -0.1732, -0.0059,
          0.2157, -0.0096

# Apply to MLP() dataset: 

In [28]:
# Simple MLP class for MNIST classification

class MLP(nn.Module):
    def __init__(self, init_type='xavier', seed=None):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, 10)
        
        if seed is not None:
            torch.manual_seed(seed)  # Set a unique seed for reproducibility

        self.init_weights(init_type)

    def init_weights(self, init_type):
        if init_type == 'xavier':
            nn.init.xavier_uniform_(self.fc1.weight)
            nn.init.xavier_uniform_(self.fc2.weight)
            nn.init.xavier_uniform_(self.fc3.weight)
        elif init_type == 'he':
            nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')
            nn.init.kaiming_uniform_(self.fc2.weight, nonlinearity='relu')
            nn.init.kaiming_uniform_(self.fc3.weight, nonlinearity='relu')
        else:
            nn.init.normal_(self.fc1.weight)
            nn.init.normal_(self.fc2.weight)
            nn.init.normal_(self.fc3.weight)
        
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.fc3.bias)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [38]:
# Set up data loader for MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_data = datasets.MNIST('.', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)

test_data = datasets.MNIST('.', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

# MLP management functions: 

def train_mlp(model, epochs=3):
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)
    
    model.train()
    for epoch in range(epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model
    
    
def test_mlp(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # No need to compute gradients for evaluation
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

model = MLP()
model = train_mlp(model)

In [39]:
def permute_layer_weights(layer, perm):
    """Permutes weights and biases of a layer based on the given permutation matrix."""
    weight = layer.weight.data.clone()
    bias = layer.bias.data.clone()

    # Permute rows of weights (output dimension)
    layer.weight.data = weight[perm, :]

    # Permute bias (output dimension)
    layer.bias.data = bias[perm]
    return layer

def permute_model(model, permutations):
    """Applies a list of permutations to a model's layers."""
    permute_layer_weights(model.fc1, permutations[0])
    permute_layer_weights(model.fc2, permutations[1])

    # Adjust input weights of the second hidden layer
    model.fc2.weight.data = model.fc2.weight.data[:, permutations[0]]
    # Adjust input weights of the output layer
    model.fc3.weight.data = model.fc3.weight.data[:, permutations[1]]
    return model

def generate_permutation_matrix(hidden_dim):
    """Generates a random permutation matrix."""
    perm = np.random.permutation(hidden_dim)
    return torch.tensor(perm, dtype=torch.long)

In [42]:
# Example Usage
input_dim = 784
hidden_dim = 32
output_dim = 10

# Generate random permutations for each hidden layer
perm1 = generate_permutation_matrix(hidden_dim)
perm2 = generate_permutation_matrix(hidden_dim)

# Apply permutations to model B
model_c = permute_model(model, [perm1, perm1]) # dont need to be seperate perms

# Verify functionality is preserved
x = torch.ones(1, input_dim)
output_a = model(x)
output_c = model_c(x)

print("Outputs are identical:", torch.allclose(output_a, output_c))

Outputs are identical: True


In [41]:
test_mlp(model, test_loader), test_mlp(model_c, test_loader)

(94.64, 94.64)