In [20]:
from tqdm import tqdm_notebook as tqdm
import torch
import torch.optim
import torchnet as tnt
from torchvision.datasets.mnist import MNIST
from torchnet.engine import Engine
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.init import kaiming_normal


def get_iterator(mode):
    ds = MNIST(root='./', download=True, train=mode)
    data = getattr(ds, 'train_data' if mode else 'test_data')
    labels = getattr(ds, 'train_labels' if mode else 'test_labels')
    tds = tnt.dataset.TensorDataset([data, labels])
    return tds.parallel(batch_size=128, num_workers=4, shuffle=mode)


def conv_init(ni, no, k):
    return kaiming_normal(torch.Tensor(no, ni, k, k))


def linear_init(ni, no):
    return kaiming_normal(torch.Tensor(no, ni))


def f(params, inputs, mode):
    o = inputs.view(inputs.size(0), 1, 28, 28)
    o = F.conv2d(o, params['conv0.weight'], params['conv0.bias'], stride=2)
    o = F.relu(o)
    o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'], stride=2)
    o = F.relu(o)
    o = o.view(o.size(0), -1)
    o = F.linear(o, params['linear2.weight'], params['linear2.bias'])
    o = F.relu(o)
    o = F.linear(o, params['linear3.weight'], params['linear3.bias'])
    return o


def main():
    params = {
        'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50),
        'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50),
        'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512),
        'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10),
    }
    params = {k: Variable(v, requires_grad=True) for k, v in params.items()}

    optimizer = torch.optim.SGD(
        params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005)

    engine = Engine()
    meter_loss = tnt.meter.AverageValueMeter()
    classerr = tnt.meter.ClassErrorMeter(accuracy=True)

    def h(sample):
        inputs = Variable(sample[0].float() / 255.0)
        targets = Variable(torch.LongTensor(sample[1]))
        o = f(params, inputs, sample[2])
        return F.cross_entropy(o, targets), o

    def reset_meters():
        classerr.reset()
        meter_loss.reset()

    def on_sample(state):
        state['sample'].append(state['train'])

    def on_forward(state):
        classerr.add(state['output'].data,
                     torch.LongTensor(state['sample'][1]))
        meter_loss.add(state['loss'].item())

    def on_start_epoch(state):
        reset_meters()
        state['iterator'] = tqdm(state['iterator'], desc="[Epoch %d / 10]"%(state['epoch'] + 1))

    def on_end_epoch(state):
        print('Training loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))
        # do validation at the end of each epoch
        reset_meters()
        engine.test(h, get_iterator(False))
        print('Testing loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))

    engine.hooks['on_sample'] = on_sample
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch
    engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer)

In [21]:
main()



HBox(children=(IntProgress(value=0, description='[Epoch 0/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.2467, accuracy: 92.62%
Testing loss: 0.0988, accuracy: 97.06%


HBox(children=(IntProgress(value=0, description='[Epoch 1/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0851, accuracy: 97.42%
Testing loss: 0.0751, accuracy: 97.68%


HBox(children=(IntProgress(value=0, description='[Epoch 2/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0620, accuracy: 98.10%
Testing loss: 0.0561, accuracy: 98.08%


HBox(children=(IntProgress(value=0, description='[Epoch 3/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0460, accuracy: 98.63%
Testing loss: 0.0553, accuracy: 98.20%


HBox(children=(IntProgress(value=0, description='[Epoch 4/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0393, accuracy: 98.82%
Testing loss: 0.0427, accuracy: 98.66%


HBox(children=(IntProgress(value=0, description='[Epoch 5/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0323, accuracy: 99.01%
Testing loss: 0.0427, accuracy: 98.59%


HBox(children=(IntProgress(value=0, description='[Epoch 6/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0273, accuracy: 99.19%
Testing loss: 0.0428, accuracy: 98.64%


HBox(children=(IntProgress(value=0, description='[Epoch 7/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0240, accuracy: 99.29%
Testing loss: 0.0384, accuracy: 98.79%


HBox(children=(IntProgress(value=0, description='[Epoch 8/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0204, accuracy: 99.41%
Testing loss: 0.0377, accuracy: 98.76%


HBox(children=(IntProgress(value=0, description='[Epoch 9/ 10]', max=469, style=ProgressStyle(description_widt…

Training loss: 0.0186, accuracy: 99.48%
Testing loss: 0.0395, accuracy: 98.66%
