In [2]:
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch.nn import functional as F
from pytorch_model_summary import summary

In [3]:
import torchvision
import torchvision.datasets as datasets

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

In [4]:
mnist_train = mnist_trainset.data
mnist_test = mnist_testset.data

mnist_trainlabel = mnist_trainset.targets
mnist_testlabel = mnist_testset.targets

In [5]:
print(mnist_train.shape)
print(mnist_test.shape)
print(mnist_trainlabel.shape)
print(mnist_testlabel.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])
torch.Size([60000])
torch.Size([10000])


In [6]:
from dfw import DFW
from dfw.losses import MultiClassHingeLoss

In [50]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(784, 100)
        self.fc2 = nn.Linear(100, 50) 
        self.fc3 = nn.Linear(50, 10) 


    def forward(self, x):
        x = x.view(-1,784) # flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
print(summary(MLP(), torch.zeros((1, 1, 28, 28)), show_input=True))

-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Linear-1            [1, 784]          78,500          78,500
          Linear-2            [1, 100]           5,050           5,050
          Linear-3             [1, 50]             510             510
Total params: 84,060
Trainable params: 84,060
Non-trainable params: 0
-----------------------------------------------------------------------


In [66]:
def train(model, train_input, train_target, test_input, test_target,
                mini_batch_size, nb_epochs = 100,eta = 1e-3, verbose=1):

#    criterion = nn.L1Loss()
#    optimizer = optim.SGD(model.parameters(), lr=eta, weight_decay=1e-4) # l2-reg with weight decay

    
    svm = MultiClassHingeLoss()
    optimizer = DFW(model.parameters(), eta=eta)
    
    for e in range(nb_epochs):
        total_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b , mini_batch_size))
            loss = svm(output, train_target.narrow(0, b, mini_batch_size))
            total_loss = total_loss + loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step(lambda: float(loss))
        
        if e%10==0:
            output_train = model(train_input.float())
            output_val = model(test_input.float())
            accuracy_train = accuracy(output_train, train_target)
            accuracy_val = accuracy(output_val, test_target)
            if not verbose : 
                print(e, total_loss)
            else:
                print('epoch {}: acc -> {} | acc_val -> {}'.format(e, accuracy_train,accuracy_val))

In [68]:
mlp = MLP()
train(mlp, mnist_train.float(), mnist_trainlabel,
            mnist_test.float(), mnist_testlabel,
            50,nb_epochs = 100,eta=1e-3,verbose=True)

epoch 0: acc -> 0.9238166809082031 | acc_val -> 0.9203000068664551
epoch 10: acc -> 0.9802500009536743 | acc_val -> 0.9585999846458435
epoch 20: acc -> 0.9905833601951599 | acc_val -> 0.9639000296592712
epoch 30: acc -> 0.9953500032424927 | acc_val -> 0.9621000289916992
epoch 40: acc -> 0.9968166947364807 | acc_val -> 0.9621000289916992
epoch 50: acc -> 0.998283326625824 | acc_val -> 0.9621000289916992
epoch 60: acc -> 0.9993000030517578 | acc_val -> 0.9610999822616577
epoch 70: acc -> 0.9996500015258789 | acc_val -> 0.9617000222206116
epoch 80: acc -> 0.9998166561126709 | acc_val -> 0.9617000222206116
epoch 90: acc -> 0.9999499917030334 | acc_val -> 0.9620000123977661


In [64]:
def accuracy(y_pred,y_ground):
    return (torch.sum(torch.argmax(y_pred,axis=1) == y_ground) / y_ground.size(0)).item()

In [None]:
def one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.

    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N, #classes].
    """
    y = torch.eye(num_classes) 
    return y[labels] 

# onehot_train = one_hot_embedding(mnist_trainlabel,10)
# onehot_test = one_hot_embedding(mnist_testlabel,10)