# The Forward-Forward Algorithm

Original paper: https://www.cs.toronto.edu/~hinton/FFA13.pdf

![ViT](./media/backprop_vs_ff.png)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from tqdm.notebook import tqdm

import torch
from utils.dataset_utils import MNISTLoader, TrainingDatasetFF
from utils.models import FFMultiLayerPerceptron, MultiLayerPerceptron
from utils.tools import base_loss, generate_positive_negative_samples_overlay
from torchvision.transforms import Compose, ToTensor, Lambda, Normalize

In [3]:
## -- Set some variables
PATH_DOWNLOAD = './tmp'

train_batch_size = 1024
test_batch_size = 1024
pos_gen_fn = generate_positive_negative_samples_overlay # which function to use to generate pos neg examples

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1.0 Import Data

In [4]:
download_folder = Path(PATH_DOWNLOAD).mkdir(parents=True, exist_ok=True)
transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,)),
    Lambda(lambda x: torch.flatten(x))])

mnist_loader = MNISTLoader(train_transform=transform,
                           test_transform=transform)

mnist_loader.download_dataset()


In [5]:
train_loader = mnist_loader.get_train_loader(train_batch_size)
test_loader = mnist_loader.get_test_loader(test_batch_size)

In [6]:
# it takes 10s to prepare all training dataset
train_loader_ff = torch.utils.data.DataLoader(TrainingDatasetFF(pos_gen_fn(X.to(device),
                                                                           Y.to(device), False)
                                                                for X, Y in train_loader),
                                              batch_size=train_loader.batch_size, shuffle=True
                                              )


# 2.0 Create Network

In [7]:
## -- Set some variables
hidden_dimensions = [784, 512, 512] # first is input size
activation = torch.nn.ReLU()
layer_optim_learning_rate = 0.09
optimizer = torch.optim.Adam
threshold = 9.0
loss = base_loss 



In [8]:
mlp_model = FFMultiLayerPerceptron(hidden_dimensions, 
                                  activation,
                                  optimizer,
                                  layer_optim_learning_rate,
                                  threshold,
                                  loss).to(device)

## 3.0 Train Network

In [9]:
## -- Set some variables
n_epochs = 60

# choose one of the following training procedures.

## 3.1 Train all layers at the same time

In [None]:
for epoch in tqdm(range(n_epochs)):
    for X_pos, Y_neg in train_loader_ff:
        layer_losses = mlp_model.train_batch(X_pos, Y_neg, before=False)
        print(", ".join(map(lambda i, l: 'Layer {}: {}'.format(i, l),list(range(len(layer_losses))) ,layer_losses)), end='\r')

## 3.2 Train one layer at a time

In [12]:
mlp_model.train_batch_progressive(n_epochs, train_loader_ff)

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: 60/60, Layer 0: 0.26429873704910287

Epoch: 60/60, Layer 1: 0.22680601477622986



# 4.0 Test the Network

In [33]:
acc = 0

for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)

    acc += (mlp_model.predict_accomulate_goodness(X_test,
            pos_gen_fn, n_class=10).eq(Y_test).sum())

print(f"Accuracy: {acc/float(len(mnist_loader.test_set)):.4%}")
print(f"Test error: {1 - acc/float(len(mnist_loader.test_set)):.4%}")


  0%|          | 0/10 [00:00<?, ?it/s]

Accuracy: 97.0800%
Test error: 2.9200%


# 5.0 Back Propagation

In [None]:
## -- Set some variables
n_epochs= 20
hidden_dimensions = [784, 512, 512, 10]
activation = torch.nn.ReLU()
optimizer = torch.optim.Adam
loss_fn = torch.nn.CrossEntropyLoss()



In [None]:
mlp_backprop_model = MultiLayerPerceptron(hidden_dimensions, activation).to(device)
optimizer = optimizer(mlp_backprop_model.parameters())


In [None]:
for epoch in tqdm(range(n_epochs)):
    for i, (X_train, Y_train) in enumerate(train_loader):
        X_train = X_train.to(device)
        Y_train = Y_train.to(device)

        Y_pred = mlp_backprop_model(X_train)

        loss = loss_fn(Y_pred, Y_train)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss}", end='\r')

In [None]:
acc = 0
for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)

    acc += (torch.softmax(mlp_backprop_model(X_test), 1).argmax(1).eq(Y_test).sum())

print(f"Accuracy: {acc/float(len(mnist_loader.test_set)):.4%}")
print(f"Test error: {1 - acc/float(len(mnist_loader.test_set)):.4%}")