Needed Libs

In [None]:
# This code is based on the github link :
# https://github.com/mpezeshki/pytorch_forward_forward

In [None]:


import argparse
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
from torch.optim import Adam




Generating neg labeles

In [None]:
def get_y_neg(y):
    y_neg = y.clone()
    for idx, y_samp in enumerate(y):
        allowed_indices = list(range(10))
        allowed_indices.remove(y_samp.item())
        y_neg[idx] = torch.tensor(allowed_indices)[
            torch.randint(len(allowed_indices), size=(1,))
        ].item()
    return y_neg.to(device)


def overlay_y_on_x(x, y, classes=10):
    x_ = x.clone()
    x_[:, :classes] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_


Base Layer class

In [None]:
class Layer(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=learning_rate)
        self.threshold = threshold
        self.num_epochs = epochs

    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(torch.mm(x_direction, self.weight.T) + self.bias.unsqueeze(0))

    def train(self, x_pos, x_neg):
        for i in range(self.num_epochs):
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            loss = torch.log(
                1
                + torch.exp(
                    torch.cat([-g_pos + self.threshold, g_neg - self.threshold])
                )
            ).mean()
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            if i % log_interval == 0:
                print("Loss: ", loss.item())
        return self.forward(x_pos).detach(), self.forward(x_neg).detach()


Based Net class

In [None]:
class Net(torch.nn.Module):
    def __init__(self, dims):

        super().__init__()
        self.layers = []
        for d in range(len(dims) - 1):
            self.layers.append(Layer(dims[d], dims[d + 1]).to(device))

    def predict(self, x):
        goodness_per_label = []
        for label in range(10):
            h = overlay_y_on_x(x, label)
            goodness = []
            for layer in self.layers:
                h = layer(h)
                goodness.append(h.pow(2).mean(1))
            goodness_per_label.append(sum(goodness).unsqueeze(1))
        goodness_per_label = torch.cat(goodness_per_label, 1)
        return goodness_per_label.argmax(1)

    def train(self, x_pos, x_neg):
        h_pos, h_neg = x_pos, x_neg
        for i, layer in enumerate(self.layers):
            print("training layer: ", i)
            h_pos, h_neg = layer.train(h_pos, h_neg)

Setting the parameters and train the model

In [None]:
# Simpler variable assignments
epochs = 1000
learning_rate = 0.05
no_cuda = False
no_mps = False
random_seed = 1234
save_model = False
train_size = 50000
threshold = 2
test_size = 10000
log_interval = 10

# Rest of the code remains unchanged

# Using the variables
use_cuda = not no_cuda and torch.cuda.is_available()
use_mps = not no_mps and torch.backends.mps.is_available()

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

train_kwargs = {"batch_size": train_size}
test_kwargs = {"batch_size": test_size}

if use_cuda:
    cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


transform = Compose(
    [
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x)),
    ]
)

train_loader = DataLoader(
    MNIST("./data/", train=True, download=True, transform=transform), **train_kwargs
)
test_loader = DataLoader(
    MNIST("./data/", train=False, download=True, transform=transform), **test_kwargs
)


Calling the Net class , Training the model, Giving back the accuracy

In [None]:
net = Net([784, 500, 500])

x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
x_pos = overlay_y_on_x(x, y)
y_neg = get_y_neg(y)
x_neg = overlay_y_on_x(x, y_neg)

net.train(x_pos, x_neg)
print("train error:", 1.0 - net.predict(x).eq(y).float().mean().item())

x_te, y_te = next(iter(test_loader))
x_te, y_te = x_te.to(device), y_te.to(device)

#if save_model:
#    torch.save(net.state_dict(), "mnist_ff.pt")

print("test error:", 1.0 - net.predict(x_te).eq(y_te).float().mean().item())


training layer:  0
Loss:  1.126759648323059
Loss:  0.7035413980484009
Loss:  0.6998260021209717
Loss:  0.6971884965896606
Loss:  0.692277193069458
Loss:  0.6863791942596436
Loss:  0.6811739802360535
Loss:  0.6726803779602051
Loss:  0.6613404750823975
Loss:  0.6469516754150391
Loss:  0.6306005120277405
Loss:  0.6136147975921631
Loss:  0.5969079732894897
Loss:  0.5807739496231079
Loss:  0.5652912855148315
Loss:  0.5504934191703796
Loss:  0.5363852977752686
Loss:  0.5229905843734741
Loss:  0.5103220343589783
Loss:  0.4983688294887543
Loss:  0.48709607124328613
Loss:  0.47646650671958923
Loss:  0.4664372205734253
Loss:  0.45696982741355896
Loss:  0.4480300545692444
Loss:  0.43958255648612976
Loss:  0.4315928518772125
Loss:  0.42402708530426025
Loss:  0.4168485105037689
Loss:  0.41002514958381653
Loss:  0.4035312831401825
Loss:  0.3973410129547119
Loss:  0.3914314806461334
Loss:  0.38578560948371887
Loss:  0.3803861737251282
Loss:  0.37521445751190186
Loss:  0.3702540397644043
Loss:  0.3654