In [19]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch.nn import functional as F
from pytorch_model_summary import summary

In [1]:
import torchvision
import torchvision.datasets as datasets

#### Goal: Composite optimizers -> Frank Wolfe / SGD / ADAM 

### Part 1: DNN Mnist Frank Wolfe grid-search

In [3]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

In [4]:
mnist_train = mnist_trainset.data
mnist_test = mnist_testset.data

mnist_trainlabel = mnist_trainset.targets
mnist_testlabel = mnist_testset.targets

In [5]:
print(mnist_train.shape)
print(mnist_test.shape)
print(mnist_trainlabel.shape)
print(mnist_testlabel.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])
torch.Size([60000])
torch.Size([10000])


In [21]:
from dfw import DFW
from dfw.losses import MultiClassHingeLoss

In [6]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64) 
        self.fc3 = nn.Linear(64, 10) 


    def forward(self, x):
        x = x.view(-1,784) # flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
print(summary(MLP(), torch.zeros((1, 1, 28, 28)), show_input=True))

-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Linear-1            [1, 784]         100,480         100,480
          Linear-2            [1, 128]           8,256           8,256
          Linear-3             [1, 64]             650             650
Total params: 109,386
Trainable params: 109,386
Non-trainable params: 0
-----------------------------------------------------------------------


In [7]:
def accuracy(y_pred,y_ground):
    return (torch.sum(torch.argmax(y_pred,axis=1) == y_ground) / y_ground.size(0)).item()

In [8]:
def train(model, train_input, train_target, test_input, test_target,
                mini_batch_size=50, nb_epochs = 100,eta = 1e-3, verbose=1):
    
    svm = MultiClassHingeLoss()
    optimizer = DFW(model.parameters(), eta=eta)
    
    history = {}
    history['loss'] = []
    history['loss_val'] = []
    history['acc'] = []
    history['acc_val'] = []
    
    for e in range(nb_epochs):
        total_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            try: 
                output = model(train_input.narrow(0, b , mini_batch_size))
            except:
                continue
            loss = svm(output, train_target.narrow(0, b, mini_batch_size))
            total_loss = total_loss + loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step(lambda: float(loss))
        
        if e%10==0:
            # yea probably need to check if there is smarter way to evaluate 
            output_train = model(train_input.float())
            output_val = model(test_input.float())
            
            accuracy_train = accuracy(output_train, train_target)
            accuracy_val = accuracy(output_val, test_target)
            loss_train = svm(output_train, train_target)
            loss_val = svm(output_val, test_target)
            
            history['loss'].append(loss_train)
            history['loss_val'].append(loss_val)
            history['acc'].append(accuracy_train)
            history['acc_val'].append(accuracy_val)
            
            if verbose : 
#                print(e, total_loss)
                print('epoch {}: acc -> {} | acc_val -> {}'.format(e, accuracy_train,accuracy_val))
        
    return history    

In [68]:
torch.manual_seed(123456789) # seeding for weight initialization and train

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

histories = []

bss = [50,64,128]
lrs = [1e-2,1e-3, 5e-3, 1e-4]

mesh = np.meshgrid(bss,lrs)

for bs,lr in zip(mesh[0].ravel(),mesh[1].ravel()):
    print('batch-size: {} | learning-rate: {}'.format(bs,lr))
    # reset weights
    mlp = MLP().to(device)

    history = train(mlp, mnist_train.float(), mnist_trainlabel,
                mnist_test.float(), mnist_testlabel,
                mini_batch_size=bs,nb_epochs = 100,eta=lr,verbose=False)
    
    histories.append(history)

epoch 0: acc -> 0.9238166809082031 | acc_val -> 0.9203000068664551
epoch 10: acc -> 0.9802500009536743 | acc_val -> 0.9585999846458435
epoch 20: acc -> 0.9905833601951599 | acc_val -> 0.9639000296592712
epoch 30: acc -> 0.9953500032424927 | acc_val -> 0.9621000289916992
epoch 40: acc -> 0.9968166947364807 | acc_val -> 0.9621000289916992
epoch 50: acc -> 0.998283326625824 | acc_val -> 0.9621000289916992
epoch 60: acc -> 0.9993000030517578 | acc_val -> 0.9610999822616577
epoch 70: acc -> 0.9996500015258789 | acc_val -> 0.9617000222206116
epoch 80: acc -> 0.9998166561126709 | acc_val -> 0.9617000222206116
epoch 90: acc -> 0.9999499917030334 | acc_val -> 0.9620000123977661


### Part 2: CNN Mnist Frank Wolfe grid-search
On Cifar 10

In [12]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True)

Using downloaded and verified file: ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [40]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


cnn = CNN()
print(summary(CNN(), torch.zeros((1, 3, 32, 32)), show_input=True))

-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Conv2d-1      [1, 3, 32, 32]             456             456
       MaxPool2d-2      [1, 6, 28, 28]               0               0
          Conv2d-3      [1, 6, 14, 14]           2,416           2,416
          Linear-4            [1, 400]          48,120          48,120
          Linear-5            [1, 120]          10,164          10,164
          Linear-6             [1, 84]             850             850
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
-----------------------------------------------------------------------


In [77]:
def reverse_channel(img):
    # (X,X,3) -> (3,X,X)
    return np.stack((img[:,:,0],img[:,:,1],img[:,:,2]))

In [80]:
# do not run often
cifar10_train = torch.tensor([reverse_channel(img) for img in trainset.data])
cifar10_test = torch.tensor([reverse_channel(img) for img in testset.data])

In [88]:
cifar10_trainlab = torch.tensor(trainset.targets)
cifar10_testlab = torch.tensor(testset.targets)

In [None]:
torch.manual_seed(123456789) # seeding for weight initialization and train

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

histories = []

bss = [50,64,128]
lrs = [1e-2,1e-3, 5e-3, 1e-4]

mesh = np.meshgrid(bss,lrs)

for bs,lr in zip(mesh[0].ravel(),mesh[1].ravel()):
    print('batch-size: {} | learning-rate: {}'.format(bs,lr))
    # reset weights
    cnn = CNN().to(device)

    history = train(cnn, cifar10_train.float(), cifar10_trainlab,
                cifar10_test.float(), cifar10_testlab,
                mini_batch_size=bs,nb_epochs = 100,eta=lr,verbose=False)
    
    histories.append(history)

batch-size: 50 | learning-rate: 0.01
