In [None]:
# Training a simple neural network on MNIST with confusion matrix visualization

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
import torch.optim as optim
from torch.utils import data
from torch.autograd import variable

import seaborn as sn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import itertools

# Confusion Matrix Plot Function
def plot_confusion_matrix(cm, target_names=None, cmap=None, normalize=True, labels=True, title='Confusion matrix'):
    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy
    if cmap is None:
        cmap = plt.get_cmap('Blues')
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(10, 15))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names)
        plt.yticks(tick_marks, target_names)

    if labels:
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            val = "{:0.4f}".format(cm[i, j]) if normalize else "{:,}".format(cm[i, j])
            plt.text(j, i, val, horizontalalignment='center',
                     color='white' if cm[i, j] > thresh else 'black')

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()

# Load MNIST dataset
def MNIST_DATA(root='./data', train=True, transforms=None, download=True, batch_size=32, num_worker=1):
    print("[+] Get the MNIST DATA")
    mnist_train = datasets.MNIST(root=root, train=True, transform=T.ToTensor(), download=download)
    mnist_test = datasets.MNIST(root=root, train=False, transform=T.ToTensor(), download=True)

    train_loader = data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    test_loader = data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_worker)

    print("[+] MNIST DATA Loaded")
    return mnist_train, mnist_test, train_loader, test_loader

mnist_train, mnist_test, train_loader, test_loader = MNIST_DATA(batch_size=32)

# Define Trainer class
class Trainer():
    def __init__(self, trainloader, testloader, net, optimizer, criterion):
        self.trainloader = trainloader
        self.testloader = testloader
        self.net = net
        self.optimizer = optimizer
        self.criterion = criterion

    def train(self, epoch=100):
        self.net.train()
        for e in range(epoch):
            running_loss = 0.0
            for i, data in enumerate(self.trainloader, 0):
                inputs, labels = data[0].cuda(), data[1].cuda()
                self.optimizer.zero_grad()
                output = self.net(inputs)
                loss = self.criterion(output, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
                if i % 500 == 0:
                    print('[%d, %5d] loss: %.3f' % (e + 1, i + 1, running_loss / 500))
                    running_loss = 0.0
                    self.test()
        print('Finished Training')

    def test(self):
        self.net.eval()
        correct = 0
        for inputs, labels in self.testloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            output = self.net(inputs)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(labels.view_as(pred)).sum().item()
        print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
            correct, len(self.testloader.dataset), 100. * correct / len(self.testloader.dataset)))

    def get_conf(self):
        self.net.eval()
        conf_matrix = torch.zeros(10, 10)
        for inputs, labels in self.testloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            output = self.net(inputs)
            pred = torch.argmax(output, dim=1)
            for num in range(output.shape[0]):
                conf_matrix[pred[num], labels[num]] += 1
        return conf_matrix

# Define MNIST Network (Sigmoid)
class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.fc0 = nn.Linear(28 * 28, 30)
        self.fc1 = nn.Linear(30, 10)
        self.act = nn.Sigmoid()

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc0(x)
        x = self.act(x)
        x = self.fc1(x)
        return x

# Train using Sigmoid network
mnist_net = MNIST_Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mnist_net.parameters(), lr=0.01)

trainer = Trainer(trainloader=train_loader, testloader=test_loader, net=mnist_net, optimizer=optimizer, criterion=criterion)
trainer.train(epoch=10)
trainer.test()
plot_confusion_matrix(trainer.get_conf().cpu().numpy())

# Train using ReLU network
class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.fc0 = nn.Linear(28 * 28, 30)
        self.fc1 = nn.Linear(30, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc0(x)
        x = self.act(x)
        x = self.fc1(x)
        return x

mnist_net = MNIST_Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mnist_net.parameters(), lr=0.001)

trainer = Trainer(trainloader=train_loader, testloader=test_loader, net=mnist_net, optimizer=optimizer, criterion=criterion)
trainer.train(epoch=10)
trainer.test()
plot_confusion_matrix(trainer.get_conf().cpu().numpy())

# Count trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(mnist_net)

# Alternative counting loop
num = 0
for param in mnist_net.parameters():
    if param.requires_grad:
        num += param.numel()
print(num)