# Classifier utils

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def load_data(data_dir, dataset):
    if dataset == 'MNIST':
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        train_set = torchvision.datasets.MNIST(root=data_dir, train=True,
                                               transform=transforms.ToTensor(),download=True)
        test_set = torchvision.datasets.MNIST(root=data_dir, train=False,
                                               transform=transforms.ToTensor(),download=True)
        
    return train_set, test_set

In [None]:
def plot_images(images, targets):
    fig, ax = plt.subplots(2, 5, figsize=(10,5))
    fig.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    plt.rcParams.update({'font.size': 20})

    for i in range(2):
        for j in range(5):
            ax[i, j].imshow(images[i*5 + j], cmap='gray')
            ax[i, j].axis('off')
            ax[i, j].set_title(str(targets[i*5 + j].item()))
    plt.show()

In [None]:
class Net(nn.Module):        
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Net, self).__init__()
        self.input_dim = input_dim
        self.layers = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                   nn.ReLU(),
                                   nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim),
                                    nn.ReLU(),
                                   nn.Linear(hidden_dim, output_dim),
                                   #nn.Sigmoid())
                                    nn.LogSoftmax(dim=1))
        
    def forward(self, x):
        x = self.layers(x)
        return x

In [None]:
def init_weights(Layer):
    name = Layer.__class__.__name__
    if name == 'Linear':
        torch.nn.init.normal_(Layer.weight, mean=0, std=0.02)
        if Layer.bias is not None:
            torch.nn.init.constant_(Layer.bias, 0)

In [None]:
def binary_cross_entropy(probs, targets, eps=1e-6):
    loss = torch.log(probs + eps) * targets + torch.log(1 - probs + eps) * (1 - targets)
    return -torch.mean(loss)

In [None]:
def cross_entropy(logprobs, oh, eps=1e-6):
    loss = torch.sum(logprobs * oh, dim=1)
    return -torch.mean(loss)

# MNIST Binary Classifier

## Load data

In [None]:
mnist_train, mnist_test = load_data(r'datasets/', 'MNIST')

#mnist_train.data = (mnist_train.data.float() / 255. - 0.1307) / 0.3081
mnist_train.data = mnist_train.data.float()

mnist_train.data = mnist_train.data[[idx for idx in range(len(mnist_train.targets)) if mnist_train.targets[idx] in [0,1]]]
mnist_train.targets = mnist_train.targets[[idx for idx in range(len(mnist_train.targets)) if mnist_train.targets[idx] in [0,1]]]

dataloader = DataLoader(mnist_train, batch_size=128, shuffle=True)

In [None]:
#mnist_test.data = (mnist_test.data.float() / 255. - 0.1307) / 0.3081
mnist_test.data = mnist_test.data.float()

mnist_test.data = mnist_test.data[[idx for idx in range(len(mnist_test.targets)) if mnist_test.targets[idx] in [0,1]]]
mnist_test.targets = mnist_test.targets[[idx for idx in range(len(mnist_test.targets)) if mnist_test.targets[idx] in [0,1]]]

mnist_test.data = mnist_test.data.to(device)
mnist_test.targets = mnist_test.targets.int().to(device)

In [None]:
plot_images(mnist_train.data[:10], mnist_train.targets[:10])

## Init model

In [None]:
n_epochs = 1000
l_rate = 1e-4

input_dim = 28 * 28
hidden_dim = 512
output_dim = 1

net = Net(input_dim, hidden_dim, output_dim).to(device)
net.apply(init_weights)

optimizer = torch.optim.Adam(net.parameters(), lr=l_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.8)

In [None]:
net

## Train model

In [None]:
net.train()
x = mnist_train.data.view(mnist_train.data.shape[0], -1).to(device)
targets = mnist_train.targets.view(x.shape[0], 1).float().to(device)

losses = []
for i in range(n_epochs):
    x = x.view(x.shape[0], -1).to(device)
    targets = targets.view(x.shape[0], 1).float().to(device)
    probs = net(x)
    loss = binary_cross_entropy(probs, targets)
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    with torch.no_grad():
        probs = net(x)
        preds = (probs >= 0.5).int()
        accuracy = torch.sum(preds == targets.view(-1,1).int()).float() / preds.shape[0] * 100
        print('Epoch: {}/{} Loss: {:.4f} Accuracy (training set): {:.2f}%'.format(i+1, n_epochs, loss.item(), accuracy))

## Test model

In [None]:
plt.figure(figsize=(10,6))
plt.rcParams.update({'font.size': 10})
plt.plot(range(len(losses)), losses)

In [None]:
net.eval()

with torch.no_grad():
    data = mnist_test.data.view(mnist_test.data.shape[0], -1)
    probs = net(data)
    preds = (probs >= 0.5).int()
    accuracy = torch.sum(preds == mnist_test.targets.view(-1,1)).float() / preds.shape[0] * 100
    print('Accuracy on the test set: {:.2f}%'.format(accuracy))

In [None]:
plot_images(data[:10].view(-1,28,28).cpu(), preds[:10].cpu())

In [None]:
missed_vec = preds != mnist_test.targets.view(-1,1)
errors = mnist_test.data[missed_vec.reshape(-1)]

In [None]:
errors.shape

In [None]:
for error in errors:
    plt.imshow(error.cpu(), cmap='gray')
    plt.axis('off')
    plt.show()

# MNIST Classifier

## Load data

In [None]:
mnist_train, mnist_test = load_data(r'datasets/', 'MNIST')

dataloader = DataLoader(mnist_train, batch_size=128, shuffle=True)

mnist_test.data = mnist_test.data.to(device)
mnist_test.targets = mnist_test.targets.int().to(device)

In [None]:
plot_images(mnist_train.data[:10], mnist_train.targets[:10])

## Init model

In [None]:
n_epochs = 1000
l_rate = 1e-3

input_dim = 28 * 28
hidden_dim = 512
output_dim = 10

net = Net(input_dim, hidden_dim, output_dim).to(device)
net.apply(init_weights)

optimizer = torch.optim.Adam(net.parameters(), lr=l_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.8)

In [None]:
net

## Train model

In [None]:
net.train()

losses = []
for i in range(n_epochs):
    loss_ac = 0
    for j, (x, targets) in enumerate(dataloader):
        x = x.view(x.shape[0], -1).float().to(device)
        
        targets = targets.view(-1,1).to(device)

        oh = torch.zeros(targets.shape[0], 10).to(device)
        oh.scatter_(1, targets.long(), 1)
        
        probs = net(x)

        loss = cross_entropy(probs, oh)
        losses.append(loss.item())
        loss_ac += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    scheduler.step()
    
    if i%5 == 0:
        with torch.no_grad():
            data = mnist_test.data.float().view(mnist_test.data.shape[0], -1)
            probs = net(data)
            preds = torch.argmax(probs, dim=1).int()
            accuracy = torch.sum(preds == mnist_test.targets).float() / preds.shape[0] * 100
            print('Accuracy on the test set: {:.2f}%'.format(accuracy))
    print('Epoch: {}/{} Loss: {:.4f}'.format(i+1, n_epochs, loss_ac / (j+1)))

## Test model

In [None]:
plt.figure(figsize=(10,6))
plt.rcParams.update({'font.size': 10})
plt.plot(range(len(losses)), losses)

In [None]:
net.eval()

with torch.no_grad():
    data = mnist_test.data.float().view(mnist_test.data.shape[0], -1)
    probs = net(data)
    preds = torch.argmax(probs, dim=1).int()
    accuracy = torch.sum(preds == mnist_test.targets).float() / preds.shape[0] * 100
    print('Accuracy on the test set: {:.2f}%'.format(accuracy))

In [None]:
plot_images(data[:10].view(-1,28,28).cpu(), preds[:10].cpu())

In [None]:
missed_vec = preds != mnist_test.targets
errors = mnist_test.data[missed_vec.reshape(-1)]
mispreds = preds[missed_vec]

In [None]:
errors.shape

In [None]:
for error, mispred in zip(errors[:5], mispreds[:5]):
    plt.imshow(error.cpu(), cmap='gray')
    plt.title(str(mispred.item()))
    plt.axis('off')
    plt.show()