## Imports

In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

## Dataset Loaders

In [3]:
def MNIST_loaders(train_batch_size=512, test_batch_size=1):

    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./data/', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader

## Forward-Forward network

In [11]:
class ff_layer(nn.Linear):
  def __init__(
                self,
                in_features, 
                out_features, 
                bias=True, 
                device=None, 
                dtype=None,
                eps=1e-7,
                threshold=3.,
                standard=True
              ):
  
    super(ff_layer, self).__init__(in_features, out_features, bias, device, dtype)

    self.relu = nn.ReLU()
    self.optimizer = torch.optim.Adam(self.parameters())
    if standard == False:
      self.criterion = nn.BCEWithLogitsLoss()
    self.eps = eps
    self.threshold = threshold
    self.standard = standard

  def forward(self, x):
    x_norm = x / (torch.norm(x, dim=1, keepdim=True) + self.eps)
    x_linear = super(ff_layer, self).forward(x_norm)
    x_relu = self.relu(x_linear)
    
    return x_relu

  def train_forward(self, x_pos, x_neg):
    for i in range(1000): 
      x_pos_relu, x_neg_relu = self.forward(x_pos), self.forward(x_neg)

      g_pos = x_pos_relu.pow(2).mean(1) 
      g_neg = x_neg_relu.pow(2).mean(1)

      if self.standard==True:
        loss = torch.log(1 + torch.exp(torch.cat([
                    -g_pos + self.threshold,
                    g_neg - self.threshold]))).mean()
      else:
        x = torch.cat((g_pos - self.threshold, g_neg + self.threshold))
        y = torch.cat((torch.ones(g_pos.shape[0]), torch.zeros(g_neg.shape[0]))).to(device)
        loss = self.criterion(x, y)
      
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

    return self.forward(x_pos).detach(), self.forward(x_neg).detach()


In [12]:
class MLP(nn.Module):
  def __init__(
                self,
                layers_size, 
                bias=True, 
                device=None, 
                dtype=None,
                eps=1e-7,
                threshold=3.
            ):
  
    super(MLP, self).__init__()
    self.layers = nn.ModuleList([
        ff_layer(layers_size[i], layers_size[i+1], threshold=threshold, eps=eps) for i in range(len(layers_size)-1)
    ])

  def forward(self, x_pos, x_neg):
    for i, layer in enumerate(self.layers):
      x_pos, x_neg = layer.train_forward(x_pos, x_neg)

  
  def predict(self, x, y):
    x_pos, x_neg = create_positive_negative(x, y)

    for i, layer in enumerate(self.layers):
      x_pos, x_neg = layer(x_pos), layer(x_neg)

      if i == 0:
        pos_goodness = x_pos.pow(2).mean(1).unsqueeze(-1)
        neg_goodness = x_neg.pow(2).mean(1).unsqueeze(-1)
        continue
  
      pos_goodness = torch.cat([pos_goodness, x_pos.pow(2).mean(1).unsqueeze(-1)], 1)
      neg_goodness = torch.cat([neg_goodness, x_neg.pow(2).mean(1).unsqueeze(-1)], 1)
    
    return pos_goodness[:, 1:].max() > neg_goodness[:, 1:].max()


## Functions to generate positives and negatives

In [13]:
def insert_y_on_x(x, y, n_labels=10):
    x_ = x.clone()
    x_[:, :n_labels] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

In [14]:
def create_positive_negative(x, y, n_labels=10):
  x_pos = insert_y_on_x(x, y)
  y_neg = list(range(10))
  y_neg.remove(y)
  x_neg = insert_y_on_x(x.repeat(n_labels-1, 1), y_neg)

  return x_pos, x_neg

## Load data and network

In [17]:
train_loader, test_loader = MNIST_loaders()
torch.manual_seed(1)
network = MLP([784, 128, 128]).to(device)

## Train model

In [18]:
network.train()
for images, y in tqdm(train_loader):
  images = images.to(device)
  y = y.to(device)
  y_neg = torch.randint(0, 10, (images.shape[0],)).to(device)

  for i in range(len(y_neg)):
    if y_neg[i] == y[i]:
      y_neg[i] = (y[i] + 1)%10


  images_pos = insert_y_on_x(images, y)
  images_neg = insert_y_on_x(images, y_neg)

  
  network(images_pos, images_neg)

100%|██████████| 118/118 [06:27<00:00,  3.28s/it]


## Compute accuracy

In [19]:
count_all = 0
count_true = 0
network.eval()
for images, y in tqdm(test_loader):

  images = images.to(device)
  count_all+=1
  if network.predict(images, y):
    count_true+=1

print(f"\nAccuracy: {count_true/count_all}")

100%|██████████| 10000/10000 [00:18<00:00, 527.53it/s]

Accuracy: 0.9



