In [1]:
import os
import itertools

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import MNIST

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, RunningAverage, ConfusionMatrix, Precision, Recall
from ignite.contrib.metrics import ROC_AUC
from ignite.handlers import ModelCheckpoint, EarlyStopping

try:
    from tensorboardX import SummaryWriter
except ImportError:
    try:
        from torch.utils.tensorboard import SummaryWriter
    except ImportError:
        raise RuntimeError(
            "This module requires either tensorboardX or torch >= 1.2.0. "
            "You may install tensorboardX with command: \n pip install tensorboardX \n"
            "or upgrade PyTorch using your package manager of choice (pip or conda)."
        )

In [2]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.convlayer1 = nn.Sequential(
            nn.Conv2d(1, 32, 3,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.convlayer2 = nn.Sequential(
            nn.Conv2d(32,64,3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.fc1 = nn.Linear(64*6*6,600)
        self.drop = nn.Dropout2d(0.25)
        self.fc2 = nn.Linear(600, 120)
        self.fc3 = nn.Linear(120, 10)
        
    def forward(self, x):
        x = self.convlayer1(x)
        x = self.convlayer2(x)
        x = x.view(-1,64*6*6)
        x = self.fc1(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return F.log_softmax(x,dim=1)

In [3]:
def create_summary_writer(model, data_loader, log_dir):
    writer = SummaryWriter(log_dir=log_dir)
    data_loader_iter = iter(data_loader)
    x, y = next(data_loader_iter)
    try:
        writer.add_graph(model, x)
    except Exception as e:
        print("Failed to save model graph: {}".format(e))
    return writer

In [4]:
def get_next_run_name(log_dir):
    if not os.path.isdir(log_dir):
        return os.path.join(log_dir, 'run_0')
    else:
        runs = [int(d[4:]) for d in next(os.walk(log_dir))[1] if d[:3] == 'run']
        return os.path.join(log_dir, 'run_' + str(max(runs) + 1))

In [5]:
def activated_output_transform(output):
    y_pred, y = output
    y_pred = torch.sigmoid(y_pred)
    return y_pred, y

In [6]:
def save_cm(cm, filename='', classes=[]):
    fig = plt.figure(figsize=(5, 5), dpi=100, facecolor='w', edgecolor='k')
    ax = fig.add_subplot(1, 1, 1)
    ax.imshow(cm, cmap='Oranges')
    
    if len(classes) == 0:
        classes = [str(c) for c in range(len(cm))]
    tick_marks = np.arange(len(classes))
    
    ax.set_xlabel('Predicted', fontsize=7)
    ax.set_xticks(tick_marks)
    c = ax.set_xticklabels(classes, fontsize=7, rotation=-90,  ha='center')
    ax.xaxis.set_label_position('bottom')
    ax.xaxis.tick_bottom()

    ax.set_ylabel('True Label', fontsize=7)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes, fontsize=7, va ='center')
    ax.yaxis.set_label_position('left')
    ax.yaxis.tick_left()
    
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color="black")
    fig.set_tight_layout(True)
    plt.savefig(os.path.join(filename, 'confusion_matrix.jpg'))
    plt.close()

In [7]:
def get_data_loaders(train_batch_size, val_batch_size):
    dataset_name = 'MNIST'
    
    transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
    
    trainset = datasets.FashionMNIST('./data', download=True, train=True, transform=transform)
    train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)

    validationset = datasets.FashionMNIST('./data', download=True, train=False, transform=transform)
    val_loader = DataLoader(validationset, batch_size=val_batch_size, shuffle=True)
    
    return train_loader, val_loader, dataset_name

In [8]:
def run(params):
    train_loader, val_loader, dataset_name = get_data_loaders(params['train_batch_size'], params['val_batch_size'])
    model_dir = os.path.join(params['log_dir'], dataset_name)
    run_dir = get_next_run_name(model_dir)
    
    model = params['model']
    
    writer = create_summary_writer(model, train_loader, run_dir)
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    optimizer = params['optimizer'](model.parameters(), lr=params['lr'])
    criterion = params['criterion']
    scheduler = ReduceLROnPlateau(optimizer, patience=params['lr_patience'], factor=params['lr_factor'], min_lr=params['min_lr'])
    
    model_name = (
        params['model_name'] + 
        '_lr:' + str(params['lr']) +
        '_' + str(optimizer).split(' ')[0] +
        '_' + str(criterion).split('.')[-1].split("'")[0][:-2]
    )
    
    trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
    
    train_evaluator = create_supervised_evaluator(
        model, 
        metrics=
        {
            "accuracy": Accuracy(), 
            "nll": Loss(F.nll_loss),
        },
        device=device
    )
    
    val_evaluator = create_supervised_evaluator(
        model, 
        metrics=
        {
            "accuracy": Accuracy(), 
            "nll": Loss(F.nll_loss),
            "precision": Precision(),
            "recall": Recall(),
            "conf_matrix": ConfusionMatrix(10),
        },
        device=device
    )
    
    def score_function(engine):
        val_loss = engine.state.metrics['nll']
        return -val_loss
    
    if params['early_stopping']:
        handler = EarlyStopping(patience=params['early_stopping_patience'], score_function=score_function, trainer=trainer)
        val_evaluator.add_event_handler(Events.COMPLETED, handler)
    
    checkpointer = ModelCheckpoint(os.path.join(model_dir,'saved_models'), dataset_name, n_saved=3, create_dir=True, save_as_state_dict=True, require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {model_name: model})
    
    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    @trainer.on(Events.ITERATION_COMPLETED(every=params['log_interval']))
    def log_training_loss(engine):
        if params['tensorboard']:
            writer.add_scalar("training/batch_loss", engine.state.output, engine.state.iteration)
        
        if params['verbose']:
            print(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                "".format(engine.state.epoch, engine.state.iteration, len(train_loader), engine.state.output)
            )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        
        if params['tensorboard']:
            writer.add_scalar("training/epoch_loss", avg_nll, engine.state.epoch)
#             writer.add_scalar("training/epoch_accuracy", avg_accuracy, engine.state.epoch)
        
        if params['verbose']:
            print(
                "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
                    engine.state.epoch, avg_accuracy, avg_nll
                )
            )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        accuracy = metrics["accuracy"]       
        precision = metrics["precision"]
        recall = metrics["recall"]
        
        nll = metrics["nll"] 
        if scheduler:
            scheduler.step(nll)
        
        save_cm(metrics["conf_matrix"] , filename=run_dir, classes=[])
        
        if params['tensorboard']:
            writer.add_scalar("valdation/epoch_loss", nll, engine.state.epoch)
            writer.add_scalar("valdation/epoch_accuracy", accuracy, engine.state.epoch)            
            writer.add_scalar("valdation/epoch_precision", precision.sum() / len(precision), engine.state.epoch)
            writer.add_scalar("valdation/epoch_recall", recall.sum() / len(recall), engine.state.epoch)
            
            writer.add_scalar("valdation/learning_rate", get_lr(optimizer), engine.state.epoch)
        
        if params['verbose']:
            print(
                "Validation Results - Epoch: {}  Accuracy: {:.2f}  loss: {:.2f}".format(
                    engine.state.epoch, accuracy, nll
                )
            )
        
    trainer.run(train_loader, max_epochs=params['epochs'])
    writer.close()    
    return model

In [9]:
params = {
    'train_batch_size': 64,
    'val_batch_size': 64,
    'epochs': 40,
    'lr': 1e-3,
    'verbose': False,
    'tensorboard': True,
    'log_interval': 100,
    'log_dir': './tensorboard_logs',
    'model_name': 'LeNet',
    'model': CNN(),
    'criterion': nn.NLLLoss(),
    'optimizer': optim.Adam,
    'early_stopping': True,
    'early_stopping_patience': 11,
    'lr_patience': 5,
    'lr_factor': 0.25,
    'min_lr': 1e-5,
    
}

In [10]:
model = run(params)

In [41]:
# Continue training
model = CNN()
model.load_state_dict(torch.load('./tensorboard_logs/MNIST/saved_models/MNIST_LeNet_lr:0.001_Adam_NLLLoss_13132.pth'))

params = {
    'train_batch_size': 64,
    'val_batch_size': 64,
    'epochs': 40,
    'lr': 1e-4,
    'verbose': False,
    'tensorboard': True,
    'log_interval': 100,
    'log_dir': './tensorboard_logs',
    'model_name': 'LeNet',
    'model': model,
    'criterion': nn.NLLLoss(),
    'optimizer': optim.Adam,
    'early_stopping': True,
    'early_stopping_patience': 11,
    'lr_patience': 5,
    'lr_factor': 0.25,
    'min_lr': 1e-5,
}

model = run(params)

In [11]:
# # classes of fashion mnist dataset
# classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']
# # creating iterator for iterating the dataset
# dataiter = iter(val_loader)
# images, labels = dataiter.next()
# images_arr = []
# labels_arr = []
# pred_arr = []
# # moving model to cpu for inference 
# model.to("cpu")
# # iterating on the dataset to predict the output
# for i in range(0,10):
#     images_arr.append(images[i].unsqueeze(0))
#     labels_arr.append(labels[i].item())
#     ps = torch.exp(model(images_arr[i]))
#     ps = ps.data.numpy().squeeze()
#     pred_arr.append(np.argmax(ps))
# # plotting the results
# fig = plt.figure(figsize=(25,4))
# for i in range(10):
#     ax = fig.add_subplot(2, 20/2, i+1, xticks=[], yticks=[])
#     ax.imshow(images_arr[i].resize_(1, 28, 28).numpy().squeeze())
#     ax.set_title("{} ({})".format(classes[pred_arr[i]], classes[labels_arr[i]]),
#                  color=("green" if pred_arr[i]==labels_arr[i] else "red"))