In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import copy

In [2]:
#This class generates correlated memory vectors as decribed in (Benna, Fusi; 2021)
class CorrelatedPatterns():
    def __init__(self, 
                 L, #Length of each memory vector 
                 p, #Number of ancestors
                 k, #Number of children per ancestor
                 gamma): #Average overlap between child and ancestor. A value of one means each child is identical to its ancestor,
                        #while a value of zero means each child is completely different from its ancestor.
        self.L = L
        self.p = p
        self.k = k
        self.gamma = gamma
        
        #Create three arrays to store the ancestor vectors, the descendant (child) vectors, and the difference vectors
        self.ancestors = []
        self.descendants = []
        self.differences = []
        
        #For purposes of PyTorch dataset creation, we will create two new lists that do not themselves contain lists
        self.descendants_singlelist = []
        self.differences_singlelist = []
        
        for _ancestorIndex in range(p):
            
            #Each ancestor is initialized randomly
            ancestor = np.random.choice((0,1), size=(L))
            self.ancestors.append(np.array(ancestor))
            
            self.descendants.append([])
            #Initialize k descendants
            for _descendantIndex in range(k):
                descendant = torch.tensor([])
                for __i in range(len(ancestor)):
                    
                    #With probability 1-gamma, the descendant memory is corrupted at this bit. 
                    if(random.uniform(0,1) < 1-gamma):
                        descendant = torch.cat((descendant, torch.tensor([0 if random.uniform(0,1) < 0.5 else 1])))
                    else: #Otherwise, the ancestor's memory at this bit is copied to the descendant.
                        descendant = torch.cat((descendant, torch.tensor([ancestor[__i]])))
                
                #Save the memory
                self.descendants[_ancestorIndex].append(descendant.clone().detach())
                self.descendants_singlelist.append(descendant.clone().detach().reshape(1,-1))
            
            #Calculate the differences between the ancestor vectors and the child vectors
            self.differences.append([])
            for _descendantIndex in range(k):
                self.differences[_ancestorIndex].append(torch.tensor(self.ancestors[_ancestorIndex]) - self.descendants[_ancestorIndex][_descendantIndex])
                self.differences_singlelist.append((torch.tensor(self.ancestors[_ancestorIndex]) - self.descendants[_ancestorIndex][_descendantIndex]).reshape(1,-1))
                
        self.descendants_singlelist = torch.cat(self.descendants_singlelist)
        self.differences_singlelist = torch.cat(self.differences_singlelist)

#This subclass inherits the PyTorch Dataset class in order to create datasets of correlated memory.
class SensoryData(Dataset):
    def __init__(self, 
                 L,      #Length of each sample
                 p,      #Number of parents
                 k,      #Number of children per parent 
                 gamma   #Overlap between parent and children (1=identical, 0=no overlap)
                ):
        super().__init__()
        c = CorrelatedPatterns(L, p, k, gamma)
        memories = c.descendants_singlelist
        
        #Grab the memories generated by CorrelatedPatterns()
        self.data = memories
        self.x = memories
        self.y = memories
        self.n_samples = memories.shape[0]
    
    #Implement necessary helper functions
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [3]:
#data = SensoryData(100, 100, 1000, 0.6)
#torch.save(data, "dataset.pt")
data = torch.load("dataset.pt")

In [4]:
class Autoencoder(nn.Module):
    def __init__(self, 
                 n_inputs, #Number of input units
                 n_hiddens): #Number of hidden units
        super().__init__()
        
        self.eweight = nn.Parameter(torch.rand(n_hiddens, n_inputs), requires_grad=True)
        #self.initial_weights = self.eweight.clone()
        self.initial_state_dict = copy.deepcopy(self.state_dict())
        
        
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        
    #Implement the forward pass
    def forward(self, X):
        X = torch.flatten(X, start_dim=1)
        
        self.encoded = F.linear(X, self.eweight)
        self.hidden_activations = F.relu(self.encoded)
        self.decoded = F.linear(self.hidden_activations, self.eweight.T)
        
        return self.decoded, self.hidden_activations
    


In [5]:
model = Autoencoder(100,40)
optimizer = torch.optim.Adam(model.parameters())

In [6]:
loader = DataLoader(data, batch_size=1, shuffle=True)

In [7]:
focus_array = []
for i in range(30):
    focus_array.append(next(iter(loader)))
    
def update_focus_array():
    pop_item = focus_array.pop(0)
    push_item = next(iter(loader))
    
    focus_array.append(push_item)
    
    return push_item

In [9]:
#Function written by Huidi Li
def compute_gradmask(model, grads, ratio=0.1): 
    masks = []
    for i, p in enumerate(model.parameters()):
        grads_shape = grads[i].shape
        grads_sorted, grads_sort_idx = torch.sort(torch.abs(grads[i]).flatten())
        min_idx = int(ratio * len(grads_sorted))
        mask = abs(grads[i])<grads_sorted[min_idx]
        masks.append(mask.reshape(grads_shape))
    return masks



In [10]:
def train(
            model,
            optimizer,
            loss_function,
            outer_loop_epochs=1000,
            inner_loop_epochs=100, 
            alpha=0.5,
            reset_ratio=0.2
        ):
    
    model.train()
    
    initial_parameters = []
    for param_ind, param in enumerate(model.parameters()):
        initial_parameters.append(param)
    
    
    weight_history = []
    
    for outer_epoch in range(outer_loop_epochs):
        
        e_item = update_focus_array()
        
        Xs = torch.Tensor()
        Ys = torch.Tensor()
        for i in range(len(focus_array)):
            Xs = torch.cat([Xs, focus_array[i][0]])
            Ys = torch.cat([Ys, focus_array[i][1]])
        
        model_gradients_ma = []
        
        
        W_old = []
        for param_ind, param in enumerate(model.parameters()):
            W_old.append(copy.deepcopy(param))
        
        for inner_epoch in range(inner_loop_epochs):
            
            predicted_y, net_hidden_activity = model(Xs)
            loss = loss_function(predicted_y, Ys)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            for param_ind, param in enumerate(model.parameters()):
                if(inner_epoch == 0):
                    model_gradients_ma.append(torch.abs(param.grad))
                else:
                    model_gradients_ma[param_ind] = alpha*torch.abs(param.grad) + (1-alpha)*model_gradients_ma[param_ind]
       
        
        W_new = []
        for param_ind, param in enumerate(model.parameters()):
            W_new.append(copy.deepcopy(param))
        
        #Save weights
        weight_history.append({'W_new': W_new, 
                               'e_item': e_item, 
                               'W_old': W_old})
        
        #Reset weights
        mask = compute_gradmask(model, model_gradients_ma, ratio=reset_ratio)
        state_dict = model.state_dict()
        param_index = 0
        for param_name, param_value in state_dict.items():
            state_dict[param_name][mask[param_index]] = model.initial_state_dict[param_name][mask[param_index]]
            param_index += 1
            
        model.load_state_dict(state_dict)
        
    return weight_history

                    

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = nn.MSELoss()

In [12]:
output_data = train(model=model,
      optimizer=optimizer,
      loss_function=loss_function,
      outer_loop_epochs=10,
      inner_loop_epochs=64
     )

In [13]:
output_data[0]

{'W_new': [Parameter containing:
  tensor([[ 0.6643,  0.0258,  0.2156,  ...,  0.6712,  0.6269,  0.6567],
          [ 0.3130,  0.2565,  0.4083,  ...,  0.5354,  0.2436,  0.7186],
          [ 0.3074,  0.1699,  0.6967,  ...,  0.3992,  0.5833,  0.3688],
          ...,
          [ 0.6793,  0.5628,  0.0343,  ...,  0.4536,  0.1579, -0.0436],
          [ 0.8813,  0.7844,  0.5189,  ...,  0.7373,  0.7919,  0.7031],
          [ 0.6702,  0.2972,  0.2329,  ...,  0.9228,  0.6751,  0.6576]],
         requires_grad=True)],
 'e_item': [tensor([[0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 0., 1.,
           1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.,
           1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1.,
           0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1.,
           0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0.,
           1., 0., 1., 0., 0., 1., 1., 0., 1., 0.]]),
  tensor(

In [14]:
arr = np.array(output_data)

In [15]:
#np.save('weight_history', arr)

In [None]:
'''

Rough work starts below

'''

In [333]:
output_data[1][0][0]

Parameter containing:
tensor([[ 0.1183,  0.4015,  0.0382,  ...,  0.5464,  0.6362,  0.6177],
        [ 0.0605,  0.8527,  0.6117,  ...,  0.3146, -0.0758,  0.0246],
        [ 0.7188,  0.5335,  0.7015,  ...,  0.4799,  0.1936,  0.7696],
        ...,
        [ 0.2808, -0.0324,  0.6383,  ...,  0.4616,  0.0205,  0.3546],
        [-0.2984, -0.0422,  0.3502,  ..., -0.0270,  0.1248,  0.0226],
        [ 0.7495,  0.8494, -0.0614,  ...,  0.2171, -0.3807,  0.0447]],
       requires_grad=True)

In [375]:
b = np.load('weight_history.npy', allow_pickle=True)

In [335]:
model.initial_state_dict

OrderedDict([('eweight',
              tensor([[0.4730, 0.4598, 0.1763,  ..., 0.7096, 0.9873, 0.7175],
                      [0.1610, 0.9123, 0.6327,  ..., 0.3353, 0.0253, 0.0451],
                      [0.8188, 0.5927, 0.7222,  ..., 0.5002, 0.2939, 0.7897],
                      ...,
                      [0.3417, 0.0275, 0.6592,  ..., 0.4821, 0.1452, 0.3750],
                      [0.0950, 0.3385, 0.3902,  ..., 0.1427, 0.1639, 0.0621],
                      [0.8307, 0.9285, 0.4003,  ..., 0.3872, 0.0826, 0.0645]]))])

In [349]:
torch.count_nonzero(output_data[0][0][0] == output_data[1][2][0])

tensor(3200)

In [347]:
output_data[0][0] 

[Parameter containing:
 tensor([[ 1.3762e-01,  4.2061e-01,  5.8512e-02,  ...,  5.6629e-01,
           6.5550e-01,  6.3742e-01],
         [ 8.0628e-02,  8.7264e-01,  5.0915e-01,  ...,  1.6365e-01,
          -5.5672e-02,  4.1318e-03],
         [ 7.3861e-01,  5.5311e-01,  5.9947e-01,  ...,  2.9128e-01,
           2.1330e-01,  7.0604e-01],
         ...,
         [ 3.0080e-01, -1.2610e-02,  2.8195e-01,  ...,  3.2624e-01,
           4.0471e-02,  3.3377e-01],
         [-2.7908e-01, -2.3132e-02,  3.7044e-01,  ..., -7.1716e-03,
           1.4411e-01,  4.2338e-02],
         [ 7.6895e-01,  8.6863e-01, -4.1126e-02,  ...,  2.3707e-01,
          -3.6131e-01,  7.9244e-04]], requires_grad=True)]

In [286]:
mask = compute_gradmask(model, grads)
state_dict = model.state_dict()
param_index = 0
for param_name, param_value in state_dict.items():
    state_dict[param_name][mask[param_index]] = model.initial_state_dict[param_name][mask[param_index]]
    param_index += 1

model.load_state_dict(state_dict)


<All keys matched successfully>

In [278]:
model.state_dict()

OrderedDict([('eweight',
              tensor([[0.4730, 0.4598, 0.1763,  ..., 0.7096, 0.9873, 0.7175],
                      [0.1610, 0.9123, 0.6327,  ..., 0.3353, 0.0253, 0.0451],
                      [0.8188, 0.5927, 0.7222,  ..., 0.5002, 0.2939, 0.7897],
                      ...,
                      [0.3417, 0.0275, 0.6592,  ..., 0.4821, 0.1452, 0.3750],
                      [0.0950, 0.3385, 0.3902,  ..., 0.1427, 0.1639, 0.0621],
                      [0.8307, 0.9285, 0.4003,  ..., 0.3872, 0.0826, 0.0645]]))])

In [209]:
sd = model.state_dict()

In [212]:
sd['eweight'][0][0] = 0

In [213]:
sd

OrderedDict([('eweight',
              tensor([[0.0000, 0.8409, 0.0066,  ..., 0.8660, 0.1656, 0.7450],
                      [0.7339, 0.6887, 0.8721,  ..., 0.6556, 0.5670, 0.9647],
                      [0.2318, 0.6499, 0.2764,  ..., 0.2299, 0.2761, 0.7626],
                      ...,
                      [0.9376, 0.8365, 0.7383,  ..., 0.6984, 0.7721, 0.1448],
                      [0.4372, 0.6897, 0.2718,  ..., 0.3009, 0.0138, 0.2842],
                      [0.6665, 0.7511, 0.4186,  ..., 0.6096, 0.6348, 0.4669]]))])

In [214]:
model.load_state_dict(sd)

<All keys matched successfully>

In [215]:
model.state_dict()

OrderedDict([('eweight',
              tensor([[0.0000, 0.8409, 0.0066,  ..., 0.8660, 0.1656, 0.7450],
                      [0.7339, 0.6887, 0.8721,  ..., 0.6556, 0.5670, 0.9647],
                      [0.2318, 0.6499, 0.2764,  ..., 0.2299, 0.2761, 0.7626],
                      ...,
                      [0.9376, 0.8365, 0.7383,  ..., 0.6984, 0.7721, 0.1448],
                      [0.4372, 0.6897, 0.2718,  ..., 0.3009, 0.0138, 0.2842],
                      [0.6665, 0.7511, 0.4186,  ..., 0.6096, 0.6348, 0.4669]]))])

In [129]:
sample = next(iter(loader))
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [130]:
predicted_y = model(sample[0])
y = sample[1]
loss = loss_function(predicted_y[0], y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

grads = []

for p_i, p in enumerate(model.parameters()):
    print(p_i, '\n', p.grad)
    grads.append(p.grad)

0 
 tensor([[825.4267, 341.2976, 790.4587,  ..., 844.1107, 304.6584, 813.0291],
        [812.0034, 306.0449, 780.6472,  ..., 828.7575, 273.1901, 800.8862],
        [822.4730, 327.8063, 788.8873,  ..., 840.4185, 292.6153, 810.5654],
        ...,
        [800.7516, 337.6969, 766.1525,  ..., 819.2385, 301.4442, 788.4847],
        [832.0778, 329.2175, 798.3475,  ..., 850.1005, 293.8751, 820.1189],
        [904.3927, 356.7924, 867.8372,  ..., 923.9250, 318.4898, 891.4323]])


In [235]:
grads

[tensor([[825.4267, 341.2976, 790.4587,  ..., 844.1107, 304.6584, 813.0291],
         [812.0034, 306.0449, 780.6472,  ..., 828.7575, 273.1901, 800.8862],
         [822.4730, 327.8063, 788.8873,  ..., 840.4185, 292.6153, 810.5654],
         ...,
         [800.7516, 337.6969, 766.1525,  ..., 819.2385, 301.4442, 788.4847],
         [832.0778, 329.2175, 798.3475,  ..., 850.1005, 293.8751, 820.1189],
         [904.3927, 356.7924, 867.8372,  ..., 923.9250, 318.4898, 891.4323]]),
 tensor([[825.4267, 341.2976, 790.4587,  ..., 844.1107, 304.6584, 813.0291],
         [812.0034, 306.0449, 780.6472,  ..., 828.7575, 273.1901, 800.8862],
         [822.4730, 327.8063, 788.8873,  ..., 840.4185, 292.6153, 810.5654],
         ...,
         [800.7516, 337.6969, 766.1525,  ..., 819.2385, 301.4442, 788.4847],
         [832.0778, 329.2175, 798.3475,  ..., 850.1005, 293.8751, 820.1189],
         [904.3927, 356.7924, 867.8372,  ..., 923.9250, 318.4898, 891.4323]])]

In [236]:
mask = compute_gradmask(model, grads)

In [237]:
mask

[tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]])]

In [219]:
sd['eweight'][mask[0]] = model.initial_weights[mask[0]]

In [223]:
sd['eweight'][mask[0]] = torch.zeros(40,100)[mask[0]]

In [226]:
torch.count_nonzero(sd['eweight'])

tensor(3599)

In [114]:
model.state_dict()['eweight'][mask[0]] = model.state_dict()['initial_weights'][mask[0]]

In [151]:
with torch.no_grad():
    model.state_dict()['initial_weights'] = 0

In [125]:
model.state_dict()['initial_weights'].zero_()

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [152]:
model.state_dict()['initial_weights']

tensor([[0.5966, 0.5105, 0.3442,  ..., 0.2259, 0.5022, 0.2717],
        [0.0634, 0.5499, 0.6019,  ..., 0.1523, 0.7488, 0.3010],
        [0.1003, 0.1526, 0.2227,  ..., 0.7062, 0.9699, 0.6520],
        ...,
        [0.7800, 0.7819, 0.4833,  ..., 0.8788, 0.1001, 0.7249],
        [0.0936, 0.9844, 0.5971,  ..., 0.1692, 0.5801, 0.8796],
        [0.4447, 0.9624, 0.7964,  ..., 0.8942, 0.3114, 0.9893]])

In [115]:
model.state_dict()

OrderedDict([('eweight',
              tensor([[0.2358, 0.1197, 0.3148,  ..., 0.5743, 0.9482, 0.7433],
                      [0.4676, 0.1410, 0.6624,  ..., 0.3996, 0.1133, 0.0365],
                      [0.4272, 0.5696, 0.3790,  ..., 0.2302, 0.1283, 0.0144],
                      ...,
                      [0.4981, 0.5285, 0.7370,  ..., 0.6440, 0.8810, 0.5669],
                      [0.2353, 0.5997, 0.6826,  ..., 0.5670, 0.8261, 0.3660],
                      [0.3919, 0.8457, 0.1804,  ..., 0.6236, 0.0643, 0.2066]])),
             ('initial_weights',
              tensor([[0.2358, 0.1197, 0.3148,  ..., 0.5743, 0.9482, 0.7433],
                      [0.4676, 0.1410, 0.6624,  ..., 0.3996, 0.1133, 0.0365],
                      [0.4272, 0.5696, 0.3790,  ..., 0.2302, 0.1283, 0.0144],
                      ...,
                      [0.4981, 0.5285, 0.7370,  ..., 0.6440, 0.8810, 0.5669],
                      [0.2353, 0.5997, 0.6826,  ..., 0.5670, 0.8261, 0.3660],
                      [0.39

In [103]:
model.parameters()[0]

TypeError: 'generator' object is not subscriptable

In [None]:
for p_i, p in enumerate(model.parameters()):
    print(p_i, '\n', p.grad)