In [12]:
import torch
import torch.nn as nn
from torch.optim import SGD
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as T
import numpy as np

from tqdm.notebook import tqdm

from typing import List, Union
from enum import Enum

In [35]:
mnist_train_dataset = MNIST(root="./data", train=True, download=True, transform=T.Compose([T.ToTensor(), T.Normalize(0.1307, 0.3014)]))
mnist_test_dataset = MNIST(root="./data", train=False, download=True, transform=T.Compose([T.ToTensor(), T.Normalize(0.1307, 0.3014)]))

In [36]:
mnist_train_data = DataLoader(mnist_train_dataset, batch_size=16)
mnist_test_data = DataLoader(mnist_test_dataset, batch_size=16)

In [38]:
class FFLayer(nn.Module):
  class GoodnessMeasure(Enum):
    SUM_OF_SQUARED_ACTIVITIES = 1
    SUM_OF_ACTIVITIES = 2

  class MaxObjective(Enum):
    LIKELIHOOD = 1
    LOG_LIKELIHOOD = 2
    GOODNESS = 3
    NEGATIVE_GOODNESS = 4

  def __init__(self, in_features, out_features, threshold=2, 
              good_measure=GoodnessMeasure.SUM_OF_SQUARED_ACTIVITIES, max_obj=MaxObjective.LOG_LIKELIHOOD, lr=1e-3):

    super(FFLayer, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.linear = nn.Linear(in_features, out_features)
    self.activation = nn.ReLU(inplace=True)
    self.opt = SGD(self.linear.parameters(), lr=lr)

    self.good_measure = good_measure
    self.threshold = threshold
    self.max_obj = max_obj

  def prob_positive(self, x):
    goodness, activity = self.goodness(x)
    return torch.sigmoid(goodness - self.threshold), activity

  def goodness(self, x):
    activity = self.forward(x)
    if self.good_measure == self.GoodnessMeasure.SUM_OF_SQUARED_ACTIVITIES:
      goodness = (activity**2).sum()
    elif self.good_measure == self.GoodnessMeasure.SUM_OF_ACTIVITIES:
      goodness = activity.sum()
    return goodness, activity

  def train(self, x, negative=False):
    sign = -1 if negative else 1
    loss = None
    activity = None

    if self.max_obj == self.MaxObjective.LIKELIHOOD:
      prob_pos, activity = self.prob_positive(x)
      loss = sign * -prob_pos
  
    elif self.max_obj == self.MaxObjective.LOG_LIKELIHOOD:
      prob_pos, activity = self.prob_positive(x)
      loss = sign * -torch.log(prob_pos)

    elif self.max_obj == self.MaxObjective.GOODNESS:
      goodness, activity = self.goodness(x)
      loss = sign * -goodness

    elif self.max_obj == self.MaxObjective.NEGATIVE_GOODNESS:
      goodness, activity = self.goodness(x)
      loss = sign * goodness

    self.opt.zero_grad()
    loss.backward()
    self.opt.step()
    return activity

  def train_pos(self, x):
    return self.train(x, negative=False)

  def train_neg(self, x):
    return self.train(x, negative=True)

  def forward(self, x):
    n = F.normalize(x, 2) 
    h = self.linear(n)
    a = self.activation(h)
    return a

In [278]:
class RecurrentFFModel(nn.Module):
  def __init__(self, input_size: int, layer_sizes: List[int], 
                rollouts=8, lr=1e-3, max_obj=FFLayer.MaxObjective.LOG_LIKELIHOOD):
                
    super(RecurrentFFModel, self).__init__()
    self.input_size = input_size
    self.rollouts = rollouts
    self.lr = lr

    self.max_obj = max_obj

    """
    Each layer depends on layers above and below, except top layer.
                   ___     ___           <--- Top layer
               ___/   \___/   \___       <--- Middle layer (could be many)
           ___/   \___/   \___/   \___   <--- Bottom layer
          /       /       /       /
     frame   frame   frame   frame

    """
    self.layers = []
    up_size = input_size
    for out_size, down_size in zip(layer_sizes, layer_sizes[1:]):
      layer = self.make_layer(up_size, down_size, out_size)
      self.layers.append(layer)
      up_size = out_size
    # Top layer only gets input from layer below, hence down_size=0.
    top_layer = self.make_layer(up_size=layer_sizes[-2], down_size=0, out_size=layer_sizes[-1])
    self.layers.append(top_layer)

  def make_layer(self, up_size, down_size, out_size):
    return FFLayer(up_size + down_size, out_size, lr=self.lr, max_obj=self.max_obj)

  def init_layers_activity(self, input_shape):
    # TODO: Could take labels for supervised learning
    return [torch.zeros(*input_shape[:-1], l.out_features) for l in self.layers] + [torch.zeros(*input_shape[:-1], 0)]

  def rollout(self, x, fns, **kwargs):
    activity = [x] + self.init_layers_activity(input_shape=x.shape)

    for _ in range(self.rollouts):
      activity[1:-1] = [
        fn(torch.cat((activity[i], activity[i+2]), dim=-1).detach(), **kwargs)
        for i, fn in enumerate(fns) # TODO: Make this parallel
      ]

    return activity[1:-1]

  def train(self, x, negative=False):
    # This does not work:
    # return self.rollout(x, [lambda y: layer.train(y, negative=negative) for layer in self.layers])
    # However, this does?:
    return self.rollout(x, [layer.train for layer in self.layers], negative=negative)

  def train_pos(self, x):
    return self.train(x, negative=False)

  def train_neg(self, x):
    return self.train(x, negative=True)
  
  def train_pos_neg(self, x_pos, x_neg):
    pos_activity = self.train_pos(x_pos)
    neg_activity = self.train_neg(x_neg)
    return pos_activity, neg_activity

  def forward(self, x):
    return self.rollout(x, [layer.forward for layer in self.layers])


In [279]:
model = RecurrentFFModel(input_size=28*28, layer_sizes=[400, 200, 100])

In [332]:
def train(model, train_data, epochs):
  for epoch in range(10):
    for pos_data, neg_data in tqdm(train_data):
      model.train_pos_neg(pos_data, neg_data)

def test(model, test_data):
  num_predictions = 0
  num_correct = 0
  for image, label in tqdm(test_data):
    activity = model.forward(image)
    mean_activities = torch.stack([(a**2).mean(dim=-1) for a in activity]).mean(dim=0)
    preds = mean_activities.argmax(dim=-1)
    correct = preds == label
    num_predictions += len(preds)
    num_correct += correct.sum()

  accuracy = num_correct / num_predictions
  return accuracy

def make_mnist_train_data(overlay_label=False):
  for image, label in mnist_train_data:
    image = image.flatten(start_dim=1) # Flatten image
    random_image = torch.randn_like(image)
    if overlay_label:
      one_hot_label = F.one_hot(label, num_classes=10)
      image[:, :10] = one_hot_label
    yield image, random_image

def make_mnist_test_data(overlay_label=False):
  for image, label in mnist_test_data:
    image = image.flatten(start_dim=1) # Flatten image

    image = image.expand(10, -1, -1).permute(1, 0, 2).clone() # [batch, 10, pixels]
    image[:, :, :10] = F.one_hot(torch.arange(10))
    
    yield image, label

In [333]:
test(model, make_mnist_test_data())

0it [00:00, ?it/s]

tensor(0.1017)

In [334]:
train(model, make_mnist_train_data(overlay_label=True), epochs=1)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [335]:
test(model, make_mnist_test_data())

0it [00:00, ?it/s]

tensor(0.0975)