Have a look at this as well: 

https://github.com/Trel725/forward-forward/blob/main/forward-forward.ipynb

https://github.com/keras-team/keras-io/blob/master/examples/vision/forwardforward.py

https://github.com/ghadialhajj/FF_unsupervised

In [None]:
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.signal import convolve2d
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset
from torch import tensor, Tensor

import torchvision
from torchvision.transforms import Compose, Lambda, Normalize, ToTensor
from torchvision.datasets import MNIST

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Supervised Version:

In [None]:
def load_data(train_batch_size=50000, test_batch_size=10000):

    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader


def add_label(x, y):
    """changes first 10 pixels with one-hot encoding of the label 0-9"""
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_

    
def make_y_negative(y):
    y_neg = y.clone()
    for idx, y_samp in enumerate(y):
        allowed_indices = list(range(10))
        allowed_indices.remove(y_samp.item())
        y_neg[idx] = torch.tensor(allowed_indices)[torch.randint(len(allowed_indices), size=(1,))].item()
    return y_neg.to(device)


class ForwardLayer(nn.Linear):
    ''' Implements just single layer forward and 'backward' pass
    '''
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = 1000

    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(
            torch.mm(x_direction, self.weight.T) +
            self.bias.unsqueeze(0))

    def train(self, x_pos, x_neg):
        for i in tqdm(range(self.num_epochs)):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            loss = torch.log(1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold]))).mean()
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()


class ForwardNet(torch.nn.Module):
    ''' implements goodness evaluation per layer
    '''
    def __init__(self, dims):
        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1):
            self.layers += [ForwardLayer(dims[d], dims[d + 1]).to(device)]

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = add_label(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness += [h.pow(2).mean(1)]
            goodness_per_label += [sum(goodness).unsqueeze(1)]
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print('training layer', i, '...')
            h_pos, h_neg = layer.train(h_pos, h_neg)

    
def sample_show(data, name='', idx=0):
    reshaped = data[idx].cpu().reshape(28, 28)
    plt.figure(figsize = (4, 4))
    plt.title(name)
    plt.imshow(reshaped, cmap="gray")
    plt.show()
    

In [None]:
torch.manual_seed(1234)
load_train, load_test = load_data()

net = ForwardNet([784, 500, 500])
x, y = next(iter(load_train))
x, y = x.to(device), y.to(device)
x_pos = add_label(x, y)
                           ### rnd = torch.randperm(x.size(0))
y_neg = make_y_negative(y) ### y[rnd] 
x_neg = add_label(x, y_neg)

for data, name in zip([x, x_pos, x_neg], ['orig', 'pos', 'neg']):
    sample_show(data, name)

net.train(x_pos, x_neg)

print('train error:', 1.0 - net.predict(x).eq(y).float().mean().item())

x_test, y_test = next(iter(load_test))
x_test, y_test = x_test.to(device), y_test.to(device)

print('test error:', 1.0 - net.predict(x_test).eq(y_test).float().mean().item())

Unsupervised Version

In [None]:
def create_mask(shape, iterations: int = 10):
    blur_filter_1 = np.array(((0, 0, 0), (0.25, 0.5, 0.25), (0, 0, 0)))
    blur_filter_2 = blur_filter_1.T
    image = np.random.randint(0, 2, size=shape)
    for i in range(iterations):
        image = np.abs(convolve2d(image, blur_filter_1, mode='same') / blur_filter_1.sum())
        image = np.abs(convolve2d(image, blur_filter_2, mode='same') / blur_filter_2.sum())
    mask = np.round(image).astype(np.uint8)
    return tensor(mask)


def create_negative_image(image_1: Tensor, image_2: Tensor):
    assert image_1.shape == image_2.shape, "Incompatible images and mask shapes."
    mask = create_mask((image_1.shape[0], image_1.shape[1]))
    image_1 = torch.mul(image_1, mask)
    image_2 = torch.mul(image_2, 1 - mask)
    return torch.add(image_1, image_2)


def create_negative_batch(images: Tensor):
    neg_imgs = []
    batch_size = images.shape[0]
    for _ in range(batch_size):
        idx1, idx2 = np.random.randint(batch_size, size=2)
        neg_imgs.append(create_negative_image(images[idx1].squeeze(), images[idx2].squeeze()))
    return torch.unsqueeze(torch.stack(neg_imgs), dim=1)


def prepare_data():
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    train_mnist_dataset = torchvision.datasets.MNIST(root="./", train=True, transform=transform,
                                                     download=False)
    n_train_samples = len(train_mnist_dataset)
    test_mnist_dataset = torchvision.datasets.MNIST(root="./", train=False, transform=transform,
                                                    download=False)
    if not os.path.exists("negatives.pt"):
        random_pairs = np.random.randint(n_train_samples, size=[n_train_samples, 2])
        random_pairs = [(row[0], row[1]) for row in random_pairs]

        transformed_dataset = [
            create_negative_image(train_mnist_dataset[pair[0]][0].squeeze(), train_mnist_dataset[pair[1]][0].squeeze())
            for pair in tqdm(random_pairs)]

        torch.save(transformed_dataset, 'negatives.pt')


In [None]:
class ForwardLayer_Unsupervised(nn.Linear):
    def __init__(self, in_features: int, out_features: int, n_epochs: int, bias: bool, device, threshold=2.0):
        super().__init__(in_features, out_features, bias=bias)
        self.n_epochs = n_epochs
        self.opt = torch.optim.Adam(self.parameters())
        self.threshold = threshold
        self.to(device)
        self.ln_layer = nn.LayerNorm(normalized_shape=[1, out_features]).to(device)

    def ff_train(self, pos_acts, neg_acts):
        self.opt.zero_grad()
        pos_goodness = -torch.sum(torch.pow(pos_acts, 2)) + self.threshold
        neg_goodness = torch.sum(torch.pow(neg_acts, 2)) - self.threshold
        goodness = torch.add(pos_goodness, neg_goodness)
        goodness.backward()
        self.opt.step()

    def forward(self, input):
        input = super().forward(input)
        input = self.ln_layer(input.detach())
        return input


class ForwardNet_Unsupervised(nn.Module):
    def __init__(self, n_layers: int = 4, n_neurons=2000, input_size: int = 28 * 28, n_epochs: int = 100,
                 bias: bool = True, n_classes: int = 10, n_hid_to_log: int = 3, device=device):
        super().__init__()
        self.n_hid_to_log = n_hid_to_log
        self.n_epochs = n_epochs
        self.device = device

        ff_layers = [
            ForwardLayer_Unsupervised(in_features=input_size if idx == 0 else n_neurons,
                     out_features=n_neurons,
                     n_epochs=n_epochs,
                     bias=bias,
                     device=device) for idx in range(n_layers)]

        self.ff_layers = ff_layers
        self.last_layer = nn.Linear(in_features=n_neurons * n_hid_to_log, out_features=n_classes, bias=bias)
        self.to(device)
        self.opt = torch.optim.Adam(self.last_layer.parameters())
        self.loss = torch.nn.CrossEntropyLoss(reduction="mean")

    def train_ff_layers(self, pos_dataloader, neg_dataloader):
        outer_tqdm = tqdm(range(self.n_epochs), desc="Training FF Layers", position=0)
        for epoch in outer_tqdm:
            inner_tqdm = tqdm(zip(pos_dataloader, neg_dataloader), desc=f"Training FF Layers | Epoch {epoch}",
                              leave=False, position=1)
            for pos_data, neg_imgs in inner_tqdm:
                pos_imgs, _ = pos_data
                pos_acts = torch.reshape(pos_imgs, (pos_imgs.shape[0], 1, -1)).to(self.device)
                neg_acts = torch.reshape(neg_imgs, (neg_imgs.shape[0], 1, -1)).to(self.device)

                for idx, layer in enumerate(self.ff_layers):
                    pos_acts = layer(pos_acts)
                    neg_acts = layer(neg_acts)
                    layer.ff_train(pos_acts, neg_acts)

    def train_last_layer(self, dataloader: DataLoader):
        num_examples = len(dataloader)
        outer_tqdm = tqdm(range(self.n_epochs), desc="Training Last Layer", position=0)
        loss_list = []
        for epoch in outer_tqdm:
            epoch_loss = 0
            inner_tqdm = tqdm(dataloader, desc=f"Training Last Layer | Epoch {epoch}", leave=False, position=1)
            for images, labels in inner_tqdm:
                images = images.to(self.device)
                labels = labels.to(self.device)
                self.opt.zero_grad()
                preds = self(images)
                loss = self.loss(preds, labels)
                epoch_loss += loss
                loss.backward()
                self.opt.step()
            loss_list.append(epoch_loss / num_examples)
            # Update progress bar with current loss
        return [l.detach().cpu().numpy() for l in loss_list]

    def forward(self, image: torch.Tensor):
        image = image.to(self.device)
        image = torch.reshape(image, (image.shape[0], 1, -1))
        concat_output = []
        for idx, layer in enumerate(self.ff_layers):
            image = layer(image)
            if idx > len(self.ff_layers) - self.n_hid_to_log - 1:
                concat_output.append(image)
        concat_output = torch.concat(concat_output, 2)
        logits = self.last_layer(concat_output)
        return logits.squeeze()

    def evaluate(self, dataloader: DataLoader, dataset_type: str = "train"):
        self.eval()
        inner_tqdm = tqdm(dataloader, desc=f"Evaluating model", leave=False, position=1)
        all_labels = []
        all_preds = []
        for images, labels in inner_tqdm:
            images = images.to(self.device)
            labels = labels.to(self.device)
            preds = self(images)
            preds = torch.argmax(preds, 1)
            all_labels.append(labels.detach().cpu())
            all_preds.append(preds.detach().cpu())
        all_labels = torch.concat(all_labels, 0).numpy()
        all_preds = torch.concat(all_preds, 0).numpy()
        acc = accuracy_score(all_labels, all_preds)
        metrics_dict = dict(accuracy_score=acc)
        print(f"{dataset_type} dataset scores: ", "\n".join([f"{key}: {value}" for key, value in metrics_dict.items()]))


def train(model: ForwardNet_Unsupervised, pos_dataloader: DataLoader, neg_dataloader: DataLoader):
    model.train()
    model.train_ff_layers(pos_dataloader, neg_dataloader)
    return model.train_last_layer(pos_dataloader)


In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
pos_dataset = torchvision.datasets.MNIST(root='./', download=False, transform=transform, train=True)
pos_dataloader = DataLoader(pos_dataset, batch_size=64, shuffle=True, num_workers=4)

prepare_data()
neg_dataset = torch.load('negatives.pt')
neg_dataloader = DataLoader(neg_dataset, batch_size=64, shuffle=True, num_workers=4)

test_dataset = torchvision.datasets.MNIST(root='./', train=False, download=False, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=4)

u_ff = ForwardNet_Unsupervised(device=device, n_epochs=2)
loss = train(u_ff, pos_dataloader, neg_dataloader)
u_ff.evaluate(pos_dataloader, dataset_type="Train")
u_ff.evaluate(test_dataloader, dataset_type="Test")

# fig = plt.figure()
# plt.plot(list(range(len(loss))), loss)
# plt.xlabel("Epochs")
# plt.ylabel("Loss")
# plt.title("Loss Plot")
# ## plt.savefig("Loss Plot.png")
# plt.show()