In [149]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
from tqdm.notebook import tqdm
from typing import List

In [150]:
mnist_dataset = MNIST(root="./data", download=True, transform=T.Compose([T.ToTensor(), T.Normalize(0.1307, 0.3014)]))

In [151]:
mnist_data = DataLoader(mnist_dataset, batch_size=16)

In [155]:
layer_size = int

In [166]:
def pad_list(list: List, value = 0, size: int = 1):
  return [value] * size + list + [value] * size

def window_list(list: List, size: int):
  return zip(*[list[i:] for i in range(size)])


class FFLinear(nn.Module):
  def __init__(self, in_features, out_features, threshold, *args, **kwargs):
    self.linear = nn.Linear(in_features, out_features)
    self.relu = nn.ReLU(inplace=True)
    self.threshold = threshold

  def prob_good(self, x):
    return F.sigmoid(self.goodness(x) - self.threshold)

  def goodness(self, x):
    n = F.normalize(x, 2)
    h = self.linear(n)
    a = self.relu(h)
    return (a**2).sum()

  def forward(self, x):
    n = F.normalize(x, 2)
    h = self.linear(n)
    a = self.relu(h)
    return a

class RecurrentFFModel(nn.Module):
  def __init__(self, input_size: int, layer_sizes: List[layer_size], rollouts=8,alpha=0.8, lr=1e-2):
    super(RecurrentFFModel, self).__init__()
    self.input_size = input_size
    self.rollouts = rollouts
    self.alpha = alpha
    self.lr = lr
    # TODO: Make this more readable.
    self.layers: List[FFLinear] = [
      FFLinear(up_input_size + down_input_size, output_size)
      for (up_input_size, output_size, down_input_size) in window_list([self.input_size] + layer_sizes + [0], 3)
    ]

  def forward(self, X: torch.Tensor):
    # Initialize layer activities
    activity = [X] + [torch.zeros(X.shape[0], l.out_features) for l in self.layers] + [torch.zeros(X.shape[0], 0)]

    # Rollout time
    for rollout in range(self.rollouts):
      new_activity = activity.copy()
      for i, layer in enumerate(self.layers): # This could be done in parallel
        # TODO: This should probably be moved to the linear layer class
        j = i + 1 # Activity index

        # For illustration purposes we don't rely on torch's autofiff, though it might be faster.
        with torch.no_grad(): 
          layer_input = torch.cat((activity[j-1], activity[j+1]), dim=1)
          norm_layer_input = layer_input / (1e-8 + torch.linalg.norm(layer_input, 2)) # Make sure vector has length 1.
          z = layer(norm_layer_input)
          y = F.relu(z)
          #loss = y**2

          # Compute gradient manually
          #                   dl/dW             dl/dy   dy/dz           dz/dW
          dldW = torch.einsum("na,na,nb->nab",  y,      ((z > 0) * 1),  norm_layer_input).mean(dim=0)
          #dl/db dl/dy dy/dz          dz/db
          dldb = (y * ((z > 0) * 1) * 1).mean(dim=0) #torch.einsum("na,na->na",  y,      ((z > 0) * 1)).mean(dim=0)
          layer.weight -= self.lr * dldW # Update the weights
          layer.bias -= self.lr * dldb # Update the weights

        new_activity[j] = y

      activity = new_activity # Switch activities

    return activity[1:-1] # Only return hidden activities


In [167]:
model = RecurrentFFModel(input_size=28*28, layer_sizes=[400, 200, 100])

In [168]:
print(model.layers[0].weight.mean(), model.layers[0].weight.std())
print(model.layers[0].bias.mean(), model.layers[0].bias.std())

tensor(-4.1034e-05, grad_fn=<MeanBackward0>) tensor(0.0184, grad_fn=<StdBackward0>)
tensor(-0.0008, grad_fn=<MeanBackward0>) tensor(0.0186, grad_fn=<StdBackward0>)


In [146]:
for image, label in tqdm(mnist_data):
  image = image.flatten(start_dim=1) # Flatten image
  model(image)

  0%|          | 0/3750 [00:00<?, ?it/s]

In [148]:
print(model.layers[0].weight.mean(), model.layers[0].weight.std())
print(model.layers[0].bias.mean(), model.layers[0].bias.std())

tensor(2.3312e-05, grad_fn=<MeanBackward0>) tensor(0.0184, grad_fn=<StdBackward0>)
tensor(-0.0182, grad_fn=<MeanBackward0>) tensor(0.0060, grad_fn=<StdBackward0>)


In [None]:
# Instead of dividing the above into multiple layers, we could have one big matrix of connection weights