In [171]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as T
from tqdm.notebook import tqdm

from torch.optim import SGD


from typing import List, Union
from enum import Enum

In [172]:
mnist_dataset = MNIST(root="./data", download=True, transform=T.Compose([T.ToTensor(), T.Normalize(0.1307, 0.3014)]))

In [173]:
mnist_data = DataLoader(mnist_dataset, batch_size=16)

In [174]:
layer_size = int

In [239]:
class FFLayer(nn.Module):
  class GoodnessMeasure(Enum):
    SUM_OF_SQUARED_ACTIVITIES = 1
    SUM_OF_ACTIVITIES = 2

  class MaxObjective(Enum):
    LIKELIHOOD = 1
    LOG_LIKELIHOOD = 2
    GOODNESS = 3
    NEGATIVE_GOODNESS = 4

  def __init__(self, in_features, out_features, threshold=2, 
              good_measure=GoodnessMeasure.SUM_OF_SQUARED_ACTIVITIES, max_obj=MaxObjective.LOG_LIKELIHOOD, lr=1e-3):

    super(FFLayer, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.linear = nn.Linear(in_features, out_features)
    self.activation = nn.ReLU(inplace=True)
    self.opt = SGD(self.linear.parameters(), lr=lr)

    self.good_measure = good_measure
    self.threshold = threshold
    self.max_obj = max_obj

  def prob_positive(self, x):
    goodness, activity = self.goodness(x)
    return F.sigmoid(goodness - self.threshold), activity

  def goodness(self, x):
    activity = self.forward(x)
    if self.good_measure == self.GoodnessMeasure.SUM_OF_SQUARED_ACTIVITIES:
      goodness = (activity**2).sum()
    elif self.good_measure == self.GoodnessMeasure.SUM_OF_ACTIVITIES:
      goodness = activity.sum()
    return goodness, activity

  def train(self, x, negative=False):
    sign = -1 if negative else 1
    loss = None
    activity = None

    if self.max_obj == self.MaxObjective.LIKELIHOOD:
      prob_pos, activity = self.prob_positive(x)
      loss = sign * -prob_pos
  
    elif self.max_obj == self.MaxObjective.LOG_LIKELIHOOD:
      prob_pos, activity = self.prob_positive(x)
      loss = sign * -torch.log(prob_pos)

    elif self.max_obj == self.MaxObjective.GOODNESS:
      goodness, activity = self.goodness(x)
      loss = sign * -goodness

    elif self.max_obj == self.MaxObjective.NEGATIVE_GOODNESS:
      goodness, activity = self.goodness(x)
      loss = sign * goodness


    self.opt.zero_grad()
    loss.backward()
    self.opt.step()
    return activity

  def train_pos(self, x):
    return self.train(x, negative=False)

  def train_neg(self, x):
    return self.train(x, negative=True)

  def forward(self, x):
    n = F.normalize(x, 2) 
    h = self.linear(n)
    a = self.activation(h)
    return a

In [240]:
class RecurrentFFModel(nn.Module):
  def __init__(self, input_size: int, layer_sizes: List[layer_size], 
                rollouts=8, lr=1e-3, max_obj=FFLayer.MaxObjective.NEGATIVE_GOODNESS):
                
    super(RecurrentFFModel, self).__init__()
    self.input_size = input_size
    self.rollouts = rollouts
    self.lr = lr

    self.max_obj = max_obj

    """
    Each layer depends on layers above and below, except top layer.
                   ___     ___           <--- Top layer
               ___/   \___/   \___       <--- Middle layer (could be many)
           ___/   \___/   \___/   \___   <--- Bottom layer
          /       /       /       /
     frame   frame   frame   frame

    """
    self.layers = []
    up_size = input_size
    for out_size, down_size in zip(layer_sizes, layer_sizes[1:]):
      layer = self.make_layer(up_size, down_size, out_size)
      self.layers.append(layer)
      up_size = out_size
    # Top layer only gets input from layer below, hence down_size=0.
    top_layer = self.make_layer(up_size=layer_sizes[-2], down_size=0, out_size=layer_sizes[-1])
    self.layers.append(top_layer)

  def make_layer(self, up_size, down_size, out_size):
    return FFLayer(up_size + down_size, out_size, lr=self.lr, max_obj=self.max_obj)

  def test(self, x, labels, num_classes):
    activities = self.forward(x)
    top_activities = activities[-1] # Latent representation of the letter
    # Create some kind of K-nearest neighbour clustering where K=num_classes.

  def train(self, x):
    top_activity = [torch.zeros(x.shape[0], 0)] # TODO: Could take labels for supervised learning
    init_layers_activity = [torch.zeros(x.shape[0], l.out_features) for l in self.layers]
    
    pos_activity = [x] + init_layers_activity + top_activity
    neg_activity = [torch.randn_like(x)] + init_layers_activity + top_activity

    for rollout in range(self.rollouts):
      pos_activity[1:-1] = [
        layer.train_pos(torch.cat((pos_activity[i], pos_activity[i+2]), dim=1).detach())
        for i, layer in enumerate(self.layers) # TODO: Make this parallel
      ]

    for rollout in range(self.rollouts):
      #neq_activity[0] = torch.randn_like(x) # Give new random input for each timestep
      neg_activity[1:-1] = [
        layer.train_neg(torch.cat((neg_activity[i], neg_activity[i+2]), dim=1).detach())
        for i, layer in enumerate(self.layers)
      ]

    return pos_activity[1:-1], neg_activity[1:-1]

  def forward(self, x):
    # Initialize layer activities
    top_activity = [torch.zeros(x.shape[0], 0)]
    init_layers_activity = [torch.zeros(x.shape[0], l.out_features) for l in self.layers]
    activity = [x] + init_layers_activity + top_activity

    for rollout in range(self.rollouts):
      activity[1:-1] = [
        layer(torch.cat((activity[i], activity[i+2]), dim=1).detach()) 
        for i, layer in enumerate(self.layers)
      ]

    return activity[1:-1] # Only return hidden activities


In [241]:
model = RecurrentFFModel(input_size=28*28, layer_sizes=[400, 200, 100])

In [242]:
for image, label in tqdm(mnist_data):
  image = image.flatten(start_dim=1) # Flatten image
  model.train(image)

  0%|          | 0/3750 [00:00<?, ?it/s]