In [3]:
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Lambda
import numpy as np
import matplotlib.pyplot as plt

In [13]:
# load in data from mnist
num_classes = 10

train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.tensor(y,dtype=torch.float))
)
test_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.tensor(y,dtype=torch.float))
)

In [14]:
#create the dataloaders: 
batch_size = 1024

train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X,y in train_dataloader:
    print(f"X shape [NCHW]: {X.shape}")
    print(f"y shape: {y.shape}")
    break

X shape [NCHW]: torch.Size([1024, 1, 28, 28])
y shape: torch.Size([1024])


In [15]:
#define the device and the model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {device}")

alpha=4.5
beta=4.5

model = nn.Sequential(
    nn.Conv2d(1,8,5),
    nn.ReLU(),
    nn.MaxPool2d(2,2),
    nn.Conv2d(8,16,5),
    nn.ReLU(),
    nn.MaxPool2d(2,2),
    nn.Flatten(),
    nn.Linear(256,128),
    nn.ReLU(),
    nn.Linear(128,32),
    nn.ReLU(),
    nn.Linear(32,1),
)

output_tf = lambda mx: alpha*mx + beta

Using device cuda


In [16]:
#optimiazer, loss and training

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss_fn = nn.MSELoss()

train_losses = []
train_accuracies =[]

def train(dataloader, model, loss_fn, optimizer):
    
    size = len(dataloader.dataset)
    model.train()
    
    losses,accuracies = [],[]
    for lvb, (X,y) in enumerate(dataloader):
        X.to(device), y.to(device)

        #forward pass
        mx = model(X)
        mx = mx.squeeze()
        mx = output_tf(mx)
        loss = loss_fn(mx,y)

        #gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # append loss and accuracy for recording
        losses.append(loss.item())
        acc = (y == torch.round(mx)).mean(dtype = float).item()
        accuracies.append(acc)

        #print for following
        if lvb % 10 == 0:
            loss, current = loss.item(), lvb * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    train_loss = np.mean(losses)
    train_acc = np.mean(accuracies)
    
    return train_loss, train_acc
            

In [17]:
def test(dataloader, model, loss_fn):
    
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    model.eval()
    
    with torch.no_grad():
        test_loss,correct = 0,0
        for lvb, (X,y) in enumerate(dataloader):
            X.to(device), y.to(device)
            
            #forward pass
            mx = model(X)
            mx = mx.squeeze()
            mx = output_tf(mx)
            loss = loss_fn(mx,y)
            
            # append loss and accuracy for recording
            test_loss += loss.item()
            correct += (y == torch.round(mx)).sum().item()
                
    test_loss /= num_batches
    acc = correct / size
    print(f"Test Error: \n Accuracy: {(100*acc):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, acc

In [18]:
#train:
num_epochs = 50
train_losses,train_accs, test_losses, test_accs = [],[],[],[]

print(f"PRE-TRAINING:\n-------------------------------")
#test + record
test_loss,test_acc = test(test_dataloader,model,loss_fn)
test_losses.append(test_loss)
test_accs.append(test_acc)

for lve in range(num_epochs):
    print(f"Epoch {lve+1}\n-------------------------------")
    #train + record
    train_loss, train_acc = train(train_dataloader,model,loss_fn,optimizer)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    #test + record
    test_loss,test_acc = test(test_dataloader,model,loss_fn)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

PRE-TRAINING:
-------------------------------
Test Error: 
 Accuracy: 9.7%, Avg loss: 8.335621 

Epoch 1
-------------------------------
loss: 8.307939  [    0/60000]
loss: 7.552577  [10240/60000]
loss: 5.966996  [20480/60000]
loss: 4.125992  [30720/60000]
loss: 3.689331  [40960/60000]
loss: 3.047615  [51200/60000]
Test Error: 
 Accuracy: 25.4%, Avg loss: 3.023489 

Epoch 2
-------------------------------
loss: 3.036360  [    0/60000]
loss: 2.728499  [10240/60000]
loss: 2.564800  [20480/60000]
loss: 2.585215  [30720/60000]
loss: 2.332980  [40960/60000]
loss: 1.874583  [51200/60000]
Test Error: 
 Accuracy: 35.7%, Avg loss: 1.866092 

Epoch 3
-------------------------------
loss: 1.864504  [    0/60000]
loss: 1.724776  [10240/60000]
loss: 1.972403  [20480/60000]
loss: 1.863144  [30720/60000]
loss: 1.815627  [40960/60000]
loss: 1.392230  [51200/60000]
Test Error: 
 Accuracy: 43.7%, Avg loss: 1.418383 

Epoch 4
-------------------------------
loss: 1.409383  [    0/60000]
loss: 1.334795  [

Here we can see the training occurs a good deal faster than in the same case prior to the affine transformation - why?  My first hypothesis was that we wanted to scale the gradients to make it go faster, but I think Adam's adaptive optimisation should cancel this out anyways; now I'm fairly sure that the idea is to effectively re-scale the y-values (since it's functionally the same to do the inverse transformation on y, allowing that the specific form of the loss will change) -- this just moves the 'average' label to be zero instead of 4.5.  I would guess that this makes things easier just because the random initial outputs are distributed in some way around zero, which means gradients for getting different numbers wrong by the same amount have different magnitudes of loss - an error classifying 9 might hurt much more than the loss from misclassifying 0, and so presumably it trains much faster on 9s than on 0s, and this difference makes it take a lot longer overall.  I think it's fairly clear that any architecture which assumes "1" is closer to "2" than to "3" is working against the natural structure of the problem.