# The Forward-Forward Algorithm

Original paper: https://www.cs.toronto.edu/~hinton/FFA13.pdf

![ViT](./media/backprop_vs_ff.png)

In [21]:
%load_ext autoreload
%autoreload 2

# !pip install utils

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
from pathlib import Path
from tqdm.notebook import tqdm

import torch
from dataset_utils import MNISTLoader, TrainingDatasetFF, CIFAR10Loader, EMNISTLoader, CIFAR100Loader
from models import FFMultiLayerPerceptron, MultiLayerPerceptron, FFCNN, BPCNN
from tools import base_loss, generate_positive_negative_samples_overlay
from torchvision.transforms import Compose, ToTensor, Lambda, Normalize
import matplotlib.pyplot as plt
import numpy as np

In [23]:
## -- Set some variables
PATH_DOWNLOAD = './tmp'
torch.manual_seed(0)
train_batch_size = 1024
test_batch_size = 1024
pos_gen_fn = generate_positive_negative_samples_overlay # which function to use to generate pos neg examples

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1.0 Import Data

In [24]:
download_folder = Path(PATH_DOWNLOAD).mkdir(parents=True, exist_ok=True)
pick_dataset = "MNIST"

if (pick_dataset == "MNIST"):
  transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,)),
    Lambda(lambda x: torch.flatten(x))])

  data_loader = MNISTLoader(train_transform=transform,
                            test_transform=transform)
  hidden_dimensions = [784, 500, 500, 500] # first is input size
  num_classes = 10
  kernel_size = 3

elif (pick_dataset == "CIFAR10"):
  transform = Compose([
    ToTensor(),
    Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    Lambda(lambda x: torch.flatten(x))])

  data_loader = CIFAR10Loader(train_transform=transform,
                            test_transform=transform)
  hidden_dimensions = [3072, 500, 500, 500]
  num_classes = 10
  kernel_size = 3

elif (pick_dataset == "CIFAR100"):
  transform = Compose([
    ToTensor(),
    Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    Lambda(lambda x: torch.flatten(x))])

  data_loader = CIFAR100Loader(train_transform=transform,
                            test_transform=transform)
  hidden_dimensions = [3072, 500, 500, 500]
  num_classes = 100
  kernel_size = 3


elif (pick_dataset == "EMNIST"):
  transform = Compose([
      ToTensor(),
      Normalize((0.5,), (0.5,)),
      Lambda(lambda x: torch.flatten(x))])
  data_loader = EMNISTLoader(train_transform=transform,
                            test_transform=transform)
  hidden_dimensions = [784, 500, 500]
  num_classes = 62
  kernel_size = 3


data_loader.download_dataset()
train_loader = data_loader.get_train_loader(train_batch_size)
test_loader = data_loader.get_test_loader(test_batch_size)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./tmp/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 102174720.68it/s]


Extracting ./tmp/MNIST/raw/train-images-idx3-ubyte.gz to ./tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./tmp/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 69220396.47it/s]

Extracting ./tmp/MNIST/raw/train-labels-idx1-ubyte.gz to ./tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./tmp/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 27867331.51it/s]


Extracting ./tmp/MNIST/raw/t10k-images-idx3-ubyte.gz to ./tmp/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./tmp/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 9119448.91it/s]


Extracting ./tmp/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./tmp/MNIST/raw



In [25]:
# it takes 10s to prepare all training dataset
train_loader_ff = torch.utils.data.DataLoader(TrainingDatasetFF(pos_gen_fn(X.to(device),
                                                                           Y.to(device), False, num_classes)
                                                                for X, Y in train_loader),
                                              batch_size=train_loader.batch_size, shuffle=True
                                              )


KeyboardInterrupt: ignored

# 2.0 Create Network

In [None]:
## -- Set some variables
activation = torch.nn.ReLU()
#activation = torch.nn.Sigmoid()
#activation = torch.nn.LeakyReLU()
#activation = torch.nn.Softmax(dim=1)
#activation = torch.nn.LogSoftmax(dim=1)
layer_optim_learning_rate = 0.09
optimizer = torch.optim.Adam
threshold = 9.0
loss = base_loss 
method = "MSE"
model_arch = "MLP"



In [None]:
if model_arch == "MLP":
  model = FFMultiLayerPerceptron(hidden_dimensions, 
                                    activation,
                                    optimizer,
                                    layer_optim_learning_rate,
                                    threshold,
                                    loss, method).to(device)
elif model_arch == "CNN":
  model = FFCNN(hidden_dimensions, activation, optimizer, 
                layer_optim_learning_rate,threshold, loss, 
                method, num_classes, kernel_size).to(device)

In [None]:
def count_parameters(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params

count_parameters(model)
# for layer in mlp_model.layers:
#     # Count the parameters
#     num_params = count_parameters(layer)
#     print("Number of parameters:", num_params)

In [None]:
def deepLIFT(model, x, baseline):
    # Set the model to evaluation mode
    model.eval()

    # Forward pass with the actual input
    output_actual = model(x)

    # Forward pass with the baseline input
    output_baseline = model(baseline)

    # Compute the differences in activations
    delta = output_actual - output_baseline

    # Perform backward pass to compute importance scores
    model.zero_grad()
    delta.backward(torch.ones_like(delta))

    # Retrieve the gradients from each input
    importance_scores = x.grad

    return importance_scores

def visualize_importance_scores(image, importance_scores):
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))

    # Display the original image
    ax[0].imshow(image, cmap='gray')
    ax[0].set_title('Original Image')
    ax[0].axis('off')

    # Overlay the importance scores on the image
    ax[1].imshow(image, cmap='gray')
    ax[1].imshow(importance_scores, cmap='hot', alpha=0.6)
    ax[1].set_title('Importance Scores Overlay')
    ax[1].axis('off')

    plt.tight_layout()
    plt.show()

def normalize_scores(scores):
    mean_baseline = torch.mean(train_loader, dim=0)
    centered_scores = scores - mean_baseline
    normalized_scores = centered_scores / centered_scores.std()

## 3.0 Train Network

In [None]:
## -- Set some variables
n_epochs = 60
print_every_10_epochs = True

# choose one of the following training procedures.

## 3.1 Train all layers at the same time

In [None]:
for epoch in tqdm(range(n_epochs)):
    for X_pos, Y_neg in train_loader_ff:
        layer_losses = model.train_batch(X_pos, Y_neg, before=False, method=method)
        print(", ".join(map(lambda i, l: 'Layer {}: {}'.format(i, l),list(range(len(layer_losses))) ,layer_losses)), end='\r')

    if epoch % 10 == 0 and print_every_10_epochs == True:
      print("Epoch:", epoch)
      acc = 0
      for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
          X_train = X_train.to(device)
          Y_train = Y_train.to(device)

          acc += (model.predict_accomulate_goodness(X_train, pos_gen_fn, n_class=num_classes, method=method).eq(Y_train).sum())

      train_accuracy = acc / float(len(train_loader.dataset))
      train_error = 1 - train_accuracy


      print("Overall Train Accuracy: {:.4%}".format(train_accuracy))
      print("Overall Train Error: {:.4%}".format(train_error))
      acc = 0

      for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
          X_test = X_test.to(device)
          Y_test = Y_test.to(device)

          acc += (model.predict_accomulate_goodness(X_test,
                  pos_gen_fn, n_class=num_classes, method=method).eq(Y_test).sum())

      print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
      print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")


        # Example usage
        # image, label = next(iter(train_loader_ff))
        # image = image.to(device)
        # image = image.squeeze().cpu().numpy()
        # importance_scores = deepLIFT(mlp_model, image, torch.cat([X_pos for X_pos, _ in iter(train_loader_ff)]).mean())  # Compute importance scores
        # normalized_scores = normalize_scores(importance_scores)  # Normalize importance scores
        # visualize_importance_scores(image, normalized_scores)
    

In [None]:
# Calculate train accuracy after each epoch
acc = 0
for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
    X_train = X_train.to(device)
    Y_train = Y_train.to(device)

    acc += (model.predict_accomulate_goodness(X_train, pos_gen_fn, n_class=num_classes, method=method).eq(Y_train).sum())

train_accuracy = acc / float(len(train_loader.dataset))
train_error = 1 - train_accuracy


print("Overall Train Accuracy: {:.4%}".format(train_accuracy))
print("Overall Train Error: {:.4%}".format(train_error))

## 3.2 Train one layer at a time

In [None]:
#mlp_model.train_batch_progressive(n_epochs, train_loader_ff)

# 4.0 Test the Network

In [None]:
acc = 0

for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)

    acc += (model.predict_accomulate_goodness(X_test,
            pos_gen_fn, n_class=num_classes, method=method).eq(Y_test).sum())

print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")


# 5.0 Back Propagation

In [None]:
## -- Set some variables
n_epochs= 60
# if pick_dataset == "MNIST":
#   hidden_dimensions = [784, 500, 500, 10] # first is input size
# elif pick_dataset == "CIFAR10":
#   hidden_dimensions = [3072, 500, 500, 10]
#activation = torch.nn.ReLU()
optimizer = torch.optim.Adam
loss_fn = torch.nn.CrossEntropyLoss()



In [None]:
if model_arch == "MLP":
  backprop_model = MultiLayerPerceptron(hidden_dimensions, activation).to(device)
  hidden_dimensions.append(num_classes)
elif model_arch == "CNN":
  backprop_model = BPCNN(hidden_dimensions, activation, num_classes, kernel_size).to(device)

optimizer = optimizer(backprop_model.parameters())


In [None]:
def count_parameters(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return num_params

count_parameters(backprop_model)
# for layer in mlp_model.layers:
#     # Count the parameters
#     num_params = count_parameters(layer)
#     print("Number of parameters:", num_params)

In [None]:
from __future__ import print_function
for epoch in tqdm(range(n_epochs)):
    for i, (X_train, Y_train) in enumerate(train_loader):
        X_train = X_train.to(device)
        Y_train = Y_train.to(device)

        Y_pred = backprop_model(X_train)

        loss = loss_fn(Y_pred, Y_train)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss}", end='\r')
      
    if epoch % 10 == 0 and print_every_10_epochs == True:
      acc = 0
      for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
          X_train = X_train.to(device)
          Y_train = Y_train.to(device)

          acc += (torch.softmax(backprop_model(X_train), 1).argmax(1).eq(Y_train).sum())

      print("Epoch: ", epoch)
      print(f"Accuracy: {acc/float(len(data_loader.train_set)):.4%}")
      print(f"Train error: {1 - acc/float(len(data_loader.train_set)):.4%}")

      acc = 0
      for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
          X_test = X_test.to(device)
          Y_test = Y_test.to(device)

          acc += (torch.softmax(backprop_model(X_test), 1).argmax(1).eq(Y_test).sum())

      print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
      print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")



In [None]:
# Train accuracy
acc = 0
for X_train, Y_train in tqdm(train_loader, total=len(train_loader)):
    X_train = X_train.to(device)
    Y_train = Y_train.to(device)

    acc += (torch.softmax(backprop_model(X_train), 1).argmax(1).eq(Y_train).sum())

print(f"Accuracy: {acc/float(len(data_loader.train_set)):.4%}")
print(f"Test error: {1 - acc/float(len(data_loader.train_set)):.4%}")

In [None]:
# Test accuracy
acc = 0
for X_test, Y_test in tqdm(test_loader, total=len(test_loader)):
    X_test = X_test.to(device)
    Y_test = Y_test.to(device)

    acc += (torch.softmax(backprop_model(X_test), 1).argmax(1).eq(Y_test).sum())

print(f"Accuracy: {acc/float(len(data_loader.test_set)):.4%}")
print(f"Test error: {1 - acc/float(len(data_loader.test_set)):.4%}")