# Feedback Alignment

##### Imports and helper functions

In [11]:
import numpy as np
import torch as th
import torchvision
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [4]:
def get_loaders(batch_size, fashion=False):
    mnist = torchvision.datasets.MNIST
    if fashion:
        mnist = torchvision.datasets.FashionMNIST

    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),])

    trainloader = th.utils.data.DataLoader(
        mnist(root="./data", train=True, download=True, transform=transform),
        batch_size=batch_size,
        shuffle=True,
        num_workers=2)
    testloader = th.utils.data.DataLoader(
        mnist(root="./data", train=False, download=True, transform=transform),
        batch_size=batch_size,
        shuffle=False,
        num_workers=2)
    return trainloader, testloader

## Linear function approximation

<img href="images/linear.png">
We are considering a three-layer network of linear neurons, shown above. The network's output is $\boldsymbol{y}=W\boldsymbol{h}$, where $\boldsymbol{h}$ is the hidden-unit activity vector, given by $\boldsymbol{h}=W_0\boldsymbol{x}$, where $\boldsymbol{x}$ is the input to the network. $W_0$ is the matrix of synaptic weights from $\boldsymbol{x}$ to $\boldsymbol{h}$, and $W$ is the weights from $\boldsymbol{h}$ to $\boldsymbol{y}$. The network learns to approximate a linear function, $T$ (for 'target'). It's goal is to reduce the squared error, or loss, $\mathcal{L}=\frac{1}{2}\boldsymbol{e}^{T}\boldsymbol{e}$, where the error $\boldsymbol{e}=\boldsymbol{y^*}-\boldsymbol{y}=T\boldsymbol{x}-\boldsymbol{y}$. To train this network, the feedback alignment algorithm adjusts $W$ in the same way as backprop, i.e. $\Delta W\propto\frac{\partial\mathcal{L}}{\partial W}=-\boldsymbol{e}\boldsymbol{h}^T$, but for $W_0$, it uses a simpler formula, which needs no information about $W$ or any other synapses, but instead, sends $\boldsymbol{e}$ through a fixed random matrix $B$:
$$\Delta W_0\propto B\boldsymbol{e}\boldsymbol{x}^T$$

### Data Generation

The target linear function $T$ maps vectors from a $30$- to a $10$-dimensional space, i.e. $T$ has a shape of $10\times 30$. The elements of $T$ are drawn at random, that is, uniformly, from the range $[-1,1]$. Once chosen, the target matrix is fixed, so that each algorithm (i.e. feedback alignment and backpropagation) tried to learn the same function. Moreover, all algorithms are trained on the same sequence of input/output pairs, with $x\sim{}\mathcal{N}(\mu=0,\Sigma=I)$, $y^*=Tx$. We have chosen the **number of inputs** to be **100**.

##### linear function, inputs, and output generation

In [14]:
num_inputs = 100

T = np.random.uniform(low=-1.0, high=1.0, size=(10, 30))
x_data = np.random.randn(30, num_inputs)
y_data = T @ x_data

`x_data`: `30 x num_inputs`

`y_data`: `10 x num_inputs`

##### Weights and biases random initalization

The elements of $B$ are drawn from the uniform distribution over $[-0.5, 0.5)$, while the elements of the network weight matrix, $W_0$ and $W$, are drawn unifromly from the range $[-0.01,0.01)$.

In [43]:
a = 0.01
W0 = np.random.uniform(-a, a, (20, 30))
b0 = np.random.uniform(-a, a, 20)

W = np.random.uniform(-a, a, (10, 20))
b = np.random.uniform(-a, a, 10)

a = 0.5
B = np.random.uniform(-a, a, (20, 10))

##### Training and Test dataset preparation

In [44]:
x_train, x_test, y_train, y_test = train_test_split(x_data.T, y_data.T, test_size=0.25, shuffle=True)
x_train, x_test, y_train, y_test = x_train.T, x_test.T, y_train.T, y_test.T

In [45]:
x_test.shape

(30, 25)

##### Hyperparameters

In [46]:
eta = 0.1
num_epochs = 1000
batch_size = 20
num_batches = x_train.shape[1] / batch_size

##### Creating batches

In [47]:
x_train_batches, y_train_batches = np.split(x_train, num_batches, axis=1), np.split(y_train, num_batches, axis=1)

In [48]:
for epoch in range(num_epochs):
    print(f"      Epoch {epoch}")
    delta_W0, delta_W, loss = 0, 0, 0
    for i in range(len(x_train_batches)):
        training_data, training_labels = x_train_batches[i], y_train_batches[i]
        
        h = W0 @ training_data
        y = W @ h
        
        e = training_labels - y
        loss += 0.5 * np.square(np.linalg.norm(e))
        
        delta_BP = W.T @ e
        delta_FA = B @ e
        
        delta_W = delta_W + e @ h.T
        delta_W0 = delta_W0 + B @ e @ training_data.T
    
    W = W + (eta / (x_train.shape[1])) * delta_W
    W0 = W0 + (eta / (x_train.shape[1])) * delta_W0
    print(f"      Current loss: {loss}")

      Epoch 0
      Current loss: 3627.638941072857
      Epoch 1
      Current loss: 3628.8345961051755
      Epoch 2
      Current loss: 2917.074192361744
      Epoch 3
      Current loss: 1602.7455428185147
      Epoch 4
      Current loss: 1004.9426339814108
      Epoch 5
      Current loss: 733.4197100787438
      Epoch 6
      Current loss: 555.59120941355
      Epoch 7
      Current loss: 428.6674136419941
      Epoch 8
      Current loss: 334.61745005024846
      Epoch 9
      Current loss: 264.2657557705443
      Epoch 10
      Current loss: 211.67181984598557
      Epoch 11
      Current loss: 172.13897690925074
      Epoch 12
      Current loss: 141.9988071792099
      Epoch 13
      Current loss: 118.60681827631439
      Epoch 14
      Current loss: 100.13930347540257
      Epoch 15
      Current loss: 85.34610250107208
      Epoch 16
      Current loss: 73.35301178783354
      Epoch 17
      Current loss: 63.52962285421383
      Epoch 18
      Current loss: 55.407788420391

      Current loss: 8.210621681522045e-06
      Epoch 326
      Current loss: 7.88145561252479e-06
      Epoch 327
      Current loss: 7.5655258484944745e-06
      Epoch 328
      Current loss: 7.262298294225259e-06
      Epoch 329
      Current loss: 6.971260491082417e-06
      Epoch 330
      Current loss: 6.6919207365239704e-06
      Epoch 331
      Current loss: 6.423807239595815e-06
      Epoch 332
      Current loss: 6.166467310965328e-06
      Epoch 333
      Current loss: 5.9194665861158135e-06
      Epoch 334
      Current loss: 5.682388280257704e-06
      Epoch 335
      Current loss: 5.454832473705201e-06
      Epoch 336
      Current loss: 5.236415426429309e-06
      Epoch 337
      Current loss: 5.026768920671693e-06
      Epoch 338
      Current loss: 4.825539630306398e-06
      Epoch 339
      Current loss: 4.632388516045538e-06
      Epoch 340
      Current loss: 4.446990245264781e-06
      Epoch 341
      Current loss: 4.269032635528073e-06
      Epoch 342
      Curren

      Current loss: 3.746944752542958e-11
      Epoch 630
      Current loss: 3.599272214475138e-11
      Epoch 631
      Current loss: 3.4574218688721824e-11
      Epoch 632
      Current loss: 3.3211640630680644e-11
      Epoch 633
      Current loss: 3.1902782058826556e-11
      Epoch 634
      Current loss: 3.064552413667814e-11
      Epoch 635
      Current loss: 2.9437831632031367e-11
      Epoch 636
      Current loss: 2.827774962311307e-11
      Epoch 637
      Current loss: 2.716340037225983e-11
      Epoch 638
      Current loss: 2.60929801867542e-11
      Epoch 639
      Current loss: 2.5064756598503893e-11
      Epoch 640
      Current loss: 2.4077065452989188e-11
      Epoch 641
      Current loss: 2.3128308302154868e-11
      Epoch 642
      Current loss: 2.221694974128057e-11
      Epoch 643
      Current loss: 2.1341514957485446e-11
      Epoch 644
      Current loss: 2.050058731812097e-11
      Epoch 645
      Current loss: 1.9692806088602297e-11
      Epoch 646
      

      Current loss: 7.695928618948622e-17
      Epoch 956
      Current loss: 7.392884880552922e-17
      Epoch 957
      Current loss: 7.101776036348416e-17
      Epoch 958
      Current loss: 6.822127398158903e-17
      Epoch 959
      Current loss: 6.553489113663128e-17
      Epoch 960
      Current loss: 6.29542951007262e-17
      Epoch 961
      Current loss: 6.047528430476727e-17
      Epoch 962
      Current loss: 5.809391254990541e-17
      Epoch 963
      Current loss: 5.580628799600565e-17
      Epoch 964
      Current loss: 5.3608729920638995e-17
      Epoch 965
      Current loss: 5.1497707779902904e-17
      Epoch 966
      Current loss: 4.946981172530453e-17
      Epoch 967
      Current loss: 4.752175622720727e-17
      Epoch 968
      Current loss: 4.565042577113092e-17
      Epoch 969
      Current loss: 4.385274783913164e-17
      Epoch 970
      Current loss: 4.21258710622019e-17
      Epoch 971
      Current loss: 4.0467000492301435e-17
      Epoch 972
      Current

In [49]:
h = W0 @ x_test
y = W @ h
e = y_test - y
test_error = 0.5 * np.square(np.linalg.norm(e))
test_error

3.448067064856725e-17