In [50]:
import torch
from torch import nn
import numpy as np
from torchvision import datasets, transforms

In [2]:
# Define the model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [12]:
def forward_forward(model, X, y):
    # Split the input data into two streams
    X_pos = X[y == 1]
    X_neg = X[y == 0]
    
    # Check if either tensor is empty
    if X_pos.nelement() == 0 or X_neg.nelement() == 0:
        return torch.tensor(0.0)
    
    # Compute the forward pass of the model on both streams of data
    out_pos = model(X_pos)
    out_neg = model(X_neg)
    
    # Compute the loss
    loss = torch.mean((out_pos - out_neg)**2)
    return loss

In [4]:
# Load the CIFAR-10 dataset
transform = transforms.Compose([transforms.ToTensor()])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29307781.48it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [8]:
# Initialize the model and optimizer
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [19]:
# Train the model using the Forward-Forward Algorithm
for epoch in range(100):
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        optimizer.zero_grad()
        loss = forward_forward(model, inputs, labels)
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch: {epoch} Loss: {running_loss / len(trainloader)}')

Epoch: 0 Loss: 2.9085383558231116e-06
Epoch: 1 Loss: 3.0060529933689396e-06
Epoch: 2 Loss: 3.0055480017108493e-06
Epoch: 3 Loss: 3.0801914619951278e-06
Epoch: 4 Loss: 3.0712073797985794e-06
Epoch: 5 Loss: 2.834570900449762e-06
Epoch: 6 Loss: 2.9115084275690606e-06
Epoch: 7 Loss: 2.9461141034062167e-06
Epoch: 8 Loss: 2.903673986929789e-06
Epoch: 9 Loss: 2.891479670179251e-06
Epoch: 10 Loss: 3.0996296259945667e-06
Epoch: 11 Loss: 2.8977382081211544e-06
Epoch: 12 Loss: 3.0660015215107705e-06
Epoch: 13 Loss: 3.1314957226459227e-06
Epoch: 14 Loss: 2.8681984919057866e-06
Epoch: 15 Loss: 2.8208324821116547e-06
Epoch: 16 Loss: 2.9298895262854785e-06
Epoch: 17 Loss: 2.9589574939564046e-06
Epoch: 18 Loss: 2.930693306370813e-06
Epoch: 19 Loss: 2.925326023632806e-06
Epoch: 20 Loss: 2.9053859716441366e-06
Epoch: 21 Loss: 2.8297916074279784e-06
Epoch: 22 Loss: 3.0158187213601195e-06
Epoch: 23 Loss: 2.9126313975393712e-06
Epoch: 24 Loss: 2.9123269974661523e-06
Epoch: 25 Loss: 3.1293613135676423e-06
E

In [None]:
# Load the test dataset
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)


In [49]:
# Evaluate the model on the test dataset
pos_distances = []
neg_distances = []
with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        outputs = model(inputs)
        pos_outputs = outputs[labels == 1]
        neg_outputs = outputs[labels == 0]
        if pos_outputs.nelement() > 0 and neg_outputs.nelement() > 0:
            pos_distance = torch.mean(torch.cdist(pos_outputs, pos_outputs))
            neg_distance = torch.mean(torch.cdist(neg_outputs, neg_outputs))
            pos_distances.append(pos_distance.item())
            neg_distances.append(neg_distance.item())

print(f'Average positive distance: {np.mean(pos_distances)}')
print(f'Average negative distance: {np.mean(neg_distances)}')

Average positive distance: 0.0011169865735218116
Average negative distance: 0.0009342479406540364
