In [1]:
import torch

In [3]:
torch.__version__

'0.4.1'

In [5]:
!pip install torchnet

Collecting torchnet
  Downloading https://files.pythonhosted.org/packages/b7/b2/d7f70a85d3f6b0365517782632f150e3bbc2fb8e998cd69e27deba599aae/torchnet-0.0.4.tar.gz
Collecting visdom (from torchnet)
[?25l  Downloading https://files.pythonhosted.org/packages/c1/48/d90e1519768107811fd6e7760bea46fff9e9c9ffb490441684003ae634a9/visdom-0.1.8.5.tar.gz (248kB)
[K    100% |████████████████████████████████| 256kB 2.7MB/s ta 0:00:011
Collecting torchfile (from visdom->torchnet)
  Downloading https://files.pythonhosted.org/packages/91/af/5b305f86f2d218091af657ddb53f984ecbd9518ca9fe8ef4103a007252c9/torchfile-0.1.0.tar.gz
Collecting websocket-client (from visdom->torchnet)
[?25l  Downloading https://files.pythonhosted.org/packages/14/d4/6a8cd4e7f67da465108c7cc0a307a1c5da7e2cdf497330b682069b1d4758/websocket_client-0.53.0-py2.py3-none-any.whl (198kB)
[K    100% |████████████████████████████████| 204kB 316kB/s ta 0:00:011
Building wheels for collected packages: torchnet, visdom, torchfile
  Running s

In [None]:
#python -m visdom.server -port 8097 &

In [7]:
from tqdm import tqdm
import torch
import torch.optim
import torchnet as tnt
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.init import kaiming_normal
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.datasets.mnist import MNIST

In [8]:
def get_iterator(mode):
    ds = MNIST(root='./', download=True, train=mode)
    data = getattr(ds, 'train_data' if mode else 'test_data')
    labels = getattr(ds, 'train_labels' if mode else 'test_labels')
    tds = tnt.dataset.TensorDataset([data, labels])
    return tds.parallel(batch_size=128, num_workers=4, shuffle=mode)


def conv_init(ni, no, k):
    return kaiming_normal(torch.Tensor(no, ni, k, k))


def linear_init(ni, no):
    return kaiming_normal(torch.Tensor(no, ni))

In [9]:
def f(params, inputs, mode):
    o = inputs.view(inputs.size(0), 1, 28, 28)
    o = F.conv2d(o, params['conv0.weight'], params['conv0.bias'], stride=2)
    o = F.relu(o)
    o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'], stride=2)
    o = F.relu(o)
    o = o.view(o.size(0), -1)
    o = F.linear(o, params['linear2.weight'], params['linear2.bias'])
    o = F.relu(o)
    o = F.linear(o, params['linear3.weight'], params['linear3.bias'])
    return o

In [11]:
params = {
        'conv0.weight': conv_init(1, 50, 5), 'conv0.bias': torch.zeros(50),
        'conv1.weight': conv_init(50, 50, 5), 'conv1.bias': torch.zeros(50),
        'linear2.weight': linear_init(800, 512), 'linear2.bias': torch.zeros(512),
        'linear3.weight': linear_init(512, 10), 'linear3.bias': torch.zeros(10),
    }

params = {k: Variable(v, requires_grad=True) for k, v in params.items()}

  # Remove the CWD from sys.path while we load stuff.
  


In [12]:
optimizer = torch.optim.SGD(
    params.values(), lr=0.01, momentum=0.9, weight_decay=0.0005)

In [13]:
engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
classerr = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(10, normalized=True)

port = 8097
train_loss_logger = VisdomPlotLogger(
        'line', port=port, opts={'title': 'Train Loss'})
train_err_logger = VisdomPlotLogger(
        'line', port=port, opts={'title': 'Train Class Error'})
test_loss_logger = VisdomPlotLogger(
        'line', port=port, opts={'title': 'Test Loss'})
test_err_logger = VisdomPlotLogger(
        'line', port=port, opts={'title': 'Test Class Error'})
confusion_logger = VisdomLogger('heatmap', port=port, opts={'title': 'Confusion matrix',
                                                                'columnnames': list(range(10)),
                                                                'rownames': list(range(10))})

In [14]:
def h(sample):
    inputs = Variable(sample[0].float() / 255.0)
    targets = Variable(torch.LongTensor(sample[1]))
    o = f(params, inputs, sample[2])
    return F.cross_entropy(o, targets), o

def reset_meters():
    classerr.reset()
    meter_loss.reset()
    confusion_meter.reset()

def on_sample(state):
    state['sample'].append(state['train'])

def on_forward(state):
    classerr.add(state['output'].data,
                 torch.LongTensor(state['sample'][1]))
    confusion_meter.add(state['output'].data,
                        torch.LongTensor(state['sample'][1]))
    meter_loss.add(state['loss'].data[0])

def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm(state['iterator'])

def on_end_epoch(state):
    print('Training loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))
    train_loss_logger.log(state['epoch'], meter_loss.value()[0])
    train_err_logger.log(state['epoch'], classerr.value()[0])

    # do validation at the end of each epoch
    reset_meters()
    engine.test(h, get_iterator(False))
    test_loss_logger.log(state['epoch'], meter_loss.value()[0])
    test_err_logger.log(state['epoch'], classerr.value()[0])
    confusion_logger.log(confusion_meter.value())
    print('Testing loss: %.4f, accuracy: %.2f%%' % (meter_loss.value()[0], classerr.value()[0]))

engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch
engine.train(h, get_iterator(True), maxepoch=10, optimizer=optimizer)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...


  0%|          | 0/469 [00:00<?, ?it/s]

Done!


100%|██████████| 469/469 [00:38<00:00, 12.13it/s]

Training loss: 0.2601, accuracy: 92.14%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.1080, accuracy: 96.77%


100%|██████████| 469/469 [00:41<00:00, 11.40it/s]

Training loss: 0.0935, accuracy: 97.16%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0763, accuracy: 97.45%


100%|██████████| 469/469 [00:38<00:00, 12.27it/s]

Training loss: 0.0646, accuracy: 98.02%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0620, accuracy: 98.07%


100%|██████████| 469/469 [00:40<00:00, 11.49it/s]

Training loss: 0.0492, accuracy: 98.53%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0566, accuracy: 98.09%


100%|██████████| 469/469 [00:41<00:00, 11.35it/s]


Training loss: 0.0400, accuracy: 98.77%


  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0436, accuracy: 98.43%


100%|██████████| 469/469 [00:44<00:00, 10.67it/s]


Training loss: 0.0343, accuracy: 98.98%


  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0437, accuracy: 98.49%


100%|██████████| 469/469 [00:44<00:00, 10.54it/s]

Training loss: 0.0292, accuracy: 99.14%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0430, accuracy: 98.45%


100%|██████████| 469/469 [00:41<00:00, 11.35it/s]

Training loss: 0.0253, accuracy: 99.25%



  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0367, accuracy: 98.74%


100%|██████████| 469/469 [00:41<00:00, 12.75it/s]


Training loss: 0.0219, accuracy: 99.38%


  0%|          | 0/469 [00:00<?, ?it/s]

Testing loss: 0.0370, accuracy: 98.69%


100%|██████████| 469/469 [00:43<00:00, 12.57it/s]


Training loss: 0.0200, accuracy: 99.43%
Testing loss: 0.0350, accuracy: 98.81%


{'network': <function __main__.h(sample)>,
 'iterator': 100%|██████████| 469/469 [00:45<00:00, 12.57it/s],
 'maxepoch': 10,
 'optimizer': SGD (
 Parameter Group 0
     dampening: 0
     lr: 0.01
     momentum: 0.9
     nesterov: False
     weight_decay: 0.0005
 ),
 'epoch': 10,
 't': 4690,
 'train': True,
 'sample': [tensor([[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]],
  
          [[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]],
  
          [[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]