Have a look at this as well: 

https://github.com/Trel725/forward-forward/blob/main/forward-forward.ipynb

https://github.com/keras-team/keras-io/blob/master/examples/vision/forwardforward.py

https://github.com/ghadialhajj/FF_unsupervised

In [1]:
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision.transforms import Compose, Lambda, Normalize, ToTensor
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
def load_data(train_batch_size=50000, test_batch_size=10000):

    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader


def add_label(x, y):
    """changes first 10 pixels with one-hot encoding of the label 0-9"""
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

    
def make_y_negative(y):
    y_neg = y.clone()
    for idx, y_samp in enumerate(y):
        allowed_indices = list(range(10))
        allowed_indices.remove(y_samp.item())
        y_neg[idx] = torch.tensor(allowed_indices)[
            torch.randint(len(allowed_indices), size=(1,))
        ].item()
    return y_neg.to(device)



class ForwardLayer(nn.Linear):
    ''' Implements just single layer forward and 'backward' pass
    '''
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 1000

    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(
            torch.mm(x_direction, self.weight.T) +
            self.bias.unsqueeze(0))

    def train(self, x_pos, x_neg):
        for i in tqdm(range(self.num_epochs)):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            loss = torch.log(1 + torch.exp(torch.cat([
                -g_pos + self.threshold,
                g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()


class ForwardNet(torch.nn.Module):
    ''' implements goodness evaluation per layer
    '''
    def __init__(self, dims):
        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1):
            self.layers += [ForwardLayer(dims[d], dims[d + 1]).to(device)]

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = add_label(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness += [h.pow(2).mean(1)]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print('training layer', i, '...')
            h_pos, h_neg = layer.train(h_pos, h_neg)

    
def sample_show(data, name='', idx=0):
    reshaped = data[idx].cpu().reshape(28, 28)
    plt.figure(figsize = (4, 4))
    plt.title(name)
    plt.imshow(reshaped, cmap="gray")
    plt.show()
    

In [None]:
torch.manual_seed(1234)
load_train, load_test = load_data()

net = ForwardNet([784, 500, 500])
x, y = next(iter(load_train))
x, y = x.to(device), y.to(device)
x_pos = add_label(x, y)
                           ### rnd = torch.randperm(x.size(0))
y_neg = make_y_negative(y) ### y[rnd] 
x_neg = add_label(x, y_neg)

for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']):
    sample_show(data, name)

net.train(x_pos, x_neg)

print('train error:', 1.0 - net.predict(x).eq(y).float().mean().item())

x_test, y_test = next(iter(load_test))
x_test, y_test = x_test.to(device), y_test.to(device)

print('test error:', 1.0 - net.predict(x_test).eq(y_test).float().mean().item())