In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

import numpy as np

from time import time

In [2]:
mnist_dataset = torchvision.datasets.MNIST(root='./data/', download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5), (0.5))
                             ]))
print(len(mnist_dataset))

60000


In [13]:
BATCH_SIZE = 1024
TRAIN_RATIO = 0.5

In [14]:
train_dataset, val_dataset = torch.utils.data.random_split(mnist_dataset, [int(TRAIN_RATIO * len(mnist_dataset)), len(mnist_dataset) - int(TRAIN_RATIO * len(mnist_dataset))])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [15]:
class ModelTrainer:
    def __init__(self, train_generator, test_generator):
        self.__model = None
        self.__train_generator = train_generator
        self.__test_generator = test_generator

    def set_model(self, model):
        self.__model = model

    def get_model(self):
        return self.__model

    def train_epoch(self, optimizer, batch_size=32, cuda=True):
        assert self.__model is not None
        
        model = self.__model

        loss_log, acc_log = [], []
        model.train()
        steps = 0
        for batch_num, (x_batch, y_batch) in enumerate(self.__train_generator):
            data = x_batch.cuda() if cuda else x_batch
            target = y_batch.cuda() if cuda else y_batch

            optimizer.zero_grad()
            output = model(data)
            pred = torch.max(output, 1)[1].cpu()
            acc = torch.eq(pred, y_batch).float().mean()
            acc_log.append(acc)
            
            loss = F.nll_loss(output, target).cpu()
            loss.backward()
            optimizer.step()
            loss = loss.item()
            loss_log.append(loss)
            
            steps += 1
            print('Step {0}'.format(steps), flush=True, end='\r')

        return loss_log, acc_log, steps
        

    def train(self, n_epochs, batch_size=32, lr=1e-3, cuda=True, plot_history=None, clear_output=None):
        assert self.__model is not None
    
        if cuda:
            self.__model = self.__model.cuda()
        else:
            self.__model = self.__model.cpu()

        model = self.__model
        opt = torch.optim.AdamW(model.parameters(), lr=lr)

        train_log, train_acc_log = [], []
        val_log, val_acc_log = [], []

        best_val_score = 0.

        for epoch in range(n_epochs):
            epoch_begin = time()
            print("Epoch {0} of {1}".format(epoch, n_epochs))
            train_loss, train_acc, steps = self.train_epoch(opt, batch_size=batch_size, cuda=cuda)

            val_loss, val_acc = self.test(cuda=cuda)

            train_log.extend(train_loss)
            train_acc_log.extend(train_acc)

            val_log.append((steps * (epoch + 1), np.mean(val_loss)))
            val_acc_log.append((steps * (epoch + 1), np.mean(val_acc)))

            if np.mean(val_acc) > best_val_score:
                best_val_score = np.mean(val_acc)
                torch.save(model, 'model_best.pth')
            
            if plot_history is not None:
                clear_output()
                plot_history(train_log, val_log)
                plot_history(train_acc_log, val_acc_log, title='accuracy')   
            epoch_end = time()
            epoch_time = epoch_end - epoch_begin
            print("Epoch: {2}, val loss: {0}, val accuracy: {1}".format(np.mean(val_loss), np.mean(val_acc), epoch))
            print("Epoch: {2}, train loss: {0}, train accuracy: {1}".format(np.mean(train_loss), np.mean(train_acc), epoch))
            print('Epoch time: {0}'.format(epoch_time))
        self.__model = model.cpu()

    def test(self, cuda=True):
        assert self.__model is not None
        
        model = self.__model
        
        loss_log, acc_log = [], []
        model.eval()
        
        for batch_num, (x_batch, y_batch) in enumerate(self.__test_generator):    
            data = x_batch.cuda() if cuda else x_batch
            target = y_batch.cuda() if cuda else y_batch

            output = model(data)
            loss = F.nll_loss(output, target).cpu()

            pred = torch.max(output, 1)[1].cpu()
            acc = torch.eq(pred, y_batch).float().mean()
            acc_log.append(acc)
            
            loss = loss.item()
            loss_log.append(loss)

        return loss_log, acc_log

def plot_history(train_history, val_history, title='loss'):
    plt.figure()
    plt.title('{}'.format(title))
    plt.plot(train_history, label='train', zorder=1)
    
    points = np.array(val_history)
    
    plt.scatter(points[:, 0], points[:, 1], marker='+', s=180, c='orange', label='val', zorder=2)
    plt.xlabel('train steps')
    
    plt.legend(loc='best')
    plt.grid()

    plt.show()

In [16]:
trainer = ModelTrainer(train_loader, val_loader)

In [17]:
base_channel = 16
model = nn.Sequential(# 28 * 28
    nn.Conv2d(1, base_channel, kernel_size=3, stride=1, dilation=1),  # 26*26
    nn.Conv2d(base_channel, base_channel, kernel_size=3, stride=1, dilation=1),  # 24 * 24
    nn.Conv2d(base_channel, 2 * base_channel, kernel_size=3, stride=2, dilation=1, padding=1),  # 12 * 12
    nn.Conv2d(2 * base_channel, 2 * base_channel, kernel_size=3, stride=1, dilation=1),  # 10*10
    nn.Conv2d(2 * base_channel, 4 * base_channel, kernel_size=3, stride=1, dilation=1),  # 8 * 8
    nn.Conv2d(4 * base_channel, 4 * base_channel, kernel_size=3, stride=1, dilation=1),  # 6 * 6
    nn.Flatten(),
    nn.Linear(6 * 6 * 4 * base_channel, 64),
    nn.Linear(64, 10)
)

In [18]:
trainer.set_model(model)

In [19]:
trainer.train(n_epochs=2, batch_size=BATCH_SIZE, lr=1e-3, cuda=False, plot_history=plot_history, clear_output=None)

Epoch 0 of 2
Step 6

KeyboardInterrupt: 