# MNIST

Training of a small MLP on MNIST with a custom loss function

In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
from matplotlib import pyplot as plt

In [4]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [5]:
# Training settings
batch_size=64
epochs=10
lr=0.01
momentum=0.0
seed=1101

din=784
dout=10
dh=50

nbatches=30 # n.batches to evaluate statistics

In [6]:
class Net(nn.Module):
    def __init__(self,din=784, dh=30, dout=10):
        super(Net, self).__init__()
        
        self.lin1 = nn.Linear(din, dh)
        self.lin2 = nn.Linear(dh, dout)
    
    def forward(self, x):    
        x = torch.sigmoid(self.lin1(x)) 
        return self.lin2(x)

### New loss function

Try to use a Minkowski loss with parameter q

$$
d_q(x,y) = \left( \sum_i (x_i - y_i )^q \right)^{1/q}
$$

the only goal here is to learn how to define a new loss.
The class MinkowskiLoss receives a minibatch of data and 
returns the average Minkowski distance over a minibatch.

In [9]:
class MinkowskiLoss(torch.nn.Module):
    def __init__(self,q):
        super(MinkowskiLoss,self).__init__()
        self.q = q
        
    def forward(self,x,y):
        
        return loss

In [10]:
q = 6
criterion = MinkowskiLoss(q)

In [11]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    

    for batch_idx, (data, target) in enumerate(train_loader):
        
       
        data, target = data.to(device), target.to(device)
                
        data = data.view(-1,784) # *
        target = target.view(-1,1) # *
        onehot = torch.FloatTensor(data.shape[0], 10).zero_().to(device) # *
        onehot.scatter_(1, target, 1) # *
        target = onehot # *
    
        optimizer.zero_grad()
        output = model(data)    
        loss = criterion(output, target)
    
        
        loss.backward()
        optimizer.step()

In [12]:
def stats(model, device, loader, nsamples=batch_size*nbatches):
    
    model.eval() 
    
    loss = 0
    correct = 0
    count = 0
    
    origin = 'train' if loader.dataset.train else 'test'
    
    with torch.no_grad():
        
        for data, target in loader:
            
            count += data.shape[0]
                
            if count > nsamples:
                break
                
            data, target = data.to(device), target.to(device)     
            
            
            data = data.view(-1,784) # *
            target = target.view(-1,1) # *
            onehot = torch.FloatTensor(data.shape[0], 10).zero_().to(device) # *
            onehot.scatter_(1, target, 1) # *
            
    
            output = model(data)
            
            loss += criterion(output, onehot).item()
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= count  
    acc = 100. * correct / count
    

    print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(origin,
        loss, correct, count,
        acc ) )
    
    
    return loss,acc

In [13]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [14]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)

In [15]:
model = Net(din,dh,dout).to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [None]:
train_stats = []
test_stats = []

for epoch in range(1, epochs + 1):
    
    train(model, device, train_loader, optimizer, epoch)
    
    train_stats.append(stats(model, device, train_loader))
    test_stats.append(stats(model, device, test_loader))
    
train_stats = np.array(train_stats)
test_stats = np.array(test_stats)

tensor(1.0635, grad_fn=<DivBackward0>)
tensor(0.9861, grad_fn=<DivBackward0>)
tensor(0.9902, grad_fn=<DivBackward0>)
tensor(0.9447, grad_fn=<DivBackward0>)
tensor(0.9765, grad_fn=<DivBackward0>)
tensor(0.9390, grad_fn=<DivBackward0>)
tensor(0.9399, grad_fn=<DivBackward0>)
tensor(0.8941, grad_fn=<DivBackward0>)
tensor(0.8631, grad_fn=<DivBackward0>)
tensor(0.9092, grad_fn=<DivBackward0>)
tensor(0.8322, grad_fn=<DivBackward0>)
tensor(0.7905, grad_fn=<DivBackward0>)
tensor(0.8054, grad_fn=<DivBackward0>)
tensor(0.8234, grad_fn=<DivBackward0>)
tensor(0.7773, grad_fn=<DivBackward0>)
tensor(0.7876, grad_fn=<DivBackward0>)
tensor(0.7845, grad_fn=<DivBackward0>)
tensor(0.7778, grad_fn=<DivBackward0>)
tensor(0.7234, grad_fn=<DivBackward0>)
tensor(0.7241, grad_fn=<DivBackward0>)
tensor(0.7481, grad_fn=<DivBackward0>)
tensor(0.7226, grad_fn=<DivBackward0>)
tensor(0.7066, grad_fn=<DivBackward0>)
tensor(0.7083, grad_fn=<DivBackward0>)
tensor(0.6901, grad_fn=<DivBackward0>)
tensor(0.6945, grad_fn=<D

tensor(0.5292, grad_fn=<DivBackward0>)
tensor(0.5547, grad_fn=<DivBackward0>)
tensor(0.5447, grad_fn=<DivBackward0>)
tensor(0.5512, grad_fn=<DivBackward0>)
tensor(0.5305, grad_fn=<DivBackward0>)
tensor(0.5480, grad_fn=<DivBackward0>)
tensor(0.5382, grad_fn=<DivBackward0>)
tensor(0.5467, grad_fn=<DivBackward0>)
tensor(0.5481, grad_fn=<DivBackward0>)
tensor(0.5435, grad_fn=<DivBackward0>)
tensor(0.5439, grad_fn=<DivBackward0>)
tensor(0.5308, grad_fn=<DivBackward0>)
tensor(0.5355, grad_fn=<DivBackward0>)
tensor(0.5541, grad_fn=<DivBackward0>)
tensor(0.5687, grad_fn=<DivBackward0>)
tensor(0.5491, grad_fn=<DivBackward0>)
tensor(0.5258, grad_fn=<DivBackward0>)
tensor(0.5454, grad_fn=<DivBackward0>)
tensor(0.5251, grad_fn=<DivBackward0>)
tensor(0.5437, grad_fn=<DivBackward0>)
tensor(0.5414, grad_fn=<DivBackward0>)
tensor(0.5332, grad_fn=<DivBackward0>)
tensor(0.5488, grad_fn=<DivBackward0>)
tensor(0.5245, grad_fn=<DivBackward0>)
tensor(0.5422, grad_fn=<DivBackward0>)
tensor(0.5434, grad_fn=<D

tensor(0.5003, grad_fn=<DivBackward0>)
tensor(0.5075, grad_fn=<DivBackward0>)
tensor(0.4806, grad_fn=<DivBackward0>)
tensor(0.4615, grad_fn=<DivBackward0>)
tensor(0.4794, grad_fn=<DivBackward0>)
tensor(0.4679, grad_fn=<DivBackward0>)
tensor(0.5171, grad_fn=<DivBackward0>)
tensor(0.4722, grad_fn=<DivBackward0>)
tensor(0.4863, grad_fn=<DivBackward0>)
tensor(0.4609, grad_fn=<DivBackward0>)
tensor(0.4797, grad_fn=<DivBackward0>)
tensor(0.4825, grad_fn=<DivBackward0>)
tensor(0.4941, grad_fn=<DivBackward0>)
tensor(0.4969, grad_fn=<DivBackward0>)
tensor(0.4763, grad_fn=<DivBackward0>)
tensor(0.4758, grad_fn=<DivBackward0>)
tensor(0.4908, grad_fn=<DivBackward0>)
tensor(0.4804, grad_fn=<DivBackward0>)
tensor(0.5054, grad_fn=<DivBackward0>)
tensor(0.4808, grad_fn=<DivBackward0>)
tensor(0.5151, grad_fn=<DivBackward0>)
tensor(0.4840, grad_fn=<DivBackward0>)
tensor(0.4316, grad_fn=<DivBackward0>)
tensor(0.4885, grad_fn=<DivBackward0>)
tensor(0.4993, grad_fn=<DivBackward0>)
tensor(0.4633, grad_fn=<D

tensor(0.4764, grad_fn=<DivBackward0>)
tensor(0.4643, grad_fn=<DivBackward0>)
tensor(0.4770, grad_fn=<DivBackward0>)
tensor(0.4688, grad_fn=<DivBackward0>)
tensor(0.4375, grad_fn=<DivBackward0>)
tensor(0.4473, grad_fn=<DivBackward0>)
tensor(0.4540, grad_fn=<DivBackward0>)
tensor(0.4894, grad_fn=<DivBackward0>)
tensor(0.4584, grad_fn=<DivBackward0>)
tensor(0.4863, grad_fn=<DivBackward0>)
tensor(0.4756, grad_fn=<DivBackward0>)
tensor(0.4542, grad_fn=<DivBackward0>)
tensor(0.4528, grad_fn=<DivBackward0>)
tensor(0.4746, grad_fn=<DivBackward0>)
tensor(0.4356, grad_fn=<DivBackward0>)
tensor(0.4386, grad_fn=<DivBackward0>)
tensor(0.4865, grad_fn=<DivBackward0>)
tensor(0.4299, grad_fn=<DivBackward0>)
tensor(0.4274, grad_fn=<DivBackward0>)
tensor(0.4651, grad_fn=<DivBackward0>)
tensor(0.4815, grad_fn=<DivBackward0>)
tensor(0.4784, grad_fn=<DivBackward0>)
tensor(0.4296, grad_fn=<DivBackward0>)
tensor(0.4173, grad_fn=<DivBackward0>)
tensor(0.4493, grad_fn=<DivBackward0>)
tensor(0.4654, grad_fn=<D

tensor(0.4138, grad_fn=<DivBackward0>)
tensor(0.4303, grad_fn=<DivBackward0>)
tensor(0.4773, grad_fn=<DivBackward0>)
tensor(0.4407, grad_fn=<DivBackward0>)
tensor(0.4562, grad_fn=<DivBackward0>)
tensor(0.4539, grad_fn=<DivBackward0>)
tensor(0.4192, grad_fn=<DivBackward0>)
tensor(0.4032, grad_fn=<DivBackward0>)
tensor(0.4500, grad_fn=<DivBackward0>)
tensor(0.4475, grad_fn=<DivBackward0>)
tensor(0.4537, grad_fn=<DivBackward0>)
tensor(0.4240, grad_fn=<DivBackward0>)
tensor(0.4174, grad_fn=<DivBackward0>)
tensor(0.4182, grad_fn=<DivBackward0>)
tensor(0.3842, grad_fn=<DivBackward0>)
tensor(0.4014, grad_fn=<DivBackward0>)
tensor(0.4285, grad_fn=<DivBackward0>)
tensor(0.4344, grad_fn=<DivBackward0>)
tensor(0.4302, grad_fn=<DivBackward0>)
tensor(0.4652, grad_fn=<DivBackward0>)
tensor(0.4421, grad_fn=<DivBackward0>)
tensor(0.3908, grad_fn=<DivBackward0>)
tensor(0.4185, grad_fn=<DivBackward0>)
tensor(0.4389, grad_fn=<DivBackward0>)
tensor(0.4541, grad_fn=<DivBackward0>)
tensor(0.4260, grad_fn=<D

In [None]:
fig=plt.figure(figsize=(10,5))

plt.subplot(121)
plt.plot(train_stats[:,0],'-k',label='train loss')
plt.plot(test_stats[:,0],'-r',label='test loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(train_stats[:,1],'-k',label='train acc')
plt.plot(test_stats[:,1],'-r',label='test acc')
plt.xlabel('epoch')
plt.ylabel('acc')
plt.legend()


plt.show()