# Forward-Forward Explained with XOR
A horribly over-engineered XOR solver to demonstrate the forward-forward unsuprivised learning method described by Geoffrey Hinton in [The Forward-Forward Algorithm: Some Preliminary Investigations](https://arxiv.org/pdf/2212.13345.pdf). *(To my knowledge this is correct, but that doesn't mean I didn't miss something obvious or didn't misunderstand anything)*

\- [aicrumb](https://twitter.com/aicrumb)

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from tqdm.auto import trange

# create the unlabelled data to train our embedding model with

X = torch.tensor([[0, 1], [1, 1], [1, 0], [0, 0]])

# as well as our little model and an optimizer

encoder = nn.Sequential(
    nn.Linear(2, 8),
    nn.ReLU()
)
opt = optim.Adam(encoder.parameters(), 0.05)

In [2]:
for i in trange(2_000):

    # create the real (real) and fake (random) data and feed it through the model

    x_pos = X.clone()
    x_neg = torch.randn(*x_pos.shape) / 2 + 0.5 
    # (idk why but randn_like was crashing the machine)

    x_pos = encoder(x_pos.float())
    x_neg = encoder(x_neg.float())

    # we just want to maximize the values output for real samples
    # and minimize the values output by the model for fake samples
    # i stole this fn from https://github.com/mohammadpz/pytorch_forward_forward

    g_pos = x_pos.pow(2).mean(1)
    g_neg = x_neg.pow(2).mean(1)

    loss = torch.log(1 + torch.exp(
        torch.cat([
            -g_pos + 2.,
            g_neg - 2.
        ])
    )).mean()

    # do a lil step, this is *technically* backprop here but also
    # *technically* computing the derivatives for w/b here 
    # is such a simple process that doesn't actually need the torch api
    # but im lazy, so we do it this way

    loss.backward()
    opt.step()
    opt.zero_grad()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [3]:
# here is where we would feed the outputs of the first layer into another layer
# and train the exact same way. since this is just an example i wont be doing
# that here

In [4]:
# this is what our net embeds each example as, we're going to train 
# one classifier layer on top afterward, not as part of the forward-forward
# process

embedded_x = encoder(X.float())
embedded_x

tensor([[3.1239, 0.0000, 0.0000, 1.9045, 1.8605, 1.2976, 0.0000, 0.0000],
        [3.2355, 0.0000, 0.0000, 2.0181, 0.0000, 1.1044, 0.0000, 0.0000],
        [3.2319, 0.0000, 0.0000, 1.9207, 0.0000, 1.3022, 0.0000, 0.0000],
        [3.1203, 0.0000, 0.0000, 1.8071, 1.9119, 1.4954, 0.0000, 0.0366]],
       grad_fn=<ReluBackward0>)

In [5]:
# we'll now create our classifier layer and an optimizer for it, and also our labeled data

X = torch.tensor([[0, 1], [1, 1], [1, 0], [0, 0]])
y = torch.tensor([1, 0, 1, 0])

classifier = nn.Sequential(
    nn.Linear(8, 1), 
    nn.Sigmoid(),
)

opt = optim.Adam(classifier.parameters(), 0.05)

In [6]:
# first we'll encoder our data

X = encoder(X.float()).clone().detach()

for i in trange(2_000):

    # then find the error between our classifier's outputs and 
    # the real outputs

    pred_labels = classifier(X).flatten()
    loss = nn.MSELoss()(pred_labels, y.float())

    # backward and step
    loss.backward()
    opt.step()
    opt.zero_grad()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [7]:
# now we can chain our models together into one classifier

model = nn.Sequential(
    encoder,
    classifier
)

# and use it to predict xor labels

print("0, 1 -", model(torch.tensor([[0., 1.]])).item())
print("1, 1 -", model(torch.tensor([[1., 1.]])).item())
print("1, 0 -", model(torch.tensor([[1., 0.]])).item())
print("0, 0 -", model(torch.tensor([[0., 0.]])).item())

0, 1 - 0.6876693367958069
1, 1 - 0.31053829193115234
1, 0 - 0.6948422193527222
0, 0 - 0.30689385533332825
