In [171]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
from tqdm.notebook import tqdm

from torch.optim import SGD


from typing import List, Union
from enum import Enum

In [172]:
mnist_dataset = MNIST(root="./data", download=True, transform=T.Compose([T.ToTensor(), T.Normalize(0.1307, 0.3014)]))

In [173]:
mnist_data = DataLoader(mnist_dataset, batch_size=16)

In [174]:
layer_size = int

In [205]:
class FFLayer(nn.Module):
  # What should the learning objective be?
  class GoodnessMeasure(Enum):
    SUM_OF_SQUARED_ACTIVITIES = 1
    SUM_OF_ACTIVITIES = 2

  class MaxObjective(Enum):
    LIKELIHOOD = 1
    LOG_LIKELIHOOD = 2
    GOODNESS = 3
    NEGATIVE_GOODNESS = 4

  def __init__(self, in_features, out_features, threshold=2, 
              good_measure=GoodnessMeasure.SUM_OF_SQUARED_ACTIVITIES, max_obj=MaxObjective.LOG_LIKELIHOOD, lr=1e-3):

    super(FFLayer, self).__init__()
    self.linear = nn.Linear(in_features, out_features)
    self.activation = nn.ReLU(inplace=True)
    self.opt = SGD(self.linear.parameters(), lr=lr)

    self.good_measure = good_measure
    self.threshold = threshold
    self.max_obj = max_obj

  def prob_positive(self, x):
    return F.sigmoid(self.goodness(x) - self.threshold)

  def goodness(self, x):
    a = self.forward(x)
    if self.good_measure == self.GoodnessMeasure.SUM_OF_SQUARED_ACTIVITIES:
      return (a**2).sum()
    elif self.good_measure == self.GoodnessMeasure.SUM_OF_ACTIVITIES:
      return a.sum()

  def train(self, x):
    self.opt.zero_grad()
    loss = None
    if self.max_obj == self.MaxObjective.LIKELIHOOD:
      loss = -self.prob_positive(x)
    elif self.max_obj == self.MaxObjective.LOG_LIKELIHOOD:
      loss = -torch.log(self.prob_positive(x))
    elif self.max_obj == self.MaxObjective.GOODNESS:
      loss = -self.goodness(x)
    elif self.max_obj == self.MaxObjective.NEGATIVE_GOODNESS:
      loss = self.goodness(x)
    loss.backward()
    self.opt.step()

  def forward(self, x):
    n = F.normalize(x, 2)
    h = self.linear(n)
    a = self.activation(h)
    return a

In [206]:
class RecurrentFFModel(nn.Module):
  def __init__(self, input_size: int, layer_sizes: List[layer_size], 
                rollouts=8, lr=1e-3, max_obj=FFLayer.MaxObjective.NEGATIVE_GOODNESS):
                
    super(RecurrentFFModel, self).__init__()
    self.input_size = input_size
    self.rollouts = rollouts
    self.lr = lr

    self.max_obj = max_obj

    """
    Each layer depends on layers above and below, except top layer.
                   ___     ___           <--- Top layer
               ___/   \___/   \___       <--- Middle layer (could be many)
           ___/   \___/   \___/   \___   <--- Bottom layer
          /       /       /       /
     frame   frame   frame   frame

    """
    self.layers = []
    up_size = input_size
    for out_size, down_size in zip(layer_sizes, layer_sizes[1:]):
      layer = self.make_layer(up_size, down_size, out_size)
      self.layers.append(layer)
      up_size = out_size
    # Top layer only gets input from layer below, hence down_size=0.
    top_layer = self.make_layer(up_size=layer_sizes[-2], down_size=0, out_size=layer_sizes[-1])
    self.layers.append(top_layer)

  def make_layer(self, up_size, down_size, out_size):
    return FFLayer(up_size + down_size, out_size, lr=self.lr, max_obj=self.max_obj)

  def train(self, x):
    pass

  def forward(self, x):
    # Initialize layer activities
    activity = [x] + [torch.zeros(x.shape[0], l.out_features) for l in self.layers] + [torch.zeros(x.shape[0], 0)]
 
    # Rollout time
    for rollout in range(self.rollouts):
      new_activity = activity.copy()
      for i, layer in enumerate(self.layers): # This could be done in parallel
        j = i + 1 # Activity index

        # Concatenative activities from layers below and above
        # and prevent propagation of gradients using .detach()
        layer_input = torch.cat((activity[j-1], activity[j+1]), dim=1).detach()
        y = layer(layer_input)

        new_activity[j] = y

      activity = new_activity # Switch activities

    return activity[1:-1] # Only return hidden activities


In [204]:
model = RecurrentFFModel(input_size=28*28, layer_sizes=[400, 200, 100])

In [168]:
print(model.layers[0].weight.mean(), model.layers[0].weight.std())
print(model.layers[0].bias.mean(), model.layers[0].bias.std())

tensor(-4.1034e-05, grad_fn=<MeanBackward0>) tensor(0.0184, grad_fn=<StdBackward0>)
tensor(-0.0008, grad_fn=<MeanBackward0>) tensor(0.0186, grad_fn=<StdBackward0>)


In [146]:
for image, label in tqdm(mnist_data):
  image = image.flatten(start_dim=1) # Flatten image
  model(image)

  0%|          | 0/3750 [00:00<?, ?it/s]

In [148]:
print(model.layers[0].weight.mean(), model.layers[0].weight.std())
print(model.layers[0].bias.mean(), model.layers[0].bias.std())

tensor(2.3312e-05, grad_fn=<MeanBackward0>) tensor(0.0184, grad_fn=<StdBackward0>)
tensor(-0.0182, grad_fn=<MeanBackward0>) tensor(0.0060, grad_fn=<StdBackward0>)


In [None]:
# Instead of dividing the above into multiple layers, we could have one big matrix of connection weights