In [1]:
import torch
import time

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader

In [2]:
def one_hot_encode(img0, lab):
    img = img0.clone()
    img[:, :10] = img0.min()
    img[range(img0.shape[0]), lab] = img0.max()
    return img

In [3]:
#Load MNIST Data
train_loader = DataLoader(
    MNIST('./MNIST_data/', train=True,
    download=True,
    transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,)), Lambda(lambda x: torch.flatten(x))])),
    batch_size=60000)

test_loader = DataLoader(
    MNIST('./MNIST_data/', train=False,
    download=True,
    transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,)), Lambda(lambda x: torch.flatten(x))])),
    batch_size=10000)
    
dtype = torch.float

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

# Training images
img0, lab = next(iter(train_loader))
img0 = img0.to(device)

# Validation images
img0_tst, lab_tst = next(iter(test_loader))
img0_tst = img0_tst.to(device)

Using device: cuda:0


In [None]:
# Forward Forward Applied to a Single Perceptron for MNIST Classification
n_input, n_out = 784, 8000
batch_size, learning_rate = 10, 0.0003
g_threshold = 10
epochs = 550

perceptron = torch.nn.Sequential(torch.nn.Linear(n_input, n_out, bias = True),
                      torch.nn.ReLU())

perceptron.to(device)
optimizer = torch.optim.Adam(perceptron.parameters(), lr = learning_rate)

N_trn = img0.size(0) #Use all training images (60000)

tic = time.time()

for epoch in range(epochs):
    img = img0.clone()

    for i in range(N_trn): # Random jittering of training images up to 2 pixels
        dx, dy = torch.randint(-2, 2, (2,))
        img[i] = torch.roll(img0[i].reshape(28, 28), shifts=(dx, dy), dims=(0, 1)).flatten()

    
    perm = torch.randperm(N_trn)
    img_pos = one_hot_encode(img[perm], lab[perm]) # Good data (actual label)
    
    lab_neg = lab[perm] + torch.randint(low=1,high=10,size=(lab.size()))
    lab_neg = torch.where(lab_neg > 9, lab_neg - 10, lab_neg)
    img_neg = one_hot_encode(img[perm], lab_neg) # Bad data (random error in label)

    L_tot = 0

    for i in range(0, N_trn, batch_size):
        perceptron.zero_grad()

        # Goodness and loss for good data in batch
        img_pos_batch = img_pos[i:i+batch_size]
        g_pos = (perceptron(img_pos_batch)**2).mean(dim=1)
        loss = torch.log(1 + torch.exp(-(g_pos - g_threshold))).sum()

        # Goodness and loss for bad data in batch
        img_neg_batch = img_neg[i:i+batch_size]
        g_neg = (perceptron(img_neg_batch)**2).mean(dim=1)
        loss += torch.log(1 + torch.exp(g_neg - g_threshold)).sum()

        L_tot += loss.item()  # Accumulate total loss for epoch

        loss.backward()   # Compute gradients
        optimizer.step()  # Update parameters

    # Test model with validation set
    N_tst = img0_tst.size(0) # Use all test images (10000)
    
    #Evaluate goodness for all test images and labels 0...9
    g_tst = torch.zeros(10,N_tst).to(device)
    for n in range(10):
        img_tst = one_hot_encode(img0_tst, n)
        g_tst[n] = ((perceptron(img_tst[0:N_tst])**2).mean(dim=1)).detach()       
    predicted_label = g_tst.argmax(dim=0).cpu()

    # Count number of correctly classified images in validation set
    Ncorrect = (predicted_label == lab_tst).sum().cpu().numpy()

    print("Epoch ", epoch+1, ":\tLoss ", L_tot, " \tTime ", round(time.time() - tic), "s\tTest Error ", 100 - Ncorrect/N_tst*100, "%")


Epoch  1 :	Loss  74521.4822101593  	Time  24 s	Test Error  18.210000000000008 %
Epoch  2 :	Loss  47536.03710889816  	Time  48 s	Test Error  12.900000000000006 %
Epoch  3 :	Loss  36756.847930788994  	Time  72 s	Test Error  9.909999999999997 %
Epoch  4 :	Loss  30418.137512803078  	Time  96 s	Test Error  8.489999999999995 %
Epoch  5 :	Loss  26599.50651872158  	Time  120 s	Test Error  6.589999999999989 %
Epoch  6 :	Loss  23383.002599418163  	Time  144 s	Test Error  6.179999999999993 %
Epoch  7 :	Loss  21235.86532586813  	Time  167 s	Test Error  5.089999999999989 %
Epoch  8 :	Loss  19317.112773120403  	Time  191 s	Test Error  4.359999999999999 %
Epoch  9 :	Loss  17776.419402897358  	Time  215 s	Test Error  4.3799999999999955 %
Epoch  10 :	Loss  16425.179507106543  	Time  239 s	Test Error  4.010000000000005 %
Epoch  11 :	Loss  15284.44343227148  	Time  263 s	Test Error  3.480000000000004 %
Epoch  12 :	Loss  14694.135591208935  	Time  287 s	Test Error  3.280000000000001 %
Epoch  13 :	Loss  13