In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import MNIST

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, RunningAverage, ConfusionMatrix
from ignite.handlers import ModelCheckpoint, EarlyStopping

try:
    from tensorboardX import SummaryWriter
except ImportError:
    try:
        from torch.utils.tensorboard import SummaryWriter
    except ImportError:
        raise RuntimeError(
            "This module requires either tensorboardX or torch >= 1.2.0. "
            "You may install tensorboardX with command: \n pip install tensorboardX \n"
            "or upgrade PyTorch using your package manager of choice (pip or conda)."
        )

In [2]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.convlayer1 = nn.Sequential(
            nn.Conv2d(1, 32, 3,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.convlayer2 = nn.Sequential(
            nn.Conv2d(32,64,3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.fc1 = nn.Linear(64*6*6,600)
        self.drop = nn.Dropout2d(0.25)
        self.fc2 = nn.Linear(600, 120)
        self.fc3 = nn.Linear(120, 10)
        
    def forward(self, x):
        x = self.convlayer1(x)
        x = self.convlayer2(x)
        x = x.view(-1,64*6*6)
        x = self.fc1(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return F.log_softmax(x,dim=1)

In [3]:
def create_summary_writer(model, data_loader, log_dir):
    writer = SummaryWriter(log_dir=log_dir)
    data_loader_iter = iter(data_loader)
    x, y = next(data_loader_iter)
    try:
        writer.add_graph(model, x)
    except Exception as e:
        print("Failed to save model graph: {}".format(e))
    return writer

In [4]:
def get_next_run_name(log_dir):
    if not os.path.isdir(log_dir):
        return os.path.join(log_dir, 'run_0')
    else:
        runs = [int(d[4:]) for d in next(os.walk(log_dir))[1] if d[:3] == 'run']
        return os.path.join(log_dir, 'run_' + str(max(runs) + 1))

In [5]:
def run(params):
    run_dir = get_next_run_name(params['log_dir'])
    train_loader, val_loader = get_data_loaders(params['train_batch_size'], params['val_batch_size'])
    model = params['model']
    
    writer = create_summary_writer(model, train_loader, run_dir)
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    optimizer = params['optimizer'](model.parameters(), lr=params['lr'])
    criterion = params['criterion']
    
    trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
    evaluator = create_supervised_evaluator(
        model, metrics={"accuracy": Accuracy(), "nll": Loss(F.nll_loss)}, device=device
    )

    @trainer.on(Events.ITERATION_COMPLETED(every=params['log_interval']))
    def log_training_loss(engine):
        if params['tensorboard']:
            writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
        
        if params['verbose']:
            print(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                "".format(engine.state.epoch, engine.state.iteration, len(train_loader), engine.state.output)
            )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        
        if params['tensorboard']:
            writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
            writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch)
        
        if params['verbose']:
            print(
                "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
                    engine.state.epoch, avg_accuracy, avg_nll
                )
            )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        
        if params['tensorboard']:
            writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
            writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch)
        
        if params['verbose']:
            print(
                "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
                    engine.state.epoch, avg_accuracy, avg_nll
                )
            )
        
    trainer.run(train_loader, max_epochs=params['epochs'])
    writer.close()    
    return model

In [6]:
def get_data_loaders(train_batch_size, val_batch_size):
    transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
    
    trainset = datasets.FashionMNIST('./data', download=True, train=True, transform=transform)
    train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

    # Download and load the test data
    validationset = datasets.FashionMNIST('./data', download=True, train=False, transform=transform)
    val_loader = DataLoader(validationset, batch_size=val_batch_size, shuffle=True)
    
    return train_loader, val_loader

In [7]:
params = {
    'train_batch_size': 64,
    'val_batch_size': 64,
    'epochs': 10,
    'lr': 0.001,
    'verbose': True,
    'tensorboard': True,
    'log_interval': 100,
    'log_dir': './tensorboard_logs',
    'model': CNN(),
    'criterion': nn.NLLLoss(),
    'optimizer': optim.Adam,
    
}

In [8]:
model = run(params)

Epoch[1] Iteration[100/938] Loss: 0.76
Epoch[1] Iteration[200/938] Loss: 0.45
Epoch[1] Iteration[300/938] Loss: 0.37
Epoch[1] Iteration[400/938] Loss: 0.60
Epoch[1] Iteration[500/938] Loss: 0.39
Epoch[1] Iteration[600/938] Loss: 0.29
Epoch[1] Iteration[700/938] Loss: 0.30
Epoch[1] Iteration[800/938] Loss: 0.31
Epoch[1] Iteration[900/938] Loss: 0.49
Training Results - Epoch: 1  Avg accuracy: 0.90 Avg loss: 0.27
Validation Results - Epoch: 1  Avg accuracy: 0.89 Avg loss: 0.31
Epoch[2] Iteration[1000/938] Loss: 0.18
Epoch[2] Iteration[1100/938] Loss: 0.26
Epoch[2] Iteration[1200/938] Loss: 0.23
Epoch[2] Iteration[1300/938] Loss: 0.40
Epoch[2] Iteration[1400/938] Loss: 0.26
Epoch[2] Iteration[1500/938] Loss: 0.39
Epoch[2] Iteration[1600/938] Loss: 0.18
Epoch[2] Iteration[1700/938] Loss: 0.34
Epoch[2] Iteration[1800/938] Loss: 0.26
Training Results - Epoch: 2  Avg accuracy: 0.91 Avg loss: 0.26
Validation Results - Epoch: 2  Avg accuracy: 0.89 Avg loss: 0.30
Epoch[3] Iteration[1900/938] Loss

In [10]:
# # classes of fashion mnist dataset
# classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']
# # creating iterator for iterating the dataset
# dataiter = iter(val_loader)
# images, labels = dataiter.next()
# images_arr = []
# labels_arr = []
# pred_arr = []
# # moving model to cpu for inference 
# model.to("cpu")
# # iterating on the dataset to predict the output
# for i in range(0,10):
#     images_arr.append(images[i].unsqueeze(0))
#     labels_arr.append(labels[i].item())
#     ps = torch.exp(model(images_arr[i]))
#     ps = ps.data.numpy().squeeze()
#     pred_arr.append(np.argmax(ps))
# # plotting the results
# fig = plt.figure(figsize=(25,4))
# for i in range(10):
#     ax = fig.add_subplot(2, 20/2, i+1, xticks=[], yticks=[])
#     ax.imshow(images_arr[i].resize_(1, 28, 28).numpy().squeeze())
#     ax.set_title("{} ({})".format(classes[pred_arr[i]], classes[labels_arr[i]]),
#                  color=("green" if pred_arr[i]==labels_arr[i] else "red"))