In [1]:
import torch

In [67]:
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader, Sampler
import torch.nn.functional as F

In [33]:
import numpy as np

In [83]:
import tqdm

In [84]:
from torchvision import datasets, transforms
import torchvision

In [270]:
class VggNetwork(nn.Module):  
    def __init__(self):
        super(VggNetwork, self).__init__()
        self.vgg = torchvision.models.vgg11(pretrained=True)
        self.adjust_conv = nn.Conv2d(1, 3, 1)
        self.features = self.vgg.features
        self.avgpool = self.vgg.avgpool
        self.classifier = nn.Sequential(nn.Linear(25088, 10))
    def forward(self, x):
#         Conv to adjust 1 channel MNIST images to RGB required by vgg pretrained (will be removed later)
        x = self.adjust_conv(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return F.log_softmax(x, dim=1)
    
    def trainable_params(self):
        return [p for p in self.parameters() if p.requires_grad]

In [271]:
net = VggNetwork()

In [272]:
mnist_train = datasets.MNIST('', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Resize((224, 224)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
train_loader = torch.utils.data.DataLoader(mnist_train,batch_size=4, shuffle=True)

mnist_test =  datasets.MNIST('', train=False, transform=transforms.Compose([
                           transforms.Resize((224, 224)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
test_loader = torch.utils.data.DataLoader(mnist_test,batch_size=4, shuffle=True)

In [273]:
config = {
    'DEBUG': False,
    'CUDA': torch.cuda.is_available(),
    'DEVICE': torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

In [284]:
class Trainer():
    def __init__(self, net, config):
        net.train()
        net.to(config['DEVICE'])
        self.device = config['DEVICE']
        self.net = net
        self.config = config
        self.optimizer = torch.optim.Adam(net.trainable_params(), lr=0.001)
    
    def run(self, dataloader, epoch=1):
        print(">> Running trainer")
        for epoch in range(epoch):
            print(">>> Epoch %s" % epoch)
            for idx, (image, target) in enumerate(tqdm.tqdm_notebook(dataloader, ascii=True)):
                image, target = image.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                predict = self.net(image)
                loss = F.nll_loss(predict, target)
                loss.backward()
                self.optimizer.step()
                if self.config['DEBUG'] == True:
                    break
            print("Trainer epoch finished")

In [285]:
class Evaluation():
    def __init__(self, net, config):
        net.eval()
        net.to(config['DEVICE'])
        self.device = config['DEVICE']
        self.net = net
        self.config = config
    
    def run(self, dataloader):
        print(">> Running Evaluation")
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for idx, (image, target) in enumerate(tqdm.tqdm_notebook(dataloader, ascii=True)):
                image, target = image.to(self.device), target.to(self.device)
                predict = self.net(image)
                test_loss += F.nll_loss(predict, target, reduction='sum').item() # sum up batch loss
                pred = predict.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                if self.config['DEBUG'] == True:
                    break
                    
        test_loss /= len(dataloader.dataset)
        print("Evaluation finished")
        return {
            'loss': test_loss,
            'accuracy': 100. * correct / len(dataloader.dataset)
        }

In [286]:
trainer = Trainer(net, config)

In [287]:
evaluation = Evaluation(net, config)

In [288]:
trainer.run(train_loader)

>> Running trainer
>>> Epoch 0


HBox(children=(IntProgress(value=0, max=15000), HTML(value='')))

KeyboardInterrupt: 

In [259]:
evaluation.run(test_loader)

>> Running Evaluation


HBox(children=(IntProgress(value=0, max=2500), HTML(value='')))

Evaluation finished


{'loss': 0.006412623596191406, 'accuracy': 0.01}