# The Forward-Forward Algorithm: Some Preliminary Investigations


> **The aim of this repository is to implement The Forward-Forward Training Algorithm (FFX) using PyTorch**


## Abstract

The aim of this paper is to introduce a new learning procedure for neural networks and to demonstrate that it works well enough on a few small problems to be worth
serious investigation. The Forward-Forward algorithm replaces the forward and backward passes of backpropagation by two forward passes, one with positive
(i.e. real) data and the other with negative data which could be generated by the network itself. Each layer has its own objective function which is simply to have
high goodness for positive data and low goodness for negative data. The sum of the squared activities in a layer can be used as the goodness but there are many other
possibilities, including minus the sum of the squared activities. If the positive and negative passes can be separated in time, the negative passes can be done offline,
which makes the learning much simpler in the positive pass and allows video to be pipelined through the network without ever storing activities or stopping to
propagate derivatives.

## Implementation

### Import Libraries


In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda

### Data Preprocessing

In [68]:
class MNIST_Loader:
    
    def __init__(self, batch_size_train=50000, batch_size_test=10000, batch_size_eval=10000, path='./data/', ):
        
        self.batch_size_train = batch_size_train
        self.batch_size_test = batch_size_test
        
        self.path = path
        
        self.dl_train = None
        self.dl_test = None
        self.dl_eval = None
        
        self.transform = transform = Compose([
            ToTensor(),
            Normalize((0.1307,), (0.3081,)),
            Lambda(lambda x: torch.flatten(x))])
        
    def data_load(self, download=True, shuffle=True):
        
        self.dl_train, self.dl_test, self.dl_eval = self.data_loader(dpath=self.path, download=download, shuffle=shuffle)
        
    def data_loader(self, download=True, shuffle=True, dpath='./data/'):
      
        dl_train = DataLoader(
            MNIST(dpath, train=True,
                  download=download,
                  transform=self.transform),
            batch_size=self.batch_size_train, shuffle=shuffle)

        dl_test = DataLoader(
            MNIST(dpath, train=False,
                  download=download,
                  transform=self.transform),
            batch_size=self.batch_size_test, shuffle=shuffle)

        dl_eval = DataLoader(
            MNIST(dpath, train=False,
                  download=download,
                  transform=self.transform),
            batch_size=self.batch_size_test, shuffle=False)
        return dl_train, dl_test, dl_eval
    
    def overlay(self, images, labels):
        # Replace the first 10 pixels of images with one-hot-encoded labels
        size = images.shape[0]
        data = images.clone()
        data[:, :10] *= 0.0
        data[range(0, size), labels] = images.max()
        
        return data
    
    def data_positive(self, images, labels):
        return self.overlay(images, labels)
    
    def data_negative(self, images, labels):
        neg_labels = labels.clone()
        for inx, lbl in enumerate(labels):
            _labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
            _labels.pop(lbl.item()) # remove y from labels to generate negative data
            neg_labels[inx] = torch.tensor(np.random.choice(_labels)).cuda()
        return self.overlay(images, neg_labels)
    
    def visualize(data, name='', idx=0):
        img = data[idx].cpu().reshape(28, 28)
        plt.figure(figsize = (4, 4))
        plt.title(name)
        plt.imshow(reshaped, cmap="gray")
        plt.show()

### FFX Layer Model



In [69]:
class FFLayer(nn.Linear):
    
    def __init__(self, 
                 in_features, 
                 out_features,
                 bias=True, 
                 device=None, 
                 dtype=None):
        
        super().__init__(in_features, out_features, bias, device, dtype)
        
        self.relu = nn.ReLU()
        self.sigm = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.lrelu = nn.LeakyReLU()
        self.rrelu = nn.RReLU()
        self.gelu = nn.GELU()
        self.opt = AdamW(self.parameters(), lr=0.02)
        self.threshold = 2.0
        
        if not device:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(torch.mm(x_direction, self.weight.T) + self.bias.unsqueeze(0)).to(self.device)

    def train(self, xpos, xneg):
        g_pos = self.forward(xpos).pow(2).mean(1)
        g_neg = self.forward(xneg).pow(2).mean(1)
        # The following loss pushes pos (neg) samples to values larger (smaller) than the self.threshold.
        loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold]))).mean()
        self.opt.zero_grad()
        # this backward just compute the derivative and hence is not considered backpropagation.
        loss.backward()
        self.opt.step()
        return self.forward(xpos).detach(), self.forward(xneg).detach(), loss.detach()

### Forward Forward Training

In [72]:
class FFX(torch.nn.Module):

    def __init__(self, DataLoader, dims, epochs=50, device=None):
        super().__init__()
        self.epochs = epochs
        self.layers = []
        self.DL = DataLoader
        if not device:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            print("DEVIVE")
        for d in range(len(dims) - 1):
            self.layers += [FFLayer(dims[d], dims[d + 1], device=device)]

    """
    There are two approaches for batch training:
    1. Iterate batches for all layers. ---> easy
    2. Iterate batches for each layer. ---> need to create new batches for next layer input
    We use 1 for the following two training methods.
    """

    def train(self):
        """
        Train method 1: train all layers for each epoch for each batch.
        """
        data_loader = self.DL.dl_train
        print(DL)
        for batch_i, (x_batch, y_batch) in enumerate(data_loader):
            print("Training Batch (Size:", str(x_batch.size(dim=0)) + ')', '#', batch_i + 1, '/', len(data_loader))
            batch_pos, batch_neg = self.DL.data_positive(x_batch, y_batch), self.DL.data_negative(x_batch, y_batch)
            batch_pos, batch_neg = batch_pos.to(self.device), batch_neg.to(self.device)
            for epoch in tqdm(range(self.epochs)):
                h_batch_pos, h_batch_neg = batch_pos, batch_neg
                for layer_i, layer in enumerate(self.layers):
                    h_batch_pos, h_batch_neg, loss = layer.train(h_batch_pos, h_batch_neg)

#     def train_2(self, data_loader):
#         """
#         Train method 2: train all epochs for each layer for each batch.
#         """
#         for batch_i, (x_batch, y_batch) in enumerate(data_loader):
#             batch_loss = 0
#             print("Training Batch (Size:", str(x_batch.size(dim=0)) + ')', '#', batch_i + 1, '/', len(data_loader))
#             h_batch_pos, h_batch_neg = data_positive(x_batch, y_batch), data_negative(x_batch, y_batch)
#             h_batch_pos, h_batch_neg = h_batch_pos.to(self.device), h_batch_neg.to(self.device)
#             for layer_i, layer in enumerate(tqdm(self.layers)):
#                 for epoch in range(self.epochs):
#                     h_batch_pos_epoch, h_batch_neg_epoch, loss = layer.train(h_batch_pos, h_batch_neg)
#                     batch_loss += loss.item()
#                 h_batch_pos, h_batch_neg = h_batch_pos_epoch, h_batch_neg_epoch
#             print('batch {} loss: {}'.format(batch_i + 1, batch_loss))

#     def train_3(self, data_loader):
#         """
#         Train method 3: train all layers for each batch for each epoch. [Slow but better?]
#         """
#         cached_data = []
#         for epoch in tqdm(range(self.epochs)):
#             epoch_loss = 0
#             for batch_i, (x_batch, y_batch) in enumerate(data_loader):
#                 # print("Training Batch (Size:", str(x_batch.size(dim=0)) + ')', '#', batch_i + 1, '/', len(data_loader))
#                 if (epoch + 1) == 1:
#                     h_batch_pos, h_batch_neg = data_positive(x_batch, y_batch), data_negative(x_batch, y_batch)
#                     h_batch_pos, h_batch_neg = h_batch_pos.to(self.device), h_batch_neg.to(self.device)
#                     cached_data.append((h_batch_pos, h_batch_neg))
#                 else:
#                     h_batch_pos, h_batch_neg = cached_data[batch_i]
#                 for layer_i, layer in enumerate(self.layers):
#                     h_batch_pos_epoch, h_batch_neg_epoch, loss = layer.train(h_batch_pos, h_batch_neg)
#                     epoch_loss += loss.item()
#                     h_batch_pos, h_batch_neg = h_batch_pos_epoch, h_batch_neg_epoch
#             print('   epoch {} loss: {}'.format(epoch + 1, epoch_loss))

    @torch.no_grad()
    def predict(self, ds='test', dl=None):
        
        if ds=='test':
            data_loader = self.DL.dl_test
        elif ds=='eval':
            data_loader = self.DL.dl_eval
        elif ds=='train':
            data_loader = self.DL.dl_train
        else:
            data_loader = dl
        
        all_predictions = torch.Tensor([])
        all_labels = torch.Tensor([])
        all_predictions, all_labels = all_predictions.to(self.device), all_labels.to(self.device)
        for batch_i, (x_batch, y_batch) in enumerate(data_loader):
            print("Evaluation Batch (Size:", str(x_batch.size(dim=0)) + ')', '#', batch_i + 1, '/', len(data_loader))
            x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
            goodness_per_label_batch = []
            for label in range(10):
                h_batch = self.DL.overlay(x_batch, label)
                goodness_batch = []
                for layer in self.layers:
                    h_batch = layer(h_batch)
                    goodness_batch += [h_batch.pow(2).mean(1)]
                goodness_per_label_batch += [sum(goodness_batch).unsqueeze(1)]
            goodness_per_label_batch = torch.cat(goodness_per_label_batch, 1)
            predictions_batch = goodness_per_label_batch.argmax(1)
            all_predictions = torch.cat((all_predictions, predictions_batch), 0)
            all_labels = torch.cat((all_labels, y_batch), 0)
        return all_predictions.eq(all_labels).float().mean().item()



### Runs

In [45]:
import time

In [46]:
torch.manual_seed(42)


<torch._C.Generator at 0x7f648d189070>

In [55]:
DL = MNIST_Loader()
DL.data_load()

In [73]:
model = FFX(DataLoader=DL, dims=[784, 2000, 2000, 2000, 2000], device='cpu', epochs=5)

DEVIVE


In [63]:
train_tic = time.time()
model.train()
train_toc = time.time()
train_duration = round(train_toc - train_tic, 2)
print(train_duration)

<__main__.MNIST_Loader object at 0x7f63ff3dcfd0>
Training Batch (Size: 50000) # 1 / 2


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:50<00:00, 22.12s/it]


Training Batch (Size: 10000) # 2 / 2


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:22<00:00,  4.59s/it]

145.04





In [64]:
print(f"Training time: {train_duration}s")

Training time: 145.04s


In [75]:
print('train error:', str(round((1.0 - model.predict(ds='train')) * 100, 2)) + '%')

Evaluation Batch (Size: 50000) # 1 / 2
Evaluation Batch (Size: 10000) # 2 / 2
train error: 85.78%


In [76]:
print('test error:', str(round((1.0 - model.predict(ds='test')) * 100, 2)) + '%')

Evaluation Batch (Size: 10000) # 1 / 1
test error: 85.39%


In [77]:
print('eval error:', str(round((1.0 - model.predict(ds='eval')) * 100, 2)) + '%')

Evaluation Batch (Size: 10000) # 1 / 1
eval error: 85.39%
