In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib.patches import Ellipse, Circle


seed = 1
np.random.seed(seed)
torch.manual_seed(seed)

torch.set_default_tensor_type(torch.DoubleTensor)


class BioLinear(nn.Module):

    def __init__(self, in_dim, out_dim, in_fold=1, out_fold=1):
        super(BioLinear, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(in_dim, out_dim)
        self.in_fold = in_fold
        self.out_fold = out_fold
        assert in_dim % in_fold == 0
        assert out_dim % out_fold == 0
        #compute in_cor, shape: (in_dim)
        in_dim_fold = int(in_dim/in_fold)
        out_dim_fold = int(out_dim/out_fold)
        self.in_coordinates = torch.tensor(list(np.linspace(1/(2*in_dim_fold), 1-1/(2*in_dim_fold), num=in_dim_fold))*in_fold, dtype=torch.float)
        self.out_coordinates = torch.tensor(list(np.linspace(1/(2*out_dim_fold), 1-1/(2*out_dim_fold), num=out_dim_fold))*out_fold, dtype=torch.float)
        
    def forward(self, x):
        return self.linear(x)


In [None]:
class BioMLP(nn.Module):
    def __init__(self, in_dim=2, out_dim=2, w=2, depth=2, shp=None, token_embedding=False, embedding_size=None):
        super(BioMLP, self).__init__()
        if shp == None:
            shp = [in_dim] + [w]*(depth-1) + [out_dim]
            self.in_dim = in_dim
            self.out_dim = out_dim
            self.depth = depth
                 
        else:
            self.in_dim = shp[0]
            self.out_dim = shp[-1]
            self.depth = len(shp) - 1
        linear_list = []
        for i in range(self.depth):
            if i == 0:
                # for modular addition and permutation
                linear_list.append(BioLinear(shp[i], shp[i+1], in_fold=2))
                # for regression
                #linear_list.append(BioLinear(shp[i], shp[i+1], in_fold=1))
            else:
                linear_list.append(BioLinear(shp[i], shp[i+1]))
        self.linears = nn.ModuleList(linear_list)
        
        
        if token_embedding == True:
            # embedding size: number of tokens * embedding dimension
            self.embedding = torch.nn.Parameter(torch.normal(0,1,size=embedding_size))
        
        self.shp = shp
        # parameters for the bio-inspired trick
        self.l0 = 0.5 # distance between two nearby layers
        self.in_perm = nn.Parameter(torch.tensor(np.arange(int(self.in_dim/self.linears[0].in_fold)), dtype=torch.float))
        self.out_perm = nn.Parameter(torch.tensor(np.arange(int(self.out_dim/self.linears[-1].out_fold)), dtype=torch.float))
        self.top_k = 30
        self.token_embedding = token_embedding

    def forward(self, x):
        shp = x.shape
        in_fold = self.linears[0].in_fold
        x = x.reshape(shp[0], in_fold, int(shp[1]/in_fold))
        x = x[:,:,self.in_perm.long()]
        x = x.reshape(shp[0], shp[1])
        f = torch.nn.SiLU()
        for i in range(self.depth-1):
            x = f(self.linears[i](x))
        x = self.linears[-1](x)
        
        out_perm_inv = torch.zeros(self.out_dim, dtype=torch.long)
        out_perm_inv[self.out_perm.long()] = torch.arange(self.out_dim)
        x = x[:,out_perm_inv]
        #x = x[:,self.out_perm]
        
        return x
    
    def get_linear_layers(self):
        return self.linears
    
    def get_cc(self, weight_factor=2.0, bias_penalize=True, no_penalize_last=False):
        # compute connection cost
        cc = 0
        num_linear = len(self.linears)
        for i in range(num_linear):
            if i == num_linear - 1 and no_penalize_last:
                weight_factor = 0.
            biolinear = self.linears[i]
            dist = torch.abs(biolinear.out_coordinates.unsqueeze(dim=1) - biolinear.in_coordinates.unsqueeze(dim=0))
            cc += torch.mean(torch.abs(biolinear.linear.weight)*(weight_factor*dist+self.l0))
            if bias_penalize == True:
                cc += torch.mean(torch.abs(biolinear.linear.bias)*(self.l0))
        if self.token_embedding:
            cc += torch.mean(torch.abs(self.embedding)*(self.l0))
            #pass
        return cc
    
    def swap_weight(self, weights, j, k, swap_type="out"):
        with torch.no_grad():  
            if swap_type == "in":
                temp = weights[:,j].clone()
                weights[:,j] = weights[:,k].clone()
                weights[:,k] = temp
            elif swap_type == "out":
                temp = weights[j].clone()
                weights[j] = weights[k].clone()
                weights[k] = temp
            else:
                raise Exception("Swap type {} is not recognized!".format(swap_type))
            
    def swap_bias(self, biases, j, k):
        with torch.no_grad():  
            temp = biases[j].clone()
            biases[j] = biases[k].clone()
            biases[k] = temp
    
    def swap(self, i, j, k):
        # in the ith layer (of neurons), swap the jth and the kth neuron. 
        # Note: n layers of weights means n+1 layers of neurons.
        # (incoming, outgoing) * weights + biases are swapped. 
        linears = self.get_linear_layers()
        num_linear = len(linears)
        if i == 0:
            # input layer, only has outgoing weights; update in_perm
            weights = linears[i].linear.weight
            infold = linears[i].in_fold
            fold_dim = int(weights.shape[1]/infold)
            for l in range(infold):
                self.swap_weight(weights, j+fold_dim*l, k+fold_dim*l, swap_type="in")
            # change input_perm
            self.swap_bias(self.in_perm, j, k)
        elif i == num_linear:
            # output layer, only has incoming weights and biases; update out_perm
            weights = linears[i-1].linear.weight
            biases = linears[i-1].linear.bias
            self.swap_weight(weights, j, k, swap_type="out")
            self.swap_bias(biases, j, k)
            # change output_perm
            self.swap_bias(self.out_perm, j, k)
        else:
            # middle layer : (incoming, outgoing) * weights, and biases
            weights_in = linears[i-1].linear.weight
            weights_out = linears[i].linear.weight
            biases = linears[i-1].linear.bias
            self.swap_weight(weights_in, j, k, swap_type="out")
            self.swap_weight(weights_out, j, k, swap_type="in")
            self.swap_bias(biases, j, k)

    def get_top_id(self, i, top_k=20):
        linears = self.get_linear_layers()
        num_linear = len(linears)
        if i == 0:
            # input layer
            weights = linears[i].linear.weight
            score = torch.sum(torch.abs(weights), dim=0)
            in_fold = linears[0].in_fold
            #print(score.shape)
            score = torch.sum(score.reshape(in_fold, int(score.shape[0]/in_fold)), dim=0)
        elif i == num_linear:
            # output layer
            weights = linears[i-1].linear.weight
            score = torch.sum(torch.abs(weights), dim=1)
        else:
            weights_in = linears[i-1].linear.weight
            weights_out = linears[i].linear.weight
            score = torch.sum(torch.abs(weights_out), dim=0) + torch.sum(torch.abs(weights_in), dim=1)
        #print(score.shape)
        top_index = torch.flip(torch.argsort(score),[0])[:top_k]
        return top_index
    
    def relocate_ij(self, i, j):
        # In the ith layer (of neurons), relocate the jth neuron
        linears = self.get_linear_layers()
        num_linear = len(linears)
        if i < num_linear:
            num_neuron = int(linears[i].linear.weight.shape[1]/linears[i].in_fold)
        else:
            num_neuron = linears[i-1].linear.weight.shape[0]
        ccs = []
        for k in range(num_neuron):
            self.swap(i,j,k)
            ccs.append(self.get_cc())
            self.swap(i,j,k)
        k = torch.argmin(torch.stack(ccs))
        self.swap(i,j,k)
            
    def relocate_i(self, i):
        # Relocate neurons in the ith layer
        top_id = self.get_top_id(i, top_k=self.top_k)
        for j in top_id:
            self.relocate_ij(i,j)
            
    def relocate(self):
        # Relocate neurons in the whole model
        linears = self.get_linear_layers()
        num_linear = len(linears)
        for i in range(num_linear+1):
            self.relocate_i(i)
            
    def plot(self):
        fig, ax = plt.subplots(figsize=(6,6))
        shp = self.shp
        s = 1/(2*max(shp))
        for j in range(len(shp)):
            N = shp[j]
            if j == 0:
                in_fold = self.linears[j].in_fold
                N = int(N/in_fold)
            for i in range(N):
                if j == 0:
                    for fold in range(in_fold):
                        circle = Ellipse((1/(2*N)+i/N, 0.1*j+0.02*fold-0.01), s, s/10*((len(shp)-1)+0.4), color='black')
                        ax.add_patch(circle)
                else:
                    for fold in range(in_fold):
                        circle = Ellipse((1/(2*N)+i/N, 0.1*j), s, s/10*((len(shp)-1)+0.4), color='black')
                        ax.add_patch(circle)


        plt.ylim(-0.02,0.1*(len(shp)-1)+0.02)
        plt.xlim(-0.02,1.02)

        linears = self.linears
        for ii in range(len(linears)):
            biolinear = linears[ii]
            p = biolinear.linear.weight.clone()
            p_shp = p.shape
            p = p/torch.abs(p).max()
            in_fold = biolinear.in_fold
            fold_num = int(p_shp[1]/in_fold)
            for i in range(p_shp[0]):
                if ii == 0:
                    for fold in range(in_fold):
                        for j in range(fold_num):
                            plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*fold_num)+j/fold_num], [0.1*(ii+1),0.1*ii+0.02*fold-0.01], lw=1*np.abs(p[i,j].detach().numpy()), color="blue" if p[i,j]>0 else "red")
                else:
                    for j in range(fold_num):
                        plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*fold_num)+j/fold_num], [0.1*(ii+1),0.1*ii], lw=0.5*np.abs(p[i,j].detach().numpy()), color="blue" if p[i,j]>0 else "red")
                    
        ax.axis('off')
        

In [None]:
import math

### create model ###

seed = 1
np.random.seed(seed)
torch.manual_seed(seed)

# Python function to print permutations of a given list
def permutation(lst):
 
    # If lst is empty then there are no permutations
    if len(lst) == 0:
        return []
 
    # If there is only one element in lst then, only
    # one permutation is possible
    if len(lst) == 1:
        return [lst]
 
    # Find the permutations for lst if there are
    # more than 1 characters
 
    l = [] # empty list that will store current permutation
 
    # Iterate the input(lst) and calculate the permutation
    for i in range(len(lst)):
       m = lst[i]
 
       # Extract lst[i] or m from the list.  remLst is
       # remaining list
       remLst = lst[:i] + lst[i+1:]
 
       # Generating all permutations where m is first
       # element
       for p in permutation(remLst):
           l.append([m] + p)
    return l
 

p0 = 4
data = list(np.arange(p0))
perms = []
perm2id = {}
id_ = 0
for perm in permutation(data):
    perms.append(perm)
    perm2id["{}".format(perm)] = id_
    id_ = id_ + 1
    
num = math.factorial(p0)
  
data_id = []
labels = []
for i in range(num):
    for j in range(num):
        data_id.append((i,j))
        
        p1 = perms[i]
        p2 = perms[j]
        p_out = list(np.array(p2)[np.array(p1)])
        id_ = perm2id["{}".format(p_out)]
        labels.append(id_)

data_id = np.array(data_id)
labels = np.array(labels)


p = math.factorial(p0)
d = 32
in_dim = 2*d
out_dim = p

shp = [in_dim, 50, 50, out_dim]
model = BioMLP(shp=shp, token_embedding=True, embedding_size=(p, d))

### create dataset ###
x = np.arange(p)
y = np.arange(p)
XX, YY = np.meshgrid(x, y)
data_id = np.transpose([XX.reshape(-1,), YY.reshape(-1,)])
labels = torch.tensor(labels, dtype=torch.long)
print(labels.shape)
fraction = 0.8
train_num = int(p**2*fraction)
test_num = p**2 - train_num

train_id = np.random.choice(p**2,train_num,replace=False)
test_id = np.array(list(set(np.arange(p**2)) - set(train_id)))

def get_data(id_):
    global labels
    inputs = torch.cat([model.embedding[data_id[id_][:,0]], model.embedding[data_id[id_][:,1]]], dim=1)
    return inputs, labels[id_]


### train ###
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.0)
steps = 20000
log = 200
lamb = 1
swap_log = 200
plot_log = 200

for step in range(steps):
    
    if step == int(steps*1/4):
        lamb *= 1
        
    if step == int(steps*3/4):
        lamb *= 1
    
    
    CEL = nn.CrossEntropyLoss()
    
    optimizer.zero_grad()
    inputs_train, labels_train = get_data(train_id)
    pred  = model(inputs_train)
    loss = CEL(pred, labels_train)
    acc = torch.mean((torch.argmax(pred, dim=1) == labels_train).float())
    
    inputs_test, labels_test = get_data(test_id)
    pred_test  = model(inputs_test)
    loss_test = CEL(pred_test, labels_test)
    acc_test = torch.mean((torch.argmax(pred_test, dim=1) == labels_test).float())
    
    cc = model.get_cc(weight_factor=1.0, no_penalize_last=False)
    total_loss = loss + lamb*cc
    total_loss.backward()
    optimizer.step()
    
    
    if step % log == 0:
        print("step = %d | total loss: %.2e | train loss: %.2e | test loss %.2e | cc: %.2e | train acc: %.2e | test acc: %.2e "%(step, total_loss.detach().numpy(), loss.detach().numpy(), loss_test.detach().numpy(), cc.detach().numpy(), acc.detach().numpy(), acc_test.detach().numpy()))
    if (step+1) % swap_log == 0:
        model.relocate()
        
    if step % plot_log == 0:
        pass
        #print("plot!!!")
        model.plot()
        plt.title("train: %.2f | test: %.2f" % (acc.detach().numpy(), acc_test.detach().numpy()))
        #plt.savefig("./video_figs/S4/{0:05d}.png".format(step))
        plt.show()


In [None]:
# Python function to print permutations of a given list
def permutation(lst):
 
    # If lst is empty then there are no permutations
    if len(lst) == 0:
        return []
 
    # If there is only one element in lst then, only
    # one permutation is possible
    if len(lst) == 1:
        return [lst]
 
    # Find the permutations for lst if there are
    # more than 1 characters
 
    l = [] # empty list that will store current permutation
 
    # Iterate the input(lst) and calculate the permutation
    for i in range(len(lst)):
       m = lst[i]
 
       # Extract lst[i] or m from the list.  remLst is
       # remaining list
       remLst = lst[:i] + lst[i+1:]
 
       # Generating all permutations where m is first
       # element
       for p in permutation(remLst):
           l.append([m] + p)
    return l


p0 = 4
d = 32
data = list(np.arange(p0))
perms = []
for perm in permutation(data):
    perms.append(np.array([1,2,3,4])[perm])

def arr2str(arr):
    string = ""
    for i in range(len(arr)):
        string += "{}".format(arr[i])
    return string

def parity(perm):
    ha = 0
    for j in range(4):
        for i in range(j):
            ha += perm[j] < perm[i]
    if ha%2 == 0:
        return "+"
    else:
        return "-"

model.plot()

p_shp = model.linears[-1].linear.weight.shape
for j in range(p_shp[0]):
    plt.text(1/(2*p_shp[0])+j/p_shp[0]-0.015, 0.1*3+0.01, arr2str(np.array(perms)[model.out_perm.long()][j])+"$("+"{}".format(parity(np.array(perms)[model.out_perm.long()][j]))+")$", fontsize=7,rotation=90)

for i in range(d):
    plt.text(i/d+0.007, -0.02, i+1, fontsize=6)
    
plt.title("Permutation "+r"$S_4$", y=1.05, fontsize=15)
#plt.show()
#plt.savefig('./fig/S4_graph.png', bbox_inches="tight")

In [None]:
arrs = [7,9,10,11,12,13,14,15,21]

plt.figure(figsize=(10,10))

for j in range(9):

    plt.subplot(9,1,j+1)
    data = model.embedding[:,model.in_perm.long()[[arrs[j]]]].detach().numpy()
    data = data/np.max(np.abs(data))
    plt.scatter(np.arange(24), data, color="blue", alpha=0.5)

    #for i in range(24):
        #plt.text(i-0.3, data[i]+0.1, arr2str(np.array(perms)[i])+"({})".format(parity(np.array(perms)[i])), rotation=90);
        #plt.text(i-0.3, data[i]+0.1, arr2str(np.array(perms)[i]), rotation=90);


    plt.ylim(-1.5,1.5)
    
    if j == 0:
        plt.gca().xaxis.tick_top()
        plt.xticks(np.arange(24,), [arr2str(np.array(perms)[i])+"$("+"{}".format(parity(np.array(perms)[i]))+")$" for i in range(24)] ,rotation=90)
    if j == 8:
        plt.xticks(np.arange(24,), [arr2str(np.array(perms)[i])+"$("+"{}".format(parity(np.array(perms)[i]))+")$" for i in range(24)] ,rotation=90)
    plt.ylabel("neuron "+str(arrs[j]+1))
    
    for k in range(24):
        plt.plot([k,k],[-2,2], ls="--", color="black", alpha=0.2)
    
plt.subplots_adjust(wspace=0, hspace=0)
#plt.axis('off')
#plt.title("8th dimension")

#plt.savefig('./fig/S4_embed.pdf', bbox_inches="tight")