<a href="https://colab.research.google.com/github/abinjoabraham/DCFNet_pytorch/blob/master/TCN_on_Sequential_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
from sample_data.utils import data_generator
from sample_data.model import TCN
import numpy as np
import argparse
from tensorboardX import SummaryWriter




In [0]:
torch.manual_seed(1111)
if torch.cuda.is_available():
   if not True:
      print("WARNING: You have a CUDA device, so you should probably run with --cuda")

root = './_sample_data/mnist'
batch_size = 64
n_classes = 10
input_channels = 1
seq_length = int(784 / input_channels)
epochs = 2
steps = 0
permuter = False
seed = 1111
nhid = 25
optimi = 'Adam'
lr = 2e-3
loginterval = 100
levels = 8
cuda = True
clip = -1 

In [3]:
!pip install tensorboardX


Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/5c/76/89dd44458eb976347e5a6e75eb79fecf8facd46c1ce259bad54e0044ea35/tensorboardX-1.6-py2.py3-none-any.whl (129kB)
[K     |████████████████████████████████| 133kB 2.9MB/s 
Installing collected packages: tensorboardX
Successfully installed tensorboardX-1.6


In [0]:
train_loader, test_loader = data_generator(root, batch_size)
permute = torch.Tensor(np.random.permutation(784).astype(np.float64)).long()
channel_sizes = [25] * 8
kernel_size = 7
model = TCN(input_channels, n_classes, channel_sizes, kernel_size=kernel_size, dropout=0.05)





In [0]:
if torch.cuda.is_available():
  model.cuda()
  permute = permute.cuda()

lr = 2e-3
optimizer = getattr(optim, optimi)(model.parameters(), lr=lr)

In [0]:
def train(ep):
    global steps
    train_loss = 0
    model.train()
    #writer = SummaryWriter()

    for batch_idx, (data, target) in enumerate(train_loader):

        if cuda: data, target = data.cuda(), target.cuda()
        data = data.view(-1, input_channels, seq_length)
        if permuter:
            data = data[:, :, permute]
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        #writer.add_graph(model, data, True)             # model, input
        output = model(data)
        loss = F.nll_loss(output, target)
        #writer.add_scalar('lossfunction', loss, batch_idx)
        loss.backward()
        # writer.add_image('data/scalar1', data, batch_idx)
        # writer.add_image('data/scalar2', target, batch_idx)
        if clip > 0:
            torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        optimizer.step()
        train_loss += loss
        steps += seq_length
        if batch_idx > 0 and batch_idx % loginterval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tSteps: {}'.format(
                ep, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), train_loss.item()/loginterval, steps))
            train_loss = 0


In [0]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        data = data.view(-1, input_channels, seq_length)
        if permuter:
            data = data[:, :, permute]
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.nll_loss(output, target, size_average=False).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss


In [7]:
    for epoch in range(1, epochs+1):
        train(epoch)
        test()
        if epoch % 10 == 0:
            lr /= 10
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr



  # This is added back by InteractiveShellApp.init_path()



Test set: Average loss: 0.1504, Accuracy: 9509/10000 (95%)


Test set: Average loss: 0.0918, Accuracy: 9696/10000 (96%)

