# Load data

In [None]:
# Experiments with performance and improvements to Feedback Alignment (Lillicrap 2016) and Direct Feedback Alignment (DFA) (Nokland 2016)
import numpy as np
import matplotlib.pyplot as plt
import torch 
import torchvision
import torchvision.transforms as transforms
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
import math

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_batches= 10
num_train_batches=20
batch_size = 64

transform = transforms.Compose([transforms.ToTensor()])#, transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=1)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=1)


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def onehot(x):
    z = torch.zeros([len(x),10])
    for i in range(len(x)):
      z[i,x[i]] = 1
    return z.float().to(DEVICE)

dataset = list(iter(trainloader))
for i,(img, label) in enumerate(dataset):
  dataset[i] = (img.reshape(len(img),784) /255 ,label)

images, labels = dataset[0]
print("IMAGES: ", images.shape)
print("LABELS: ", labels.shape)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!








IMAGES:  torch.Size([64, 784])
LABELS:  torch.Size([64])


# Functions

In [None]:

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_tensor(xs):
  return xs.float().to(DEVICE)

def tanh(xs):
    return torch.tanh(xs)

def linear(x):
    return x

def tanh_deriv(xs):
    return 1.0 - torch.tanh(xs) ** 2.0

def linear_deriv(x):
    return set_tensor(torch.ones((1,)))

def relu(xs):
  return torch.clamp(xs,min=0)

def relu_deriv(xs):
  rel = relu(xs)
  rel[rel>0] = 1
  return rel 

def softmax(xs):
  return torch.nn.softmax(xs)

def sigmoid(xs):
  return F.sigmoid(xs)

def sigmoid_deriv(xs):
  return F.sigmoid(xs) * (torch.ones_like(xs) - F.sigmoid(xs))
   
def edge_zero_pad(img,d):
  N,C, h,w = img.shape 
  x = torch.zeros((N,C,h+(d*2),w+(d*2))).to(DEVICE)
  x[:,:,d:h+d,d:w+d] = img
  return x


def accuracy(out, L):
  B,l = out.shape
  total = 0
  for i in range(B):
    if torch.argmax(out[i,:]) == torch.argmax(L[i,:]):
      total +=1
  return total/ B

#DEVICE="cpu"

# FC Layer

In [None]:

class FCLayer(object):
  def __init__(self, input_size,output_size,batch_size, learning_rate,f,df,use_backwards_weights=True,update_backwards_weights=False, use_backwards_nonlinearities=True,device="cpu"):
    self.input_size = input_size
    self.output_size = output_size
    self.batch_size = batch_size
    self.learning_rate = learning_rate
    self.f = f 
    self.df = df
    self.device = device
    self.use_backwards_weights = use_backwards_weights
    self.update_backwards_weights = update_backwards_weights
    self.use_backwards_nonlinearities = use_backwards_nonlinearities
    self.weights = torch.empty([self.input_size,self.output_size]).normal_(mean=0.0,std=0.05).to(self.device)
    if self.use_backwards_weights:
      self.backward_weights = torch.empty([self.output_size,self.input_size]).normal_(mean=0.0,std=0.05).to(self.device)

  def forward(self,x):
    self.inp = x.clone()
    self.activations = torch.matmul(self.inp, self.weights)
    return self.f(self.activations)

  def backward(self,e):
    self.fn_deriv = self.df(self.activations)
    if self.use_backwards_weights:
      if self.use_backwards_nonlinearities:
        out = torch.matmul(e * self.fn_deriv, self.backward_weights)
      else:
        out = torch.matmul(e, self.backward_weights)
    else:
      if self.use_backwards_nonlinearities:
        out = torch.matmul(e * self.fn_deriv, self.weights.T)
      else:
        out = torch.matmul(e, self.weights.T)
    return torch.clamp(out,-50,50)

  def DFA_update(self, e):
    if self.use_backwards_nonlinearities:
      self.fn_deriv = self.df(self.activations)
      dw = torch.matmul(self.inp.T, e * self.fn_deriv)
    else:
      dw = torch.matmul(self.inp.T, e)
    self.weights -= self.learning_rate * torch.clamp(dw*2,-50,50)
    return dw

  def update_weights(self,e,update_weights=False,DFA=False):
    if DFA is True:
      return self.DFA_update(e)
    else:
      self.fn_deriv = self.df(self.activations)
      if self.use_backwards_weights:
        if self.use_backwards_nonlinearities:
          delta = torch.matmul((e * self.fn_deriv).T,self.inp)
          dw = torch.matmul(self.inp.T, e * self.fn_deriv)
        else:
          delta = torch.matmul(e.T, self.inp)
          dw = torch.matmul(self.inp.T, e)
        if update_weights:
          self.weights -= self.learning_rate * torch.clamp(dw*2,-50,50)
          if self.update_backwards_weights:
            self.backward_weights -= self.learning_rate * torch.clamp(delta*2,-50,50)
      else:
        if self.use_backwards_nonlinearities:
          dw = torch.matmul(self.inp.T, e * self.fn_deriv)
        else:
          dw = torch.matmul(self.inp.T, e)
        if update_weights:
          self.weights -= self.learning_rate * torch.clamp(dw*2,-50,50)
      return dw

  def get_true_weight_grad(self):
    return self.weights.grad

  def set_weight_parameters(self):
    self.weights = nn.Parameter(self.weights)


# PCNet

In [None]:
class PCNet(object):
  def __init__(self, layers, n_inference_steps_train, inference_learning_rate, weight_learning_rate,use_error_weights=False,device='cpu'):
    self.layers= layers
    self.n_inference_steps_train = n_inference_steps_train
    self.inference_learning_rate = inference_learning_rate
    self.weight_learning_rate = weight_learning_rate
    self.device = device
    self.L = len(self.layers)
    self.outs = [[] for i in  range(self.L+1)]
    self.prediction_errors = [[] for i in range(self.L+1)]
    self.predictions = [[] for i in range(self.L+1)]
    self.mus = [[] for i in range(self.L+1)]
    self.use_error_weights = use_error_weights
    self.error_weights = []
    for i,l in enumerate(self.layers):
      if self.use_error_weights:
        #error_weight = set_tensor(torch.empty([l.input_size, l.input_size]).normal_(mean=0.0, std=0.05))
        #if i == 0:
        #  error_weight = set_tensor(torch.eye(l.input_size))
        #else:
          error_weight = (0.0 * set_tensor(torch.eye(l.input_size))) + set_tensor(torch.empty([l.input_size, l.input_size]).normal_(mean=0.0, std=0.05))
      else:
        error_weight = set_tensor(torch.eye(l.input_size))
      self.error_weights.append(error_weight)
    for l in self.layers:
      l.set_weight_parameters()

  def update_weights(self,print_weight_grads=True,get_errors=False):
    weight_diffs = []
    for (i,l) in enumerate(self.layers):
      dW = l.update_weights(self.prediction_errors[i+1],update_weights=True)
      #true_dW = l.update_weights(self.predictions[i+1],update_weights=True)
      #if print_weight_grads:
      #  diff = torch.sum((dW -true_dW)**2)
      #  weight_diffs.append(diff)
    return weight_diffs

  def update_error_weights(self):
    for (i,l) in enumerate(self.layers):
      if i != 0:
        #error_connection_delta = torch.matmul(self.outs[i].T,self.prediction_errors[i]) 
        error_connection_delta = torch.matmul(self.mus[i].T, self.prediction_errors[i]) #WORKING
        #print(error_connection_delta.shape)
        #error_connection_delta = torch.matmul(self.v_pred_errs[i+1],self.v_layers[i].mu.T)
        #self.error_weights[i] -=  self.weight_learning_rate * torch.clamp(error_connection_delta,-1,1) 
        self.error_weights[i] -=  self.weight_learning_rate * torch.clamp(error_connection_delta,-50,50)
        #print(self.error_weights[i])
        print(self.error_weights[i])
        #pass

  def forward(self,x):
    for i,l in enumerate(self.layers):
      x = l.forward(x)
    return x

  def no_grad_forward(self,x):
    with torch.no_grad():
      for i,l in enumerate(self.layers):
        x = l.forward(x)
      return x

  def infer(self, inp,label,n_inference_steps=None):
    self.n_inference_steps_train = n_inference_steps if n_inference_steps is not None else self.n_inference_steps_train
    with torch.no_grad():
      self.mus[0] = inp.clone()
      self.outs[0] = inp.clone()
      for i,l in enumerate(self.layers):
        self.mus[i+1] = l.forward(self.mus[i])
        self.outs[i+1] = self.mus[i+1].clone()
      self.mus[-1] = label.clone()
      self.prediction_errors[-1] = self.mus[-1] - self.outs[-1] 
      self.predictions[-1] = self.prediction_errors[-1].clone()
      for n in range(self.n_inference_steps_train):
        for j in reversed(range(len(self.layers))):
          #if j != 0: 
          self.prediction_errors[j] = self.mus[j] - torch.matmul(self.outs[j],self.error_weights[j])
          self.prediction_errors[j] = torch.matmul(self.mus[j],self.error_weights[j]) -self.outs[j]
          self.predictions[j] = self.layers[j].backward(self.prediction_errors[j+1])
          dx_l = self.prediction_errors[j] - self.predictions[j]
          #print(dx_l.shape)
          self.mus[j] -= self.inference_learning_rate * (2*dx_l)
        #if self.use_error_weights: 
        #  self.update_error_weights()

      weight_diffs = self.update_weights()
      if self.use_error_weights:
        self.update_error_weights()
      L = torch.sum(self.prediction_errors[-1]**2).item()
      acc = accuracy(self.no_grad_forward(inp),label)
      return L,acc,weight_diffs

  def train(self,dataset,n_epochs,n_inference_steps):
    for epoch in range(n_epochs):
      print("Epoch: ", epoch)
      for i,(inp, label) in enumerate(dataset):
        L, acc,weight_diffs = self.infer(inp.to(DEVICE),onehot(label).to(DEVICE))
        print("Epoch: " + str(epoch) + " batch: " + str(i))
        print("Loss: ", L)
        print("Acc: ", acc)
        #print("weight diffs: ", weight_diffs)






# Backprop Net

In [None]:
class BackpropNet(object):
  def __init__(self, layers,numerical_check=False,device="cpu"):
    self.layers = layers 
    self.device = device
    self.xs = [[] for i in range(len(self.layers)+1)]
    self.e_ys = [[] for i in range(len(self.layers)+1)]
    if numerical_check:
      for l in self.layers:
        l.set_weight_parameters()

  def forward(self, inp):
    self.xs[0] = inp
    for i,l in enumerate(self.layers):
      self.xs[i+1] = l.forward(self.xs[i])
    return self.xs[-1]

  def backward(self,e_y):
    self.e_ys[-1] = e_y
    for (i,l) in reversed(list(enumerate(self.layers))):
      self.e_ys[i] = l.backward(self.e_ys[i+1])
    return self.e_ys[0]

  def update_weights(self,print_weight_grads=False,update_weight=False):
    for (i,l) in enumerate(self.layers):
      dW = l.update_weights(self.e_ys[i+1],update_weights=update_weight)
      if print_weight_grads:
        print("weight grads : ", i)
        print("dW: ", dW*2)
        print("weight grad: ",l.get_true_weight_grad())

  def train(self, dataset,n_epochs):
    #train on a single datapoint here just to test to see if it works
    with torch.no_grad():
      for n in range(n_epochs):
        print("Epoch: ",n)
        for (inp,label) in dataset:
          out = self.forward(inp.to(DEVICE))
          label = onehot(label).to(DEVICE)
          e_y = out - label
          self.backward(e_y)
          #print("out: ",out[0,:])
          #print("label: ",label[0,:])
          self.update_weights(update_weight=True)
          print("Loss: ", torch.sum(e_y**2))
          print("Accuracy: ", accuracy(out,label))



def pytorch_tutorial_net_numerical_check():
  print("beginning numerical check")
  lr = 0.001
  inp = nn.Parameter(images.to(DEVICE))
  l1 = FCLayer(784,300,64,lr,tanh,tanh_deriv,use_backwards_weights= False, use_backwards_nonlinearities=True,device=DEVICE)
  l2 = FCLayer(300,100,64,lr,tanh,tanh_deriv,use_backwards_weights= False, use_backwards_nonlinearities=True,device=DEVICE)
  l3 = FCLayer(100,10,64,lr,tanh,linear_deriv,use_backwards_weights= False, use_backwards_nonlinearities=True,device=DEVICE)
  layers =[l1,l2,l3]
  net = BackpropNet(layers,numerical_check=True,device=DEVICE)
  out = net.forward(inp)
  print(out.shape)
  true_out = onehot(labels).to(DEVICE)
  print(true_out.shape)
  L = torch.sum((out - true_out)**2)
  L.backward()
  e_y = out.detach() - true_out
  dX = net.backward(e_y)
  print("true inp grad: ", inp.grad)
  print("net grad: ", dX*2)
  dw = net.update_weights(print_weight_grads =True)



#pytorch_tutorial_net_numerical_check()

# DFA network

In [None]:
class DFANet(object):
  def __init__(self, layers,numerical_check=False,device="cpu"):
    self.layers = layers 
    self.device = device
    self.xs = [[] for i in range(len(self.layers)+1)]
    self.e_ys = [[] for i in range(len(self.layers)+1)]
    self.DFA_backwards_weights = [torch.empty([10,l.output_size]).normal_(mean=0, std=1).to(self.device) for l in self.layers]
    self.DFA_backwards_weights[-1] = torch.eye(10,10).to(self.device)
    if numerical_check:
      for l in self.layers:
        l.set_weight_parameters()

  def forward(self, inp):
    self.xs[0] = inp
    for i,l in enumerate(self.layers):
      self.xs[i+1] = l.forward(self.xs[i])
    return self.xs[-1]

  def backward(self,e_y):
    self.e_ys[-1] = e_y
    return self.e_ys[-1]

  def update_weights(self,print_weight_grads=False,update_weight=False,shallow=False):
    for (i,l) in enumerate(self.layers):
      if shallow:
        if i <= len(self.layers)-1:
          update_weight=False
        else:
          update_weight=update_weight
      dW = l.update_weights(torch.matmul(self.e_ys[-1],self.DFA_backwards_weights[i]),update_weights=update_weight,DFA=True)
      if print_weight_grads:
        print("weight grads : ", i)
        print("dW: ", dW*2)
        print("weight grad: ",l.get_true_weight_grad())

  def train(self, dataset,n_epochs):
    #train on a single datapoint here just to test to see if it works
    with torch.no_grad():
      for n in range(n_epochs):
        print("Epoch: ",n)
        for (inp,label) in dataset:
          out = self.forward(inp.to(DEVICE))
          label = onehot(label).to(DEVICE)
          e_y = out - label
          self.backward(e_y)
          #print("out: ",out[0,:])
          #print("label: ",label[0,:])
          self.update_weights(update_weight=True)
          print("Loss: ", torch.sum(e_y**2))
          print("Accuracy: ", accuracy(out,label))



def pytorch_tutorial_net_numerical_check():
  print("beginning numerical check")
  lr = 0.001
  inp = nn.Parameter(images.to(DEVICE))
  l1 = FCLayer(784,300,64,lr,tanh,tanh_deriv,use_backwards_weights= False, use_backwards_nonlinearities=True,device=DEVICE)
  l2 = FCLayer(300,100,64,lr,tanh,tanh_deriv,use_backwards_weights= False, use_backwards_nonlinearities=True,device=DEVICE)
  l3 = FCLayer(100,10,64,lr,tanh,linear_deriv,use_backwards_weights= False, use_backwards_nonlinearities=True,device=DEVICE)
  layers =[l1,l2,l3]
  net = BackpropNet(layers,numerical_check=True,device=DEVICE)
  out = net.forward(inp)
  print(out.shape)
  true_out = onehot(labels).to(DEVICE)
  print(true_out.shape)
  L = torch.sum((out - true_out)**2)
  L.backward()
  e_y = out.detach() - true_out
  dX = net.backward(e_y)
  print("true inp grad: ", inp.grad)
  print("net grad: ", dX*2)
  dw = net.update_weights(print_weight_grads =True)



# Training

In [None]:
# right so FA is just update_backwards_weights = False and use_backwards_weights True
lr = 0.0005
n_inference_steps_train = 100
inference_learning_rate =0.02
weight_learning_rate = 0.005
use_backwards_weights= False
use_backwards_nonlinearities=True
use_error_weights=True
update_backwards_weights = True
n_epochs = 10000
l1 = FCLayer(784,300,64,lr,tanh,tanh_deriv,use_backwards_weights= use_backwards_weights, update_backwards_weights = update_backwards_weights, use_backwards_nonlinearities=use_backwards_nonlinearities,device=DEVICE)
l2 = FCLayer(300,100,64,lr,tanh,tanh_deriv,use_backwards_weights= use_backwards_weights, update_backwards_weights = update_backwards_weights,use_backwards_nonlinearities=use_backwards_nonlinearities,device=DEVICE)
l3 = FCLayer(100,10,64,lr,tanh,linear_deriv,use_backwards_weights= use_backwards_weights,update_backwards_weights = update_backwards_weights, use_backwards_nonlinearities=use_backwards_nonlinearities,device=DEVICE)
layers =[l1,l2,l3]
#net = PCNet(layers,n_inference_steps_train,inference_learning_rate,weight_learning_rate,use_error_weights=use_error_weights,device=DEVICE)
#net = BackpropNet(layers)
net = DFANet(layers,device=DEVICE)
#net.train([dataset[0]],n_epochs)
net.train(dataset[0:-2],n_epochs)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Accuracy:  0.84375
Loss:  tensor(24.9885, device='cuda:0')
Accuracy:  0.875
Loss:  tensor(34.1909, device='cuda:0')
Accuracy:  0.734375
Loss:  tensor(25.5060, device='cuda:0')
Accuracy:  0.859375
Loss:  tensor(28.5756, device='cuda:0')
Accuracy:  0.828125
Loss:  tensor(33.3504, device='cuda:0')
Accuracy:  0.828125
Loss:  tensor(29.3493, device='cuda:0')
Accuracy:  0.859375
Loss:  tensor(28.6445, device='cuda:0')
Accuracy:  0.90625
Loss:  tensor(29.7140, device='cuda:0')
Accuracy:  0.8125
Loss:  tensor(25.3011, device='cuda:0')
Accuracy:  0.890625
Loss:  tensor(28.1274, device='cuda:0')
Accuracy:  0.875
Loss:  tensor(31.1144, device='cuda:0')
Accuracy:  0.8125
Loss:  tensor(34.3427, device='cuda:0')
Accuracy:  0.78125
Loss:  tensor(29.9062, device='cuda:0')
Accuracy:  0.859375
Loss:  tensor(30.5329, device='cuda:0')
Accuracy:  0.84375
Loss:  tensor(28.9314, device='cuda:0')
Accuracy:  0.859375
Loss:  tensor(26.9502, device

KeyboardInterrupt: ignored