In [1]:
import forward
from util import *
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
from torch.optim import SGD

In [2]:
torch.manual_seed(1234)
train_loader, test_loader = MNIST_loaders()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [3]:
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)


In [4]:
import torch.nn.functional as F

dataset = FwFw_Dataset(x, OneHot(y))
train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

In [16]:
net = forward.FCNet([784, 2000, 2000, 2000, 2000], y_classes = 10, dropout=0).to(device)

opt = SGD(net.parameters(), lr=0.005, momentum=0.9)
n_epoch = 300
lossfn = fwfw_loss
pbar = tqdm(range(n_epoch), desc=f"Epoch 0")
for i in pbar:    
    running_loss = 0
    for batch in train_dataloader:
        cur_x = (batch[0].to(device), batch[1].to(device)) # x and associated y (which may be true or false labels)
        cur_y = batch[2].to(device)
        opt.zero_grad()
        res = net.forward(cur_x)
        loss = lossfn(cur_y.type(torch.float64), res)
        loss.backward()
        opt.step()
        running_loss += (loss.item() / len(dataset))
    if i % 10 == 0:
        with torch.no_grad():
            acc = net.predict(x).eq(y).float().mean().item()
    pbar.set_description(f'Epoch {i}, train loss {running_loss}, train acc: {acc}')
torch.save(net.state_dict(), modelpath('hinton.ptc'))



Epoch 299, train loss 0.007325633278655434, train acc: 0.9332999587059021: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [32:39<00:00,  6.53s/it]


In [None]:
net = forward.FCNet([784, 2000, 2000, 2000, 2000], dropout=0).to(device)
net.load_state_dict(torch.load(modelpath('hinton.ptc')))
x_te, y_te = next(iter(test_loader))
x_te, y_te = x_te.to(device), y_te.to(device)
with torch.no_grad():
    print('train error:', 1.0 - net.predict(x).eq(y).float().mean().item())
    print('test error:', 1.0 - net.predict(x_te).eq(y_te).float().mean().item())


The above train/test error is a bit above the paper's results, which is partially due to training for 400 rather than 1000 epochs. the only other differences I can think of are:
- the use of an adapter layer on top of the probability outputs, which I haven't yet figured out how to elegantly fit into the project structure
- the training has basically label damping because the dataloader has some chance of picking a positive example with label=0