In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.datasets import mnist
from sklearn.model_selection import KFold
from sklearn import metrics

import torch.nn.functional as TF
import torch.optim as optim
import os
import math
import matplotlib.pyplot as plt
import pickle

torch.set_num_threads(1)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [13]:
# (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)

transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,)),])
test_transform = transforms.Compose([transforms.Resize((32,32)),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.1307,), (0.3081,)),])

train_batch_size = 512
val_batch_size = 300
test_batch_size = 100
train_dataset = mnist.MNIST(root='data/MNIST/train', train=True,
                              download=True,transform=transform)
test_dataset = mnist.MNIST(root='data/MNIST/test', train=False,
                             download=True, transform=test_transform)
# train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=test_batch_size)

In [14]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.feature_extractor = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=10),
        )


    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits

In [4]:
LOG_DIR = 'Logs/'
SAVE_DIR = 'Models/'

for SEED in range(5):
    torch.manual_seed(SEED)
    train, val = random_split(train_dataset,[int(0.9*len(train_dataset)),int(0.1*len(train_dataset))])
    
#     all_index = np.arange(len(train_dataset))
#     np.random.shuffle(all_index)
#     train_index = all_index[0:int(0.9*len(train_dataset))]
#     val_index = all_index[int(0.9*len(train_dataset)):]
    
    train_loader = DataLoader(train, shuffle=True, batch_size=train_batch_size)
    val_loader = DataLoader(val, shuffle=True, batch_size=train_batch_size)


    model = LeNet5().double().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

    cost = nn.CrossEntropyLoss()
    epoch = 20
    train_loss = []
    val_loss = []
    LEARN_SEED = 42
    torch.manual_seed(LEARN_SEED)
    best_val_acc = 0.0
    for _epoch in range(epoch):
        for idx, (train_x, train_label) in enumerate(train_loader):
            train_x, train_label = train_x.double().to(device), train_label.to(device)
            optimizer.zero_grad()
            outputs = model(train_x)
            loss = cost(outputs, train_label)
            loss.backward()
            optimizer.step()
            if idx % 50 == 0:
                print('Epoch:%d, idx:%d, loss:%.6f'%(_epoch, idx, loss.sum().item()))
            train_loss.append(loss.sum().item())

        correct = 0
        _sum = 0

        for idx, (val_x, val_label) in enumerate(val_loader):
            val_x, val_label = val_x.double().to(device), val_label.to(device)
            outputs = model(val_x).detach()
            t_loss = cost(outputs, val_label)
            predict_ys = torch.argmax(outputs, axis=-1)
            _ = predict_ys.detach().data == val_label
            correct += torch.sum(_, axis=-1)
            _sum += _.shape[0]
            val_loss.append(t_loss.sum().item())
        val_acc = 100*correct / _sum
        print('Validation accuracy: %.4f'%val_acc)
    
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_ckpt = {'net':model.state_dict(),
                        'optim':optimizer.state_dict(),
                        'epoch':_epoch,
                        'val_acc':best_val_acc}
            best_save_path = SAVE_DIR + "MNIST_CNN_Val_SEED_%d_model"%SEED
            torch.save(best_ckpt, best_save_path)
            
    log_save_path = LOG_DIR + "MNIST_CNN_Val_SEED_%d_log"%SEED

    pickle.dump([train_loss, val_loss], open(log_save_path,"wb"))

Epoch:0, idx:0, loss:2.303470
Epoch:0, idx:50, loss:1.664604
Epoch:0, idx:100, loss:0.609361
Validation accuracy: 84.1500
Epoch:1, idx:0, loss:0.607857
Epoch:1, idx:50, loss:0.399100
Epoch:1, idx:100, loss:0.390668
Validation accuracy: 89.9667
Epoch:2, idx:0, loss:0.329201
Epoch:2, idx:50, loss:0.294618
Epoch:2, idx:100, loss:0.319652
Validation accuracy: 91.4500
Epoch:3, idx:0, loss:0.274068
Epoch:3, idx:50, loss:0.250187
Epoch:3, idx:100, loss:0.165481
Validation accuracy: 93.1333
Epoch:4, idx:0, loss:0.243066
Epoch:4, idx:50, loss:0.211686
Epoch:4, idx:100, loss:0.168831
Validation accuracy: 94.5833
Epoch:5, idx:0, loss:0.166868
Epoch:5, idx:50, loss:0.200319
Epoch:5, idx:100, loss:0.198834
Validation accuracy: 95.4333
Epoch:6, idx:0, loss:0.118862
Epoch:6, idx:50, loss:0.152230
Epoch:6, idx:100, loss:0.104994
Validation accuracy: 96.2500
Epoch:7, idx:0, loss:0.116940
Epoch:7, idx:50, loss:0.140235
Epoch:7, idx:100, loss:0.125654
Validation accuracy: 96.6167
Epoch:8, idx:0, loss:0.1

Epoch:6, idx:100, loss:0.129866
Validation accuracy: 95.5000
Epoch:7, idx:0, loss:0.117183
Epoch:7, idx:50, loss:0.130889
Epoch:7, idx:100, loss:0.122367
Validation accuracy: 95.8667
Epoch:8, idx:0, loss:0.178375
Epoch:8, idx:50, loss:0.128246
Epoch:8, idx:100, loss:0.139144
Validation accuracy: 96.3000
Epoch:9, idx:0, loss:0.094556
Epoch:9, idx:50, loss:0.091627
Epoch:9, idx:100, loss:0.103609
Validation accuracy: 96.6500
Epoch:10, idx:0, loss:0.084793
Epoch:10, idx:50, loss:0.089521
Epoch:10, idx:100, loss:0.065698
Validation accuracy: 96.8667
Epoch:11, idx:0, loss:0.090284
Epoch:11, idx:50, loss:0.068910
Epoch:11, idx:100, loss:0.088452
Validation accuracy: 97.1667
Epoch:12, idx:0, loss:0.072731
Epoch:12, idx:50, loss:0.060135
Epoch:12, idx:100, loss:0.085572
Validation accuracy: 97.2833
Epoch:13, idx:0, loss:0.116088
Epoch:13, idx:50, loss:0.081193
Epoch:13, idx:100, loss:0.074829
Validation accuracy: 97.4167
Epoch:14, idx:0, loss:0.069160
Epoch:14, idx:50, loss:0.053970
Epoch:14, 

In [13]:
for SEED in range(5):
    best_save_path = SAVE_DIR + "MNIST_CNN_Val_SEED_%d_model"%SEED
    print(torch.load(best_save_path)['val_acc'])

tensor(98.5333, device='cuda:0')
tensor(98.1167, device='cuda:0')
tensor(98.0667, device='cuda:0')
tensor(97.8333, device='cuda:0')
tensor(98.4500, device='cuda:0')


In [46]:
SAVE_DIR = 'Models/'
for SEED in range(5):
    torch.manual_seed(SEED)
    train, val = random_split(train_dataset,[int(0.9*len(train_dataset)),int(0.1*len(train_dataset))])
    
    train_loader = DataLoader(train, shuffle=True, batch_size=train_batch_size)
    val_loader = DataLoader(val, shuffle=True, batch_size=val_batch_size)
    
    best_save_path = SAVE_DIR + "MNIST_CNN_Val_SEED_%d_model"%SEED
    model = LeNet5().double().to(device)
    model.load_state_dict(torch.load(best_save_path)['net'])
    
    y_vals = []
    y_vals_onehot = []
    y_outputs = []
    y_preds = []
    for idx, (val_x, val_label) in enumerate(val_loader):
        val_x, val_label = val_x.double().to(device), val_label.to(device)
        y_vals.append(val_label.cpu().data)
        y_vals_onehot.append(TF.one_hot(val_label.cpu().data,10).numpy())
        outputs = model(val_x).detach()
        y_output = TF.softmax(outputs,-1)
        y_outputs.append(y_output.detach().cpu().data.numpy())
        y_pred = torch.argmax(outputs, axis=-1)
        y_preds.append(y_pred.detach().cpu().data.numpy())
        
    y_vals = torch.stack(y_vals,0).numpy()
    y_vals = np.array(y_vals).reshape([-1,1])
    y_vals_onehot = np.eye(10)[y_vals].reshape([-1,10])
    y_preds = np.array(y_preds).reshape([-1,1])
    y_outputs = np.array(y_outputs).reshape([-1,10])
    print(metrics.accuracy_score(y_vals,y_preds))
    print(metrics.f1_score(y_vals,y_preds,average=None))
    print(metrics.roc_auc_score(y_vals_onehot,y_outputs,average=None))

0.9853333333333333
[0.99078341 0.99157088 0.98426573 0.97807757 0.98572628 0.98217469
 0.99268887 0.98621249 0.97757848 0.98206661]
[0.99990411 0.99995074 0.99985258 0.99970153 0.99949382 0.99977813
 0.99997487 0.99989972 0.99971396 0.99977597]
0.9811666666666666
[0.98657718 0.98302583 0.98623064 0.98204668 0.98109966 0.97798165
 0.98913952 0.98057498 0.98260149 0.96266234]
[0.99988312 0.99988665 0.99990359 0.99936676 0.999822   0.99962842
 0.99992579 0.99981884 0.99981663 0.99906824]
0.9806666666666667
[0.98181818 0.99303944 0.9798995  0.97072419 0.9844898  0.98314108
 0.98836168 0.97848606 0.97670406 0.97026338]
[0.99984604 0.99981402 0.99983228 0.9983117  0.99944572 0.99847921
 0.99982612 0.99952137 0.9996972  0.99873196]
0.9783333333333334
[0.98537477 0.98459281 0.97978981 0.97838271 0.9788315  0.98084291
 0.98328936 0.97709924 0.96611642 0.96920583]
[0.99983525 0.99982354 0.99985391 0.99979578 0.99827092 0.99945016
 0.99952999 0.99967708 0.99960317 0.99888389]
0.9845
[0.98848684 0

In [48]:
y_te = test_dataset.targets.numpy()
LOG_DIR = 'Logs/'
SAVE_DIR = 'Models/'

SEED = 0
best_save_path = SAVE_DIR + "MNIST_CNN_Val_SEED_%d_model"%SEED
model = LeNet5().double().to(device)
model.load_state_dict(torch.load(best_save_path)['net'])
y_vals = []
y_vals_onehot = []
y_outputs = []
y_preds = []
for idx, (val_x, val_label) in enumerate(test_loader):
    val_x, val_label = val_x.double().to(device), val_label.to(device)
    y_vals.append(val_label.cpu().data)
    y_vals_onehot.append(TF.one_hot(val_label.cpu().data,10).numpy())
    outputs = model(val_x).detach()
    y_output = TF.softmax(outputs,-1)
    y_outputs.append(y_output.detach().cpu().data.numpy())
    y_pred = torch.argmax(outputs, axis=-1)
    y_preds.append(y_pred.detach().cpu().data.numpy())

y_vals = torch.stack(y_vals,0).numpy()
y_vals = np.array(y_vals).reshape([-1,1])
y_vals_onehot = np.eye(10)[y_vals].reshape([-1,10])
y_preds = np.array(y_preds).reshape([-1,1])
y_outputs = np.array(y_outputs).reshape([-1,10])
print(metrics.accuracy_score(y_vals,y_preds))
print(metrics.f1_score(y_vals,y_preds,average=None))
print(metrics.roc_auc_score(y_vals_onehot,y_outputs,average=None))

0.9848
[0.98528666 0.99164835 0.98742747 0.98422091 0.98729029 0.98378983
 0.98591549 0.98242188 0.98303342 0.97590361]
[0.99991923 0.99997287 0.999894   0.99987015 0.99992592 0.99984971
 0.99989979 0.99971084 0.99986771 0.99972332]
