# The Forward-Forward Algorithm

Original paper: https://www.cs.toronto.edu/~hinton/FFA13.pdf

![ViT](./media/backprop_vs_ff.png)

In [1]:
%load_ext autoreload
%autoreload 2

# !pip install utils

In [2]:
from pathlib import Path
from tqdm.notebook import tqdm

import torch
from dataset_utils import MNISTLoader, TrainingDatasetFF, CIFAR10Loader, EMNISTLoader, CIFAR100Loader
from models import MultiLayerPerceptron, FFCNN, BPCNN
from tools import base_loss, generate_positive_negative_samples_overlay
from torchvision.transforms import Compose, ToTensor, Lambda, Normalize
import matplotlib.pyplot as plt
import numpy as np

In [3]:
## -- Set some variables
PATH_DOWNLOAD = './tmp'
torch.manual_seed(0)
train_batch_size = 256
test_batch_size = 256
pos_gen_fn = generate_positive_negative_samples_overlay # which function to use to generate pos neg examples

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1.0 Import Data

In [4]:
download_folder = Path(PATH_DOWNLOAD).mkdir(parents=True, exist_ok=True)

transform = Compose([
ToTensor(),
Normalize((0.1307,), (0.3081,))])

data_loader = MNISTLoader(train_transform=transform,
                        test_transform=transform)
hidden_dimensions = [1, 16, 32, 400, 140, 84] # first is input size
kernel_size = 5

In [5]:
data_loader.download_dataset()
train_loader = data_loader.get_train_loader(train_batch_size)
test_loader = data_loader.get_test_loader(test_batch_size)


In [6]:
# it takes 10s to prepare all training dataset
train_loader_ff = torch.utils.data.DataLoader(TrainingDatasetFF(pos_gen_fn(X.to(device),
                                                                           Y.to(device),
                                                                           num_classes=10,
                                                                           only_positive=False)
                                                                for X, Y in train_loader),
                                              batch_size=train_loader.batch_size, shuffle=True
                                              )


# 2.0 Create Network

In [15]:
## -- Set some variables
activation = torch.nn.ReLU()
#activation = torch.nn.Sigmoid()
#activation = torch.nn.LeakyReLU()
#activation = torch.nn.Softmax(dim=1)
#activation = torch.nn.LogSoftmax(dim=1)
layer_optim_learning_rate = 0.05
optimizer = torch.optim.Adam
threshold = 2.0
loss = base_loss 
method = "MSE"
model_arch = "MLP"
num_classes=10


In [None]:
model = FFCNN(hidden_dimensions, activation, optimizer, 
              layer_optim_learning_rate,threshold, loss, 
              method, kernel_size, replace=True).to(device)

In [9]:
def count_parameters(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params

count_parameters(model)
# for layer in mlp_model.layers:
#     # Count the parameters
#     num_params = count_parameters(layer)
#     print("Number of parameters:", num_params)

414880

## 3.0 Train Network

In [10]:
## -- Set some variables
n_epochs = 60
print_every_10_epochs = True

# choose one of the following training procedures.

## 3.1 Train all layers at the same time

In [14]:
for epoch in tqdm(range(n_epochs)):
    for X_pos, Y_neg in train_loader_ff:
        layer_losses = model.train_batch(X_pos, Y_neg, before=False)
        print(", ".join(map(lambda i, l: 'Layer {}: {}'.format(i, l),list(range(len(layer_losses))) ,layer_losses)), end='\r')

    if epoch % 10 == 0 and print_every_10_epochs == True:
      print("Epoch:", epoch)
      acc = 0
      for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
          X_train = X_train.to(device)
          Y_train = Y_train.to(device)

          acc += (model.predict_accumulate_goodness(X_train, pos_gen_fn, n_class=num_classes).eq(Y_train).sum())

      train_accuracy = acc / float(len(train_loader.dataset))
      train_error = 1 - train_accuracy


      print("Overall Train Accuracy: {:.4%}".format(train_accuracy))
      print("Overall Train Error: {:.4%}".format(train_error))
      acc = 0

      for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
          X_test = X_test.to(device)
          Y_test = Y_test.to(device)

          acc += (model.predict_accumulate_goodness(X_test,
                  pos_gen_fn, n_class=num_classes).eq(Y_test).sum())

      print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
      print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")


        # Example usage
        # image, label = next(iter(train_loader_ff))
        # image = image.to(device)
        # image = image.squeeze().cpu().numpy()
        # importance_scores = deepLIFT(mlp_model, image, torch.cat([X_pos for X_pos, _ in iter(train_loader_ff)]).mean())  # Compute importance scores
        # normalized_scores = normalize_scores(importance_scores)  # Normalize importance scores
        # visualize_importance_scores(image, normalized_scores)
    

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch: 0 1.389725923538208, Layer 1: 0, Layer 2: 9.000246047973633, Layer 3: 0, Layer 4: 1.3862942457199097, Layer 5: 1.3924598693847656, Layer 6: 1.38629674911499022


  0%|          | 0/235 [00:00<?, ?it/s]

Overall Train Accuracy: 10.2183%
Overall Train Error: 89.7817%


  0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 10.1000%
Test error: 89.9000%
Epoch: 101.3876748085021973, Layer 1: 0, Layer 2: 9.000246047973633, Layer 3: 0, Layer 4: 1.3862942457199097, Layer 5: 1.3862957954406738, Layer 6: 1.3862949609756478


  0%|          | 0/235 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Calculate train accuracy after each epoch
acc = 0
for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
    X_train = X_train.to(device)
    Y_train = Y_train.to(device)

    acc += (model.predict_accomulate_goodness(X_train, pos_gen_fn, n_class=num_classes, method=method).eq(Y_train).sum())

train_accuracy = acc / float(len(train_loader.dataset))
train_error = 1 - train_accuracy


print("Overall Train Accuracy: {:.4%}".format(train_accuracy))
print("Overall Train Error: {:.4%}".format(train_error))

## 3.2 Train one layer at a time

In [None]:
#mlp_model.train_batch_progressive(n_epochs, train_loader_ff)

# 4.0 Test the Network

In [None]:
acc = 0

for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)

    acc += (model.predict_accomulate_goodness(X_test,
            pos_gen_fn, n_class=num_classes, method=method).eq(Y_test).sum())

print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")


# 5.0 Back Propagation

In [None]:
## -- Set some variables
n_epochs= 60
# if pick_dataset == "MNIST":
#   hidden_dimensions = [784, 500, 500, 10] # first is input size
# elif pick_dataset == "CIFAR10":
#   hidden_dimensions = [3072, 500, 500, 10]
#activation = torch.nn.ReLU()
optimizer = torch.optim.Adam
loss_fn = torch.nn.CrossEntropyLoss()



In [None]:
if model_arch == "MLP":
  backprop_model = MultiLayerPerceptron(hidden_dimensions, activation).to(device)
  hidden_dimensions.append(num_classes)
elif model_arch == "CNN":
  backprop_model = BPCNN(hidden_dimensions, activation, num_classes, kernel_size).to(device)

optimizer = optimizer(backprop_model.parameters())


In [None]:
def count_parameters(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params

count_parameters(backprop_model)
# for layer in mlp_model.layers:
#     # Count the parameters
#     num_params = count_parameters(layer)
#     print("Number of parameters:", num_params)

In [None]:
from __future__ import print_function
for epoch in tqdm(range(n_epochs)):
    for i, (X_train, Y_train) in enumerate(train_loader):
        X_train = X_train.to(device)
        Y_train = Y_train.to(device)

        Y_pred = backprop_model(X_train)

        loss = loss_fn(Y_pred, Y_train)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss}", end='\r')
      
    if epoch % 10 == 0 and print_every_10_epochs == True:
      acc = 0
      for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
          X_train = X_train.to(device)
          Y_train = Y_train.to(device)

          acc += (torch.softmax(backprop_model(X_train), 1).argmax(1).eq(Y_train).sum())

      print("Epoch: ", epoch)
      print(f"Accuracy: {acc/float(len(data_loader.train_set)):.4%}")
      print(f"Train error: {1 - acc/float(len(data_loader.train_set)):.4%}")

      acc = 0
      for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
          X_test = X_test.to(device)
          Y_test = Y_test.to(device)

          acc += (torch.softmax(backprop_model(X_test), 1).argmax(1).eq(Y_test).sum())

      print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
      print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")



In [None]:
# Train accuracy
acc = 0
for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
    X_train = X_train.to(device)
    Y_train = Y_train.to(device)

    acc += (torch.softmax(backprop_model(X_train), 1).argmax(1).eq(Y_train).sum())

print(f"Accuracy: {acc/float(len(data_loader.train_set)):.4%}")
print(f"Test error: {1 - acc/float(len(data_loader.train_set)):.4%}")

In [None]:
# Test accuracy
acc = 0
for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)

    acc += (torch.softmax(backprop_model(X_test), 1).argmax(1).eq(Y_test).sum())

print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")