#Setup


In [None]:
#Imports
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from google.colab import drive
drive.mount("/content/drive")
import torchvision
import torchvision.transforms as transforms
device = torch.device("cuda")
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Rectangle
img_dir = "/content/drive/MyDrive/ColabNotebooks/RelaxedOriginalCode/MyGraphs"

In [None]:
#Function to download the data and sort it into batches which can be used to train our data
def import_data(data_name):
  tensor_transform = transforms.Compose([transforms.ToTensor()])
  if data_name == "cifar":
    train_set = torchvision.datasets.CIFAR10("/content/gdrive/ColabNotebooks/Relaxed Original Code", train=True, download=True, transform=tensor_transform)
    test_set = torchvision.datasets.CIFAR10("/content/gdrive/ColabNotebooks/Relaxed Original Code", train=False, download=True, transform=tensor_transform)
  elif data_name == "fashion":
    train_set = torchvision.datasets.FashionMNIST("/content/gdrive/ColabNotebooks/Relaxed Original Code", train=True, download=True, transform=tensor_transform)
    test_set = torchvision.datasets.FashionMNIST("/content/gdrive/ColabNotebooks/Relaxed Original Code", train=False, download=True, transform=tensor_transform)
  elif data_name == "mnist":
    train_set = torchvision.datasets.MNIST("/content/gdrive/ColabNotebooks/Relaxed Original Code", train=True, download=True, transform=tensor_transform)
    test_set = torchvision.datasets.MNIST("/content/gdrive/ColabNotebooks/Relaxed Original Code", train=False, download=True, transform=tensor_transform)
  train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, drop_last = True) #Shuffles the data into batches so we can train on smaller sets at a time
  test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, drop_last = True)
  train_set = list(iter(train_dataloader))
  test_set = list(iter(test_dataloader))
  display = test_set.copy()
  for i,(img, label) in enumerate(train_set): #Reshape the images to be a 1 dimensional array
        train_set[i] = (img.reshape(len(img),-1),label)
  for i,(img, label) in enumerate(test_set):
        test_set[i] = (img.reshape(len(img),-1) ,label)
  return train_set, test_set, display

In [None]:
#Onehots an array so each class is its own individual index
def onehot(label, class_size):
  oh_label = torch.zeros([len(label),class_size])
  for i in range(len(label)):
    oh_label[i,label[i]] = 1
  return oh_label.float().to(device)

In [431]:
#Activation Functions to use for our model
def tanh(matrix):
  return torch.tanh(matrix)

def d_tanh(matrix):
  return 1.0 - (torch.tanh(matrix)**2)

def relu(matrix):
  f = torch.nn.ReLU()
  return f(matrix)

def d_relu(matrix):
  rel = relu(matrix)
  rel[rel>0] = 1
  return rel

def sigm(matrix):
  sig = nn.Sigmoid()
  return sig(matrix)

def d_sigm(matrix):
  sig_m = sigm(matrix)
  return sig_m * (1-sig_m)

#Turns the matrix into a torch tensor
def set_tensor(matrix):
  return matrix.float().to(device)

#Takes the onehotted predicted and true labels and compares them, if the max value of the 
def accuracy(prediction, label):
  batch_size,l = prediction.shape
  total = 0
  for i in range(batch_size):
    if torch.argmax(prediction[i,:]) == torch.argmax(label[i,:]):
      total +=1
  return total/ batch_size

#Produces a graph of test_data accuracy over a number of epochs and saves it to the directory
def depict(results, func, dset, v_name, o_res):
    accuracy_df = pd.DataFrame(results)
    accuracy_df = pd.concat((accuracy_df, o_res), axis=1)
    sns.set(rc={"figure.figsize":(5,3)})
    accuracy_df.plot()
    plt.title(f"{v_name} Accuracy")
    plt.xlabel("Accuracy")
    plt.ylabel("Epochs")
    plt.legend(["Relaxed", "Original"])
    plt.savefig(f"{img_dir}/{v_name}_{dset}_{func}_Accuracy_Graph.png")
    plt.show()

In [None]:
class Pred_Layer():
  def __init__(self, input_size, output_size, batch_size=64, learning_rate=0.0005, activation_function=tanh, act_func_deriv=d_tanh, use_back_weights = True, use_non_linear = False):
    #Prepare all layer specific variables
    self.input_size = input_size
    self.output_size = output_size
    self.learning_rate = learning_rate
    self.activation_function = activation_function
    self.act_func_deriv = act_func_deriv
    self.bias = torch.zeros([batch_size, self.output_size]).to(device)
    self.use_back_weights = use_back_weights
    self.use_non_linear = use_non_linear
    #Prepare the random weights
    self.weights = torch.empty([self.input_size,self.output_size]).normal_(mean=0.0,std=0.05).to(device) #Forward fed weight matrix
    self.back_weights = torch.empty([self.output_size, self.input_size]).normal_(mean=0.0,std=0.05).to(device) #Backward fed weight matrix

  def forward(self, inputs): #Calculates the prediction for the next layer
    self.inp = inputs.clone()
    #print("Input: ", self.inp.size(), "  Weight:", self.weights.size())
    self.activation = torch.matmul(self.inp,self.weights)
    #print(self.activation.size())
    return self.activation_function(self.activation) + self.bias
  
  def backward(self, errors): #Passes errors from layer above back into current layer using weights
    activation_divs = self.act_func_deriv(self.activation)
    if self.use_back_weights:
      if self.use_non_linear:
        return torch.matmul(errors*activation_divs, self.back_weights)
      else:
        return torch.matmul(errors, self.back_weights)
    else:
      if self.use_non_linear:
        return torch.matmul(errors*activation_divs, self.weights.T)
      else:
        return torch.matmul(errors, self.weights.T)

  def update_weights(self, errors): #Updates the weights
    activation_divs = self.act_func_deriv(self.activation)
    DW_back = torch.matmul((errors*activation_divs).T,self.inp) #change in backward weights
    DW_forward = torch.matmul(self.inp.T, errors*activation_divs) #change in forward weights
    self.weights += self.learning_rate * torch.clamp(DW_forward*2,-50, 50) #clamped so that it doesn't grow too far
    self.back_weights += self.learning_rate * torch.clamp(DW_back*2,-50, 50)
    self.bias += self.learning_rate * torch.clamp(errors,-50,50)

In [436]:
class Relaxed_Predictive_Net():
  def __init__(self, layers, fn, dataset, n_steps=500, inf_learning_rate=0.01, weight_learning_rate=0.0005, fully_connected = True):
    self.layers = layers
    self.n_steps = n_steps
    self.inf_learning_rate = inf_learning_rate
    self.weight_learning_rate = weight_learning_rate
    self.fully_connected = fully_connected
    self.l_num = len(self.layers)
    self.guess = [[] for i in range(self.l_num+1)]
    self.pred_e = [[] for i in range(self.l_num+1)]
    self.down_error = [[] for i in range(self.l_num+1)]
    self.mus = [[] for i in range(self.l_num+1)]
    self.fn = fn
    self.dataset = dataset
    self.error_weights = []
    for i,l in enumerate(self.layers): #Generate the full connection matrix, starting with an index matrix for the 1 to 1 connections then randomly changing all values
      error_weight = ((0.0 * set_tensor(torch.eye(l.input_size))) + set_tensor(torch.empty([l.input_size, l.input_size]).normal_(mean=0.0, std=0.05))).to(device)
      error_weight = torch.abs(error_weight)
      self.error_weights.append(error_weight)
    self.error_weights.append(set_tensor(torch.eye(self.layers[-1].output_size)))

  def forward_pass(self, x): #Gets a prediction for given input
    with torch.no_grad():
      for i,l in enumerate(self.layers):
        x = l.forward(x)
      return x
  
  # pass correct_vals as a list of length 64, where each item is the correct value for a given index.
  #Generates a heatmap showing the predictions for a batch of 64 images.
  def heat_map(self, preds, epoch, correct_vals, name):
    dframe = pd.DataFrame(preds.numpy())
    dframe.index += 1
    dframe = dframe.T
    sns.set(rc={"figure.figsize":(32,4.5)})
    ax = sns.heatmap(dframe, vmin = 0)
    plt.xlabel("Images")
    plt.ylabel("Predictions for labels")
    plt.title("Predictions for a batch after " + str(epoch+1) + " Epochs with " + name)
    pred_vals = torch.argmax(preds[:], dim=1)
    for num, val in enumerate(pred_vals):
      ax.add_patch(Rectangle((num, val), 1,1, fill=False, edgecolor='cyan', lw=5))
    for num, val in enumerate(correct_vals):
      ax.add_patch(Rectangle((num, val), 1,1, fill=False, edgecolor='green', lw=3))
    plt.savefig(f"{img_dir}/{name}_{self.dataset}_{self.fn}_{epoch}_Heatmap.png")
    plt.show()

  #Outputs the accuracy of the model for the test data given.
  def test_accuracy(self, test):
    accs = []
    for i,(input, label) in enumerate(test):
      pred = self.forward_pass(input.to(device))
      #print("Input: ", input)
      #print("Label: ", label)
      #print("Predicted: ", pred)
      accs.append(accuracy(pred, onehot(label, layers[-1].output_size).to((device))))
    return np.mean(np.array(accs)), pred
    
  #Updates the fully connected weights matrix
  def update_error_weights(self):
    for i in range(1,self.l_num):
      d_error_connection = torch.matmul(self.guess[i].T,self.pred_e[i])
      self.error_weights[i] = self.error_weights[i].clone() - (self.weight_learning_rate * torch.clamp(d_error_connection*2,-1000, 1000))

  #Updates the weights on each layer
  def update_weights(self):
    for (i, l) in enumerate(self.layers):
      l.update_weights(self.pred_e[i+1])

  #Performs inference for the value neurons, updating the predictions every step.
  def equalise_net(self):
    #Initial error setup
    for i in range(1, self.l_num+1): #Not including input layer
      self.guess[i] = self.layers[i-1].forward(self.mus[i-1]) #Guess for current layer from layer below
      if self.fully_connected:
        self.pred_e[i] = self.mus[i] - (self.guess[i] @ self.error_weights[i]) #Calculates the error between current layer mu and guess * connection weights
      else:
        self.pred_e[i] = self.mus[i] - self.guess[i]
    for n in range(self.n_steps):
      for layer in range(1, self.l_num):
        self.down_error[layer] = self.layers[layer].backward(self.pred_e[layer+1])
        #Difference between layers
        d_l = self.pred_e[layer] - self.down_error[layer]
        self.mus[layer] -= self.inf_learning_rate * torch.clamp(d_l*2,-1000, 1000)
      for i in range(1, self.l_num+1): #Not including input layer
        if self.fully_connected:
          self.pred_e[i] = self.mus[i] - (self.guess[i] @ self.error_weights[i]) #Calculates the error between current layer mu and guess * connection weights
        else:
          self.pred_e[i] = self.mus[i] - self.guess[i]

  #Trains the network on a given batch.
  def learn(self, input, label, epoch):
    with torch.no_grad():
      #Initialise the input to the network as the image
      self.mus[0] = input.clone() 
      #Pass the mus up the layers to initialise
      for i in range(1,self.l_num):
        self.mus[i] = self.layers[i-1].forward(self.mus[i-1])
      #Set the top layer equal to the correct outputs
      self.mus[-1] = label.clone()
      self.equalise_net()
      self.update_weights() 
      #Update the weights after each batch
      if self.fully_connected:
        self.update_error_weights()

  #Trains the network on the given dataset that has been organised into batches
  def train_net(self, training, testing, test_name, n_epochs=1):
    with torch.no_grad():
      test_accs = []
      for e in range(n_epochs):
        self.weight_learning_rate = self.weight_learning_rate*0.75
        for l in self.layers:
          l.learning_rate = l.learning_rate*0.75
        print(e)
        for i,(inp, label) in enumerate(training): #Trains the model on each batcn
          #print("Batch: ", i)
          self.learn(inp.to(device), onehot(label, layers[-1].output_size).to(device), e)
        test_mean, predictions_hm = self.test_accuracy(testing) #Test the accuracy of the model using the training data after each epoch
        test_accs.append(test_mean)
        if e%5 == 0:
            self.heat_map(predictions_hm.to("cpu"), e, testing[-1][1], test_name) #Generates heatmap of the last batch of the test data every 5 epochs
      self.heat_map(predictions_hm.to("cpu"), e, testing[-1][1], test_name) #Generates heatmap at the end of the training.
      return test_accs

# Testing

Runs all tests for every variable version of the predictive coding model on all 3 datasets and each of the 3 activation functions we decided to use.

In [None]:
functions = [(tanh, d_tanh, "Tanh"), (relu, d_relu, "reLU") , (sigm, d_sigm, "Sigmoid")]   
datasets = ["cifar", "fashion","mnist"]
#Format: (Seperate backward weights, Backward non-linearitys, Full connections)
variables = [(False, True, False, "Original Predictive Network"), (False, False, False, "No backward non-linearitys"), (True, True, False, "Seperate Backward Weights"), (False, True, True, "Fully connected layers"), (True, False, True, "Combined Relaxed Network")]  
output_size = 10
num_epochs = 30
origin_pred = []
for v in variables:
  for d in datasets:
    if d == "cifar":
        input_size = 3072
    elif d == "fashion" or d == "mnist":
        input_size = 784
    else:
      print("Not a valid dataset!")
      break
    training, testing, display = import_data(d)
    for (f, d_f, name) in functions:
      #Initialise layers
      l1 = Pred_Layer(input_size, 500, activation_function=f, act_func_deriv=d_f, use_back_weights = v[0], use_non_linear = v[1])
      l2 = Pred_Layer(500, 200, activation_function=f, act_func_deriv=d_f, use_back_weights = v[0], use_non_linear = v[1])
      l3 = Pred_Layer(200, 100, activation_function=f, act_func_deriv=d_f, use_back_weights = v[0], use_non_linear = v[1])
      l4 = Pred_Layer(100, output_size, activation_function=f, act_func_deriv=d_f, use_back_weights = v[0], use_non_linear = v[1])
      layers = [l1,l2,l3,l4]
      net = Relaxed_Predictive_Net(layers, name, d, fully_connected = v[2])
      results = net.train_net(training, testing, v[3], num_epochs)
    if v[3] == "Original Predictive Network":
      origin_pred = pd.DataFrame(results)
    else:
      depict(results, name, d, v[3], origin_pred)