In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
import math
import numpy as np
import random
import torch
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader

In [3]:
#This class generates correlated memory vectors as decribed in (Benna, Fusi; 2021)
class CorrelatedPatterns():
    def __init__(self, 
                 L, #Length of each memory vector 
                 p, #Number of ancestors
                 k, #Number of children per ancestor
                 gamma): #Average overlap between child and ancestor. A value of one means each child is identical to its ancestor,
                        #while a value of zero means each child is completely different from its ancestor.
        self.L = L
        self.p = p
        self.k = k
        self.gamma = gamma
        
        #Create three arrays to store the ancestor vectors, the descendant (child) vectors, and the difference vectors
        self.ancestors = []
        self.descendants = []
        self.differences = []
        
        #For purposes of PyTorch dataset creation, we will create two new lists that do not themselves contain lists
        self.descendants_singlelist = []
        self.differences_singlelist = []
        
        for _ancestorIndex in range(p):
            
            #Each ancestor is initialized randomly
            ancestor = np.random.choice((0,1), size=(L))
            self.ancestors.append(np.array(ancestor))
            
            self.descendants.append([])
            #Initialize k descendants
            for _descendantIndex in range(k):
                descendant = torch.tensor([])
                for __i in range(len(ancestor)):
                    
                    #With probability 1-gamma, the descendant memory is corrupted at this bit. 
                    if(random.uniform(0,1) < 1-gamma):
                        descendant = torch.cat((descendant, torch.tensor([0 if random.uniform(0,1) < 0.5 else 1])))
                    else: #Otherwise, the ancestor's memory at this bit is copied to the descendant.
                        descendant = torch.cat((descendant, torch.tensor([ancestor[__i]])))
                
                #Save the memory
                self.descendants[_ancestorIndex].append(descendant.clone().detach())
                self.descendants_singlelist.append(descendant.clone().detach().reshape(1,-1))
            
            #Calculate the differences between the ancestor vectors and the child vectors
            self.differences.append([])
            for _descendantIndex in range(k):
                self.differences[_ancestorIndex].append(torch.tensor(self.ancestors[_ancestorIndex]) - self.descendants[_ancestorIndex][_descendantIndex])
                self.differences_singlelist.append((torch.tensor(self.ancestors[_ancestorIndex]) - self.descendants[_ancestorIndex][_descendantIndex]).reshape(1,-1))
                
        self.descendants_singlelist = torch.cat(self.descendants_singlelist)
        self.differences_singlelist = torch.cat(self.differences_singlelist)

#This subclass inherits the PyTorch Dataset class in order to create datasets of correlated memory.
class SensoryData(Dataset):
    def __init__(self, 
                 L,      #Length of each sample
                 p,      #Number of parents
                 k,      #Number of children per parent 
                 gamma   #Overlap between parent and children (1=identical, 0=no overlap)
                ):
        super().__init__()
        c = CorrelatedPatterns(L, p, k, gamma)
        memories = c.descendants_singlelist
        
        #Grab the memories generated by CorrelatedPatterns()
        self.data = memories
        self.x = memories
        self.y = memories
        self.n_samples = memories.shape[0]
    
    #Implement necessary helper functions
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]

Ideas to try:
- Fast weights that modify a fixed subset of the main weights only
- Meta-meta fast weights
- Fast weights that gate plasticity instead of set weights (connect with Dendritic Gated Networks)
- Different fast weights modulating the encoder and decoder
- Fast weights that modulate themselves

In [7]:
class FastWeightAE(nn.Module):
    def __init__(self, 
                 n_inputs=10,
                 n_hiddens=5,
                ):
        super().__init__()
        
        self.fast_weights = Parameter(torch.rand(n_inputs, n_inputs*n_hiddens), requires_grad=True)
        self.eweight = Parameter(torch.rand(n_hiddens, n_inputs), requires_grad=True)
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        
    def forward(self, X):
        #Limitation: X can only be a single sample
        
        #Generate weights using fast weight network
        generated_weights = F.linear(X, self.fast_weights.T)
        generated_weights = generated_weights.reshape_as(self.eweight.data)
        #self.eweight.data = generated_weights
        
        #Run the input through the main network
        encoded = F.linear(X, generated_weights)
        encoded = F.relu(encoded)
        
        decoded = F.linear(encoded, generated_weights.T)
        
        return decoded        

In [8]:
def train(loader, #Dataloader
          model,  #Model to be trained
          loss_function, #Loss function
          optimizer, #Optimizer
          n_epochs=100 #number of epochs
         ):
    
    #Toggle to training mode
    model.train()
    
    for epoch in range(n_epochs):
        total_loss = 0
        
        #Iterate through the DataLoader's batches
        for batch, (X, y) in enumerate(loader):
            #Get the model's prediction of the input
            predicted_y = model(X)
            
            #Calculate the loss
            loss = loss_function(predicted_y, y)
            total_loss += loss
            
            #Reset gradients
            optimizer.zero_grad()
            
            #Backpropagation
            loss.backward()
            
            #Update the optimizer
            optimizer.step()
        print(f'Average loss: {total_loss/len(loader)}')

In [9]:
loss_function = nn.MSELoss()
dataset = SensoryData(10, 5,7,0.5)
loader = DataLoader(dataset, batch_size=1)
model = FastWeightAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
train(loader, model, loss_function, optimizer, n_epochs=1000)

## Experiment 1: Having separate fast weights for the encoder and the decoder

In [18]:
class FastWeightAE_sep(nn.Module):
    def __init__(self, 
                 n_inputs=10,
                 n_hiddens=5,
                ):
        super().__init__()
        
        self.encoder_fast_weights = Parameter(torch.rand(n_inputs, n_inputs*n_hiddens), requires_grad=True)
        self.decoder_fast_weights = Parameter(torch.rand(n_inputs, n_hiddens*n_inputs), requires_grad=True)
        self.eweight = Parameter(torch.rand(n_hiddens, n_inputs), requires_grad=True)
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        
    def forward(self, X):
        #Limitation: X can only be a single sample
        
        #Generate weights using fast weight network
        encoder_generated_weights = F.linear(X, self.encoder_fast_weights.T)
        encoder_generated_weights = encoder_generated_weights.reshape_as(self.eweight.data)
        
        decoder_generated_weights = F.linear(X, self.decoder_fast_weights.T)
        decoder_generated_weights = decoder_generated_weights.reshape_as(self.eweight.data)
        
        #self.eweight.data = generated_weights
        
        #Run the input through the main network
        encoded = F.linear(X, encoder_generated_weights)
        encoded = F.relu(encoded)
        
        decoded = F.linear(encoded, decoder_generated_weights.T)
        
        return decoded        

In [19]:
loss_function = nn.MSELoss()
dataset = SensoryData(10, 5,7,0.5)
loader = DataLoader(dataset, batch_size=1)
model = FastWeightAE_sep()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
train(loader, model, loss_function, optimizer, n_epochs=1000)

## Experiment 2: three networks
Idea:
to avoid exponential growth in the # of parameters, adjust the model so that the fast weights in 
each layer only modify a subset of the weights in the following layer

In [18]:
class FastWeightAE_sep(nn.Module):
    def __init__(self, 
                 n_inputs=10,
                 n_hiddens=5,
                ):
        super().__init__()
        
        self.fast_weights_l2 = Parameter(torch.rand(n_inputs, n_inputs*n_inputs*n_hiddens), requires_grad=True)
        self.fast_weights_l1 = Parameter(torch.rand(n_inputs, n_inputs*n_hiddens), requires_grad=True)
        self.eweight = Parameter(torch.rand(n_hiddens, n_inputs), requires_grad=True)
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        
    def forward(self, X):
        #Limitation: X can only be a single sample
        
        #Generate weights using fast weight network
        encoder_generated_weights = F.linear(X, self.encoder_fast_weights.T)
        encoder_generated_weights = encoder_generated_weights.reshape_as(self.eweight.data)
        
        decoder_generated_weights = F.linear(X, self.decoder_fast_weights.T)
        decoder_generated_weights = decoder_generated_weights.reshape_as(self.eweight.data)
        
        #self.eweight.data = generated_weights
        
        #Run the input through the main network
        encoded = F.linear(X, encoder_generated_weights)
        encoded = F.relu(encoded)
        
        decoded = F.linear(encoded, decoder_generated_weights.T)
        
        return decoded        

In [19]:
loss_function = nn.MSELoss()
dataset = SensoryData(10, 5,7,0.5)
loader = DataLoader(dataset, batch_size=1)
model = FastWeightAE_sep()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
train(loader, model, loss_function, optimizer, n_epochs=1000)