In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib.patches import Ellipse, Circle
import time
from itertools import islice
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torchvision
from mpl_toolkits.mplot3d import Axes3D
from collections import defaultdict
from itertools import groupby
from operator import itemgetter
import json 

seed = 1
np.random.seed(seed)
torch.manual_seed(seed)


steps = 10000

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("device is:", device)

seed = 1
np.random.seed(seed)
torch.manual_seed(seed)

#torch.set_default_tensor_type(torch.DoubleTensor)


class BioLinear2D(nn.Module):

    def __init__(self, in_dim, out_dim, in_fold=1, out_fold=1, out_ring=False):
        super(BioLinear2D, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(in_dim, out_dim)
        self.in_fold = in_fold
        self.out_fold = out_fold
        assert in_dim % in_fold == 0
        assert out_dim % out_fold == 0

        #compute in_cor, shape: (in_dim_sqrt, in_dim_sqrt)
        in_dim_fold = int(in_dim/in_fold)
        out_dim_fold = int(out_dim/out_fold)
        in_dim_sqrt = int(np.sqrt(in_dim_fold))
        out_dim_sqrt = int(np.sqrt(out_dim_fold))
        x = np.linspace(1/(2*in_dim_sqrt), 1-1/(2*in_dim_sqrt), num=in_dim_sqrt)
        X, Y = np.meshgrid(x, x)
        self.in_coordinates = torch.tensor(np.transpose(np.array([X.reshape(-1,), Y.reshape(-1,)])), dtype=torch.float)

        # compute out_cor, shape: (out_dim_sqrt, out_dim_sqrt)
        if out_ring:
            thetas = np.linspace(1/(2*out_dim_fold)*2*np.pi, (1-1/(2*out_dim_fold))*2*np.pi, num=out_dim_fold)
            self.out_coordinates = 0.5+torch.tensor(np.transpose(np.array([np.cos(thetas), np.sin(thetas)]))/4, dtype=torch.float)
        else:
            x = np.linspace(1/(2*out_dim_sqrt), 1-1/(2*out_dim_sqrt), num=out_dim_sqrt)
            X, Y = np.meshgrid(x, x)
            self.out_coordinates = torch.tensor(np.transpose(np.array([X.reshape(-1,), Y.reshape(-1,)])), dtype=torch.float)


    def forward(self, x):
        return self.linear(x)


class BioMLP2D(nn.Module):
    def __init__(self, in_dim=2, out_dim=2, w=2, depth=2, shp=None, token_embedding=False, embedding_size=None):
        super(BioMLP2D, self).__init__()
        if shp == None:
            shp = [in_dim] + [w]*(depth-1) + [out_dim]
            self.in_dim = in_dim
            self.out_dim = out_dim
            self.depth = depth

        else:
            self.in_dim = shp[0]
            self.out_dim = shp[-1]
            self.depth = len(shp) - 1
        linear_list = []
        for i in range(self.depth):
            if i == 0:
                # for modular addition
                #linear_list.append(BioLinear(shp[i], shp[i+1], in_fold=2))
                # for regression
                linear_list.append(BioLinear2D(shp[i], shp[i+1], in_fold=1).to(device))
            elif i == self.depth - 1:
                linear_list.append(BioLinear2D(shp[i], shp[i+1], in_fold=1, out_ring=True).to(device))
            else:
                linear_list.append(BioLinear2D(shp[i], shp[i+1]).to(device))
        self.linears = nn.ModuleList(linear_list)


        if token_embedding == True:
            # embedding size: number of tokens * embedding dimension
            self.embedding = torch.nn.Parameter(torch.normal(0,1,size=embedding_size)).to(device)

        self.shp = shp
        # parameters for the bio-inspired trick
        self.l0 = 0.5 # distance between two nearby layers
        self.in_perm = nn.Parameter(torch.tensor(np.arange(int(self.in_dim/self.linears[0].in_fold)), dtype=torch.float))
        self.out_perm = nn.Parameter(torch.tensor(np.arange(int(self.out_dim/self.linears[-1].out_fold)), dtype=torch.float))
        self.top_k = 30
        self.token_embedding = token_embedding

    def forward(self, x):
        shp = x.shape
        x = x.reshape(shp[0],-1)
        shp = x.shape
        in_fold = self.linears[0].in_fold
        x = x.reshape(shp[0], in_fold, int(shp[1]/in_fold))
        x = x[:,:,self.in_perm.long()]
        x = x.reshape(shp[0], shp[1])
        f = torch.nn.SiLU()
        for i in range(self.depth-1):
            x = f(self.linears[i](x))
        x = self.linears[-1](x)

        out_perm_inv = torch.zeros(self.out_dim, dtype=torch.long)
        out_perm_inv[self.out_perm.long()] = torch.arange(self.out_dim)
        x = x[:,out_perm_inv]
        #x = x[:,self.out_perm]

        return x

    def get_linear_layers(self):
        return self.linears

    def get_cc(self, weight_factor=2.0, bias_penalize=True, no_penalize_last=False):
        # compute connection cost
        cc = 0
        num_linear = len(self.linears)
        for i in range(num_linear):
            if i == num_linear - 1 and no_penalize_last:
                weight_factor = 0.
            biolinear = self.linears[i].to(device)
            dist = torch.sum(torch.abs(biolinear.out_coordinates.unsqueeze(dim=1) - biolinear.in_coordinates.unsqueeze(dim=0)),dim=2).to(device)
            # print("yooo")
            # print(biolinear.linear.weight.device)
            # print(dist.device)
            # print(self.l0)
            cc += torch.mean(torch.abs(biolinear.linear.weight)*(weight_factor*dist+self.l0))
            if bias_penalize == True:
                cc += torch.mean(torch.abs(biolinear.linear.bias)*(self.l0))
        if self.token_embedding:
            cc += torch.mean(torch.abs(self.embedding)*(self.l0))
            #pass
        return cc

    def swap_weight(self, weights, j, k, swap_type="out"):
        with torch.no_grad():
            if swap_type == "in":
                temp = weights[:,j].clone()
                weights[:,j] = weights[:,k].clone()
                weights[:,k] = temp
            elif swap_type == "out":
                temp = weights[j].clone()
                weights[j] = weights[k].clone()
                weights[k] = temp
            else:
                raise Exception("Swap type {} is not recognized!".format(swap_type))

    def swap_bias(self, biases, j, k):
        with torch.no_grad():
            temp = biases[j].clone()
            biases[j] = biases[k].clone()
            biases[k] = temp

    def swap(self, i, j, k):
        # in the ith layer (of neurons), swap the jth and the kth neuron.
        # Note: n layers of weights means n+1 layers of neurons.
        # (incoming, outgoing) * weights + biases are swapped.
        linears = self.get_linear_layers()
        num_linear = len(linears)
        if i == 0:
            return
            # for images, do not allow input_perm
            # input layer, only has outgoing weights; update in_perm
            weights = linears[i].linear.weight
            infold = linears[i].in_fold
            fold_dim = int(weights.shape[1]/infold)
            for l in range(infold):
                self.swap_weight(weights, j+fold_dim*l, k+fold_dim*l, swap_type="in")
            # change input_perm. do not allow input_perm for images
            self.swap_bias(self.in_perm, j, k)
        elif i == num_linear:
            # output layer, only has incoming weights and biases; update out_perm
            weights = linears[i-1].linear.weight
            biases = linears[i-1].linear.bias
            self.swap_weight(weights, j, k, swap_type="out")
            self.swap_bias(biases, j, k)
            # change output_perm
            self.swap_bias(self.out_perm, j, k)
        else:
            # middle layer : (incoming, outgoing) * weights, and biases
            weights_in = linears[i-1].linear.weight
            weights_out = linears[i].linear.weight
            biases = linears[i-1].linear.bias
            self.swap_weight(weights_in, j, k, swap_type="out")
            self.swap_weight(weights_out, j, k, swap_type="in")
            self.swap_bias(biases, j, k)

    def get_top_id(self, i, top_k=20):
        linears = self.get_linear_layers()
        num_linear = len(linears)
        if i == 0:
            # input layer
            weights = linears[i].linear.weight
            score = torch.sum(torch.abs(weights), dim=0)
            in_fold = linears[0].in_fold
            #print(score.shape)
            score = torch.sum(score.reshape(in_fold, int(score.shape[0]/in_fold)), dim=0)
        elif i == num_linear:
            # output layer
            weights = linears[i-1].linear.weight
            score = torch.sum(torch.abs(weights), dim=1)
        else:
            weights_in = linears[i-1].linear.weight
            weights_out = linears[i].linear.weight
            score = torch.sum(torch.abs(weights_out), dim=0) + torch.sum(torch.abs(weights_in), dim=1)
        #print(score.shape)
        top_index = torch.flip(torch.argsort(score),[0])[:top_k]
        return top_index, score

    def relocate_ij(self, i, j):
        # In the ith layer (of neurons), relocate the jth neuron
        linears = self.get_linear_layers()
        num_linear = len(linears)
        if i < num_linear:
            num_neuron = int(linears[i].linear.weight.shape[1]/linears[i].in_fold)
        else:
            num_neuron = linears[i-1].linear.weight.shape[0]
        ccs = []
        for k in range(num_neuron):
            self.swap(i,j,k)
            ccs.append(self.get_cc())
            self.swap(i,j,k)
        k = torch.argmin(torch.stack(ccs))
        self.swap(i,j,k)

    def relocate_i(self, i):
        # Relocate neurons in the ith layer
        top_id = self.get_top_id(i, top_k=self.top_k)
        for j in top_id[0]:
            self.relocate_ij(i,j)

    def relocate(self):
        # Relocate neurons in the whole model
        linears = self.get_linear_layers()
        num_linear = len(linears)
        for i in range(num_linear+1):
            self.relocate_i(i)



  from .autonotebook import tqdm as notebook_tqdm


device is: mps


In [2]:
train = torchvision.datasets.MNIST(root="/tmp", train=True, transform=torchvision.transforms.ToTensor(), download=True)
test = torchvision.datasets.MNIST(root="/tmp", train=False, transform=torchvision.transforms.ToTensor(), download=True)
data_size = 60000
train = torch.utils.data.Subset(train, range(data_size))
train_loader = torch.utils.data.DataLoader(train, batch_size=100, shuffle=True)

In [3]:
def extract_activations(layer, input, output):
    return output

activations = []

def hook_fn(module, input, output):
    activations.append(output)

hooks = []

def measure_improvement(patched_logits, original_logits):
    # Compute L2 difference
    difference_L2 = torch.norm(patched_logits - original_logits, p=2).item()
    return difference_L2

def patch_neuron(activation, neuron_idx, new_value):
    patched_activation = activation.clone()
    # print("To be patched: ", patched_activation[0, neuron_idx])
    # print("patched to:", new_value)
    patched_activation[0, neuron_idx] = new_value
    return patched_activation

def sparsify2circuit_left(wt_tensor, indices_tensor):
    zero_tensor = torch.zeros_like(wt_tensor)
    zero_tensor[indices_tensor, :] = wt_tensor[indices_tensor, :]
    return zero_tensor

def sparsify2circuit_right(wt_tensor, indices_tensor):
    zero_tensor = torch.zeros_like(wt_tensor)
    zero_tensor[:, indices_tensor] = wt_tensor[:, indices_tensor]
    return zero_tensor

def get_top_n_from_each_layer(layer_dict, n):
    # top_n = {}
    top_try = {}
    for layer in layer_dict:
        # top_n[layer] = layer_dict[layer][:n]
        top_try[layer] = [z[0][1] for z in layer_dict[layer][:n]]
    return top_try

def circuit_discovery(model, clean_tensor, corr_tensor):

    # Attach hooks to all BioLinear2D layers
    for layer in model.get_linear_layers():
        hooks.append(layer.register_forward_hook(hook_fn))

    model.eval() 
    original_output = model(clean_tensor)
    clean_activations = activations.copy()  # Store activations after passing the clean tensor
    activations.clear()
    # print(clean_activations)
    model(corr_tensor)
    corrupted_activations = activations  # Store activations after passing the corrupted tensor
    # print(corrupted_activations)

    # Clear the hooks
    for hook in hooks:
        hook.remove()

    improvements = []

    # Iterate through layers
    for layer_idx, (clean_act, corrupted_act) in enumerate(zip(clean_activations, corrupted_activations)):
        num_neurons = clean_act.shape[1]
        
        for neuron_idx in range(num_neurons):
            # Patch the neuron in the corrupted activations
            patched_act = patch_neuron(corrupted_act, neuron_idx, clean_act[0, neuron_idx])

            # Store the patched activations and keep others as-is
            current_activations = [a.clone() for a in corrupted_activations]
            current_activations[layer_idx] = patched_act
            model.eval() 

            for inn_idx, sub_layer in enumerate(model.get_linear_layers()[layer_idx+1:]):
                current_activations[layer_idx+inn_idx+1] = sub_layer(current_activations[layer_idx+inn_idx])

            improvement = measure_improvement(current_activations[-1], original_output)  # Implement measure_improvement as required
            improvements.append((layer_idx, neuron_idx, improvement))

    sorted_neurons = sorted(improvements, key=lambda x: x[2], reverse=False)
    return sorted_neurons


In [7]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

# Load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(root="/tmp", train=True, transform=torchvision.transforms.ToTensor(), download=True)

# Create datasets for each class
class_3_dataset = [(image, label) for image, label in train_dataset if label == 3]
class_5_dataset = [(image, label) for image, label in train_dataset if label == 5]
class_6_dataset = [(image, label) for image, label in train_dataset if label == 6]
class_8_dataset = [(image, label) for image, label in train_dataset if label == 8]

class_1_dataset = [(image, label) for image, label in train_dataset if label == 1]
class_4_dataset = [(image, label) for image, label in train_dataset if label == 4]
class_7_dataset = [(image, label) for image, label in train_dataset if label == 7]
class_9_dataset = [(image, label) for image, label in train_dataset if label == 9]

# 1 3

# 4 9

# 7 9

# 1 2 3 4 5 6 7 8 9 0
# class_8_dataset = [(image, label) for image, label in train_dataset if label == 8]

# Custom dataset for paired data
class PairedDataset(Dataset):
    def __init__(self, dataset1, dataset2):
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def __len__(self):
        # Ensure both datasets are of the same length
        return min(len(self.dataset1), len(self.dataset2))

    def __getitem__(self, idx):
        image1, _ = self.dataset1[idx % len(self.dataset1)]
        image2, _ = self.dataset2[idx % len(self.dataset2)]
        return image1, image2

paired_dataset_1_3 = PairedDataset(class_3_dataset, class_1_dataset)
paired_dataset_4_9 = PairedDataset(class_9_dataset, class_4_dataset)
paired_dataset_7_9 = PairedDataset(class_9_dataset, class_7_dataset)
# Create the paired datasets
paired_dataset_3_8 = PairedDataset(class_3_dataset, class_8_dataset)
paired_dataset_5_6 = PairedDataset(class_5_dataset, class_6_dataset)

# Combine the paired datasets
final_dataset = paired_dataset_3_8 + paired_dataset_5_6

# Create a DataLoader for the final dataset
final_loader = DataLoader(final_dataset, batch_size=1, shuffle=True)
# loader_3_8 = DataLoader(paired_dataset_3_8, batch_size=1, shuffle=True)

final_st = paired_dataset_1_3 + paired_dataset_4_9 + paired_dataset_7_9
st_loader = DataLoader(final_st, batch_size=1, shuffle=True)


class_3_and_8_dataset = torch.utils.data.Subset(train_dataset, [i for i, (_, label) in enumerate(train_dataset) if label in [3, 8]])

# Create a DataLoader for the filtered dataset
acc_loader = DataLoader(class_3_and_8_dataset, batch_size=1, shuffle=True)

# # Example: Iterate over the DataLoader
# for i, (image1, image2) in enumerate(final_loader):
#     print(f"Batch {i}:")
#     print("Image from Class 3/5:", image1.shape)
#     print("Image from Class 6/8:", image2.shape)
#     if i == 5:  # Stop after 6 iterations for demonstration
#         break


In [5]:
import json

# Replace 'path_to_file.json' with the path to your JSON file
file_path = 'model_results.json'

# Open the JSON file and load its contents into a Python object
with open(file_path, 'r') as file:
    data = json.load(file)

# Now 'data' is a Python object containing the data from the JSON file
print(data)

{'0': {'top_15_neurons': [{'2': [7, 3, 6, 9, 5, 1, 4, 0, 8, 2], '1': [25, 63, 36, 46, 66, 39, 30, 26, 56, 29, 40, 93, 31, 84, 0, 32, 23, 71, 5, 7, 22, 43, 96, 6, 45], '0': [55, 44, 42, 66, 54, 67, 73, 51, 23, 11, 75, 33, 88, 89, 1, 41, 29, 83, 81, 50, 17, 4, 14, 91, 39]}, {'2': [3, 7, 6, 0, 5, 4, 9, 1, 2, 8], '1': [63, 25, 66, 26, 39, 56, 23, 30, 22, 85, 93, 29, 7, 2, 0, 48, 71, 40, 79, 6, 76, 84, 91, 95, 31], '0': [55, 44, 67, 66, 42, 84, 51, 73, 11, 23, 46, 31, 5, 91, 10, 14, 30, 48, 58, 59, 95, 81, 2, 1, 62]}, {'2': [3, 7, 6, 2, 0, 1, 9, 4, 5, 8], '1': [63, 25, 66, 46, 39, 56, 71, 23, 26, 2, 50, 93, 35, 21, 79, 95, 44, 7, 89, 84, 96, 5, 43, 6, 31], '0': [55, 44, 46, 66, 36, 73, 51, 67, 23, 42, 65, 54, 11, 33, 47, 2, 10, 48, 29, 20, 71, 91, 12, 17, 19]}, {'2': [3, 7, 6, 8, 9, 5, 2, 0, 1, 4], '1': [63, 36, 25, 46, 66, 26, 39, 23, 93, 56, 31, 6, 7, 71, 30, 40, 95, 22, 51, 75, 59, 29, 14, 35, 5], '0': [55, 54, 67, 75, 66, 33, 42, 46, 84, 51, 23, 10, 72, 4, 22, 9, 20, 29, 48, 61, 70, 91,

In [10]:
data

{'0': {'top_15_neurons': [{'2': [7, 3, 6, 9, 5, 1, 4, 0, 8, 2],
    '1': [25,
     63,
     36,
     46,
     66,
     39,
     30,
     26,
     56,
     29,
     40,
     93,
     31,
     84,
     0,
     32,
     23,
     71,
     5,
     7,
     22,
     43,
     96,
     6,
     45],
    '0': [55,
     44,
     42,
     66,
     54,
     67,
     73,
     51,
     23,
     11,
     75,
     33,
     88,
     89,
     1,
     41,
     29,
     83,
     81,
     50,
     17,
     4,
     14,
     91,
     39]},
   {'2': [3, 7, 6, 0, 5, 4, 9, 1, 2, 8],
    '1': [63,
     25,
     66,
     26,
     39,
     56,
     23,
     30,
     22,
     85,
     93,
     29,
     7,
     2,
     0,
     48,
     71,
     40,
     79,
     6,
     76,
     84,
     91,
     95,
     31],
    '0': [55,
     44,
     67,
     66,
     42,
     84,
     51,
     73,
     11,
     23,
     46,
     31,
     5,
     91,
     10,
     14,
     30,
     48,
     58,
     59,
     95,
     81,
     2,
 

In [11]:
data[0]

KeyError: 0

In [38]:

models = ["fivemodels/bimt.pt", "fivemodels/l1local.pt", "fivemodels/l1only.pt", "fivemodels/l1swap.pt", "fivemodels/fully_dense.pt"]


from torch.utils.data import DataLoader, Subset
import random
import gc

# Assuming you have a function `circuit_discovery` that returns top_15_neurons, avg_logit_diff, stop_time, sparsity
# def circuit_discovery(model, data_loader): ...

# Your DataLoader and models setup remains the same
# final_loader = DataLoader(final_dataset, batch_size=1, shuffle=True)
# models = ["fivemodels/bimt.pt", "fivemodels/l1local.pt", "fivemodels/l1only.pt", "fivemodels/l1swap.pt", "fivemodels/fully_dense.pt"]
models = ["fivemodels/bimt.pt", "fivemodels/l1local.pt", "fivemodels/l1only.pt", "fivemodels/l1swap.pt", "fivemodels/fully_dense.pt"]

# def accuracy(network, data_loader, device, N=2000, batch_size=50):
#     # dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
#     correct = 0
#     total = 0
#     for x, labels in islice(data_loader, N // batch_size):
#         #print(x.shape)
#         logits = network(x.to(device))
#         predicted_labels = torch.argmax(logits, dim=1)
#         correct += torch.sum(predicted_labels == labels.to(device))
#         total += x.size(0)
#     return correct / total

# del activations
activations = []
circuit_acc = []
for val, model_path in enumerate(models):
    mlp = BioMLP2D(shp=(784,100,100,10)).to("cpu")
    mlp.load_state_dict(torch.load(model_path))
    correct = 0
    total = 0
    avg_logit_diff = 0
    for i, (image2, labels) in enumerate(acc_loader):
        clean_tensor = image2
        # Attach hooks to all BioLinear2D layers
        for layer in mlp.get_linear_layers():
            hooks.append(layer.register_forward_hook(hook_fn))
        mlp.eval()
        # with torch.no_grad():
        og_op = mlp(clean_tensor)
        # print(og_op)
        # print(activations)
        clean_activations = activations.copy()
        # print(clean_activations)
        for hook in hooks:
            hook.remove()
        sub_graph_act = [a.clone() for a in clean_activations]
        # print(len(sub_graph_act))
        top_15 = data[str(val)]['top_15_neurons']
        for inn_idx, sub_layer in enumerate(mlp.get_linear_layers()[1:]):
            # print(inn_idx)
            zero_tensor = torch.zeros_like(sub_graph_act[inn_idx])
            zero_tensor[0, torch.tensor(top_15[0][str(inn_idx)])] = sub_graph_act[inn_idx][0, torch.tensor(top_15[0][str(inn_idx)])]
            sub_graph_act[inn_idx] = zero_tensor
            # print(sub_graph_act[inn_idx].shape)
            sub_graph_act[inn_idx+1] = sub_layer(sub_graph_act[inn_idx])
        
        predicted_labels = torch.argmax(sub_graph_act[-1], dim=1)
        correct += torch.sum(predicted_labels == labels)
        total += 1
        if i>100:
            break
    print(correct/total)
    circuit_acc.append(correct/total)
        # avg_logit_diff += measure_improvement(sub_graph_act[-1], og_op)
        # break
    # avg_logit_diff = avg_logit_diff/len(subset_loader)
    # print("Average Logit Diff: ", avg_logit_diff)

print(circuit_acc)
# for i in data.keys():
#     print(data[i]['top_15_neurons'])

tensor(0.9804)
tensor(0.9314)
tensor(0.9608)
tensor(0.5098)
tensor(1.)
[tensor(0.9804), tensor(0.9314), tensor(0.9608), tensor(0.5098), tensor(1.)]


In [26]:
top_15 = data[str(1)]['top_15_neurons']

In [29]:
len(top_15)

20