In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import random
import math
import pickle
import numpy as np
import copy

In [3]:
#This class generates correlated memory vectors as decribed in (Benna, Fusi; 2021)
class CorrelatedPatterns():
    def __init__(self, 
                 L, #Length of each memory vector 
                 p, #Number of ancestors
                 k, #Number of children per ancestor
                 gamma): #Average overlap between child and ancestor. A value of one means each child is identical to its ancestor,
                        #while a value of zero means each child is completely different from its ancestor.
        self.L = L
        self.p = p
        self.k = k
        self.gamma = gamma
        
        #Create three arrays to store the ancestor vectors, the descendant (child) vectors, and the difference vectors
        self.ancestors = []
        self.descendants = []
        self.differences = []
        
        #For purposes of PyTorch dataset creation, we will create two new lists that do not themselves contain lists
        self.descendants_singlelist = []
        self.differences_singlelist = []
        
        for _ancestorIndex in range(p):
            
            #Each ancestor is initialized randomly
            ancestor = np.random.choice((-1,1), size=(L))
            self.ancestors.append(np.array(ancestor))
            
            self.descendants.append([])
            #Initialize k descendants
            for _descendantIndex in range(k):
                descendant = torch.tensor([])
                for __i in range(len(ancestor)):
                    
                    #With probability 1-gamma, the descendant memory is corrupted at this bit. 
                    if(random.uniform(0,1) < 1-gamma):
                        descendant = torch.cat((descendant, (torch.tensor([ancestor[__i]]) * -1)))
                    else: #Otherwise, the ancestor's memory at this bit is copied to the descendant.
                        descendant = torch.cat((descendant, torch.tensor([ancestor[__i]])))
                
                #Save the memory
                self.descendants[_ancestorIndex].append(descendant.clone().detach())
                self.descendants_singlelist.append(descendant.clone().detach().reshape(1,-1))
            
            #Calculate the differences between the ancestor vectors and the child vectors
            self.differences.append([])
            for _descendantIndex in range(k):
                self.differences[_ancestorIndex].append(torch.tensor(self.ancestors[_ancestorIndex]) - self.descendants[_ancestorIndex][_descendantIndex])
                self.differences_singlelist.append((torch.tensor(self.ancestors[_ancestorIndex]) - self.descendants[_ancestorIndex][_descendantIndex]).reshape(1,-1))
                
        self.descendants_singlelist = torch.cat(self.descendants_singlelist)
        self.differences_singlelist = torch.cat(self.differences_singlelist)

#This subclass inherits the PyTorch Dataset class in order to create datasets of correlated memory.
class SensoryData(Dataset):
    def __init__(self, 
                 L,      #Length of each sample
                 p,      #Number of parents
                 k,      #Number of children per parent 
                 gamma   #Overlap between parent and children (1=identical, 0=no overlap)
                ):
        super().__init__()
        c = CorrelatedPatterns(L, p, k, gamma)
        memories = c.descendants_singlelist
        
        #Grab the memories generated by CorrelatedPatterns()
        self.data = memories
        self.x = memories
        self.y = memories
        self.n_samples = memories.shape[0]
    
    #Implement necessary helper functions
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [4]:
low_corr_data2 = SensoryData(10, 100, 500, 0.5)
torch.save(low_corr_data2, "low_corr_dataset_len10.pt")
#low_corr_data2 = torch.load("low_corr_dataset2.pt")

In [5]:
high_corr_data2 = SensoryData(10, 100, 500, 0.9)
#torch.save(high_corr_data2, "high_corr_dataset3.pt")
high_corr_data2 = torch.load("high_corr_dataset_len10.pt")

In [6]:
class Autoencoder(nn.Module):
    def __init__(self, 
                 n_inputs, #Number of input units
                 n_hiddens): #Number of hidden units
        super().__init__()
        
        self.eweight = nn.Parameter(torch.rand(n_hiddens, n_inputs), requires_grad=True)
        #self.initial_weights = self.eweight.clone()
        self.initial_state_dict = copy.deepcopy(self.state_dict())
        
        
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        
    #Implement the forward pass
    def forward(self, X):
        X = torch.flatten(X, start_dim=1)
        
        self.encoded = F.linear(X, self.eweight)
        self.hidden_activations = F.relu(self.encoded)
        self.decoded = F.linear(self.hidden_activations, self.eweight.T)
        
        return self.decoded, self.hidden_activations
    


## High correlation training set (shuffled, ARRAY_LEN=10, correlation=0.9)

In [7]:
loader = DataLoader(high_corr_data2, batch_size=1, shuffle=True)

focus_array = []
ARRAY_LEN = 10
for i in range(ARRAY_LEN):
    focus_array.append(next(iter(loader)))
    
def update_focus_array():
    pop_item = focus_array.pop(0)
    push_item = next(iter(loader))
    
    focus_array.append(push_item)
    
    return push_item

In [8]:
#Function written by Huidi Li
def compute_gradmask(model, grads, ratio=0.1): 
    masks = []
    for i, p in enumerate(model.parameters()):
        grads_shape = grads[i].shape
        grads_sorted, grads_sort_idx = torch.sort(torch.abs(grads[i]).flatten())
        min_idx = int(ratio * len(grads_sorted))
        mask = abs(grads[i])<grads_sorted[min_idx]
        masks.append(mask.reshape(grads_shape))
    return masks



In [9]:
def AddNoise(array, bit_flip_chance):
    toggle_status = torch.rand(array.shape) > bit_flip_chance
    return array * toggle_status

def train(
            model,
            optimizer,
            loss_function,
            outer_loop_epochs=1000,
            inner_loop_epochs=100, 
            alpha=0.5,
            reset_ratio=0.3
        ):
    
    model.train()
    
    initial_parameters = []
    for param_ind, param in enumerate(model.parameters()):
        initial_parameters.append(param)
    
    
    weight_history = {}
    
    for outer_epoch in range(outer_loop_epochs):
        
        e_item = update_focus_array()
        
        Xs = torch.Tensor()
        Ys = torch.Tensor()
        for i in range(len(focus_array)):
            Xs = torch.cat([Xs, focus_array[i][0]])
            Ys = torch.cat([Ys, focus_array[i][1]])
        
        model_gradients_ma = []
        
        
        W_old = []
        for param_ind, param in enumerate(model.parameters()):
            W_old.append(copy.deepcopy(param))
        
        for inner_epoch in range(inner_loop_epochs):
        
            Xs = AddNoise(Xs, bit_flip_chance=0.2)
        
            predicted_y, net_hidden_activity = model(Xs)
            loss = loss_function(predicted_y, Ys)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            for param_ind, param in enumerate(model.parameters()):
                if(inner_epoch == 0):
                    model_gradients_ma.append(torch.abs(param.grad))
                else:
                    model_gradients_ma[param_ind] = alpha*torch.abs(param.grad) + (1-alpha)*model_gradients_ma[param_ind]
       
        
        W_new = []
        for param_ind, param in enumerate(model.parameters()):
            W_new.append(copy.deepcopy(param))
        
        #Save weights
        weight_history[outer_epoch] = {'e_item': e_item, 
                               'W_old': W_old,
                               'W_new': W_new, 
                               }
        
        #Reset weights
        mask = compute_gradmask(model, model_gradients_ma, ratio=reset_ratio)
        state_dict = model.state_dict()
        param_index = 0
        for param_name, param_value in state_dict.items():
            state_dict[param_name][mask[param_index]] = model.initial_state_dict[param_name][mask[param_index]]
            param_index += 1
            
        model.load_state_dict(state_dict)
        
    return weight_history

                    

In [32]:
model = Autoencoder(30,20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = nn.MSELoss()


In [33]:
output_data = train(model=model,
      optimizer=optimizer,
      loss_function=loss_function,
      outer_loop_epochs=50000,
      inner_loop_epochs=32
     )

In [34]:
with open('high_correlation_train.pkl', 'wb') as f:
    pickle.dump(output_data, f)

In [35]:
loader = DataLoader(high_corr_data2, batch_size=1, shuffle=True)

In [36]:
model = Autoencoder(30,20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = nn.MSELoss()


In [37]:
output_data = train(model=model,
      optimizer=optimizer,
      loss_function=loss_function,
      outer_loop_epochs=10000,
      inner_loop_epochs=32
     )

In [38]:
with open('high_correlation_test.pkl', 'wb') as f:
    pickle.dump(output_data, f)

In [10]:
loader = DataLoader(low_corr_data2, batch_size=1, shuffle=True)

In [11]:
model = Autoencoder(30,20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = nn.MSELoss()


In [12]:
output_data = train(model=model,
      optimizer=optimizer,
      loss_function=loss_function,
      outer_loop_epochs=50000,
      inner_loop_epochs=32
     )

In [13]:
with open('low_correlation_train.pkl', 'wb') as f:
    pickle.dump(output_data, f)

In [14]:
loader = DataLoader(low_corr_data2, batch_size=1, shuffle=True)

In [15]:
model = Autoencoder(30,20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_function = nn.MSELoss()


In [16]:
output_data = train(model=model,
      optimizer=optimizer,
      loss_function=loss_function,
      outer_loop_epochs=10000,
      inner_loop_epochs=32
     )

In [17]:
with open('low_correlation_test.pkl', 'wb') as f:
    pickle.dump(output_data, f)

In [None]:
output_data[0]