In [76]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np


#If cuda device exists, use that. If not, default to CPU.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [77]:
#Hyperparameters
num_epochs = 15
num_classes = 10
batch_size = 128
learning_rate = 0.001

In [78]:
#Load MNIST data set
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

#Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

In [84]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dense1 = nn.Linear(12*12*64, 128)
        self.dense2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.dropout(x, 0.25)
        x = x.view(-1, 12*12*64)
        x = F.relu(self.dense1(x))
        x = F.dropout(x, 0.5)
        x = self.dense2(x)
        return F.log_softmax(x, dim=1)
    


In [85]:
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))


Epoch [1/15], Step [100/469], Loss: 0.1666
Epoch [1/15], Step [200/469], Loss: 0.2753
Epoch [1/15], Step [300/469], Loss: 0.2462
Epoch [1/15], Step [400/469], Loss: 0.1169
Epoch [2/15], Step [100/469], Loss: 0.0327
Epoch [2/15], Step [200/469], Loss: 0.0238
Epoch [2/15], Step [300/469], Loss: 0.0293
Epoch [2/15], Step [400/469], Loss: 0.0598
Epoch [3/15], Step [100/469], Loss: 0.0179
Epoch [3/15], Step [200/469], Loss: 0.0577
Epoch [3/15], Step [300/469], Loss: 0.0275
Epoch [3/15], Step [400/469], Loss: 0.0228
Epoch [4/15], Step [100/469], Loss: 0.0051
Epoch [4/15], Step [200/469], Loss: 0.0139
Epoch [4/15], Step [300/469], Loss: 0.0048
Epoch [4/15], Step [400/469], Loss: 0.0033
Epoch [5/15], Step [100/469], Loss: 0.0081
Epoch [5/15], Step [200/469], Loss: 0.0044
Epoch [5/15], Step [300/469], Loss: 0.0084
Epoch [5/15], Step [400/469], Loss: 0.0011
Epoch [6/15], Step [100/469], Loss: 0.0077
Epoch [6/15], Step [200/469], Loss: 0.0072
Epoch [6/15], Step [300/469], Loss: 0.0446
Epoch [6/15

In [87]:
from sklearn.metrics import roc_auc_score

preds = []
y_true = []
# Test the model
model.eval()  # Set model to evaluation mode.
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()    
        detached_pred = predicted.detach().cpu().numpy()
        detached_label = labels.detach().cpu().numpy()
        for f in range(0, len(detached_pred)):
            preds.append(detached_pred[f])
            y_true.append(detached_label[f])
        
    print('Test Accuracy of the model on the 10000 test images: {:.2%}'.format(correct / total))
    
    preds = np.eye(num_classes)[preds]
    y_true = np.eye(num_classes)[y_true]    
    auc = roc_auc_score(preds, y_true)
    print("AUC: {:.2%}".format (auc))
# Save the model checkpoint
torch.save(model.state_dict(), 'pytorch_mnist_cnn.ckpt')

Test Accuracy of the model on the 10000 test images: 99.07%
AUC: 99.48%


In [72]:
with torch.no_grad():
    correct = 0
    total = 0
    predss = []
    matches = []
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        for f in range(0, len(predicted.detach().cpu().numpy())):
            predss.append(predicted.detach().cpu().numpy()[f])
            matches.append(labels.detach().cpu().numpy()[f])
            
#     preds = np.eye(num_classes)[predicted.detach().cpu().numpy()]
#     y_true = np.eye(num_classes)[labels.detach().cpu().numpy()]
#     print(predss, matches)
    preds = np.eye(num_classes)[predss]
    y_true = np.eye(num_classes)[matches]
    
    print(preds == y_true)
    auc = roc_auc_score(preds, y_true)
    print("AUC: {:.2%}".format (auc))

125 128
[7 2 1 0 4 1 4 9 6 9 0 6 9 0 1 5 9 7 5 4 9 6 6 5 4 0 7 4 0 1 3 1 3 4 7 2 7
 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 5 3 7 4 6 4 3 0 7 0 2 9
 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6 9 6 0 5 4 9 9 2 1 9 4 8
 7 3 9 7 4 4 4 9 2 5 4 7 6 7 9 0 5]
[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 4 7 2 7
 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4 6 4 3 0 7 0 2 9
 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6 9 6 0 5 4 9 9 2 1 9 4 8
 7 3 9 7 4 4 4 9 2 5 4 7 6 7 9 0 5]
126 128
[8 5 6 6 5 7 8 1 0 1 6 4 6 7 3 1 7 1 8 2 0 2 9 8 5 5 1 5 6 0 3 4 8 6 5 4 6
 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 9 2 5 0 1 1 1 0 9 0 3 1 6 4 2 3 6
 1 1 1 3 9 5 2 9 4 5 9 3 9 0 3 6 5 5 7 2 2 7 1 2 8 4 1 7 3 3 8 8 7 9 2 2 4
 1 5 9 8 7 2 3 0 4 4 2 4 1 9 5 7 7]
[8 5 6 6 5 7 8 1 0 1 6 4 6 7 3 1 7 1 8 2 0 2 9 9 5 5 1 5 6 0 3 4 4 6 5 4 6
 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 9 2 5 0 1 1 1 0 9 0 3 1 6 4 2 3 6
 1 1 1 3 9 5 2 9 4 5 9 3 9 0 3 6 5 5 7 2 2 7 1 2 8 

 3 3 0 1 3 7 5 8 6 6 1 0 4 2 3 4 6]
[4 6 1 3 4 7 3 2 3 4 2 5 2 7 1 7 2 6 4 1 5 7 8 6 0 1 8 2 5 7 7 6 9 3 5 8 4
 2 4 0 8 8 3 4 9 2 7 5 8 6 5 6 0 8 6 7 3 6 4 9 4 6 6 3 2 4 1 0 1 4 6 2 9 1
 1 0 6 3 9 5 6 5 6 5 8 4 6 4 3 9 1 3 4 1 9 1 7 1 1 9 3 5 4 0 7 3 6 1 7 5 5
 3 3 0 1 5 7 5 8 6 5 1 0 4 2 3 4 6]
124 128
[7 9 8 1 8 4 9 2 8 6 2 7 0 0 6 7 5 8 6 0 9 3 7 1 3 5 4 3 3 5 5 6 3 0 2 3 4
 2 3 0 9 9 4 7 2 8 4 7 0 6 2 8 5 2 8 5 7 3 0 8 2 3 2 8 2 5 5 7 6 4 0 8 4 8
 2 7 4 5 2 0 3 8 9 6 7 2 5 1 1 1 2 3 6 7 8 7 6 4 8 9 4 8 6 3 8 3 1 0 6 2 2
 5 6 9 5 8 1 4 1 7 8 4 6 1 8 4 3 1]
[7 9 8 1 8 4 9 2 8 6 2 7 0 0 6 7 5 8 6 0 9 3 7 1 3 5 4 3 3 5 5 6 3 0 2 3 4
 2 3 0 9 9 4 7 2 8 4 7 0 6 2 8 5 2 8 5 7 3 0 8 2 3 2 8 2 5 5 7 6 4 6 8 4 8
 2 7 4 5 2 0 3 9 4 6 7 2 5 6 1 1 2 3 6 7 8 7 6 4 8 9 4 8 6 3 8 3 1 0 6 2 2
 5 6 9 5 8 1 4 1 7 8 4 6 1 8 4 3 1]
125 128
[2 8 0 8 5 9 1 4 2 0 2 7 0 8 0 2 5 7 6 7 9 4 2 6 2 4 4 8 0 4 4 5 8 0 6 8 9
 8 5 6 9 0 4 8 7 1 3 4 5 8 0 9 1 3 3 6 9 8 7 1 0 5 7 1 7 5 2 7 9 1 8 5 2 4
 9 4 7 2 2 3 4 

 8 5 3 8 5 4 5 2 0 5 6 3 2 8 3 9 9]
127 128
[5 7 9 4 6 7 1 3 7 3 6 6 0 9 0 1 9 9 2 8 8 0 1 6 9 7 5 3 4 7 4 9 9 4 3 6 3
 1 1 7 6 9 1 8 4 1 1 9 9 4 3 6 8 1 6 0 4 1 3 7 7 4 9 5 1 0 0 1 1 6 2 1 9 8
 4 0 3 6 4 9 0 7 1 6 5 7 5 2 5 1 8 5 4 7 0 6 7 0 2 5 8 1 0 4 5 7 1 8 5 1 3
 0 0 6 0 7 3 1 8 3 9 7 0 0 8 9 5 9]
[5 7 9 4 6 7 1 3 7 3 6 6 0 9 0 1 9 9 2 8 8 0 1 6 9 7 5 3 4 7 4 9 9 4 3 6 3
 1 1 7 6 9 1 8 4 1 1 9 9 4 3 6 8 1 6 0 4 1 3 7 7 4 9 5 1 0 0 1 1 6 2 1 9 8
 4 0 3 6 4 9 0 7 1 6 5 7 5 2 5 1 8 5 4 7 0 6 7 0 2 5 8 1 0 4 5 7 1 8 5 1 9
 0 0 6 0 7 3 1 8 3 9 7 0 0 8 9 5 9]
126 128
[8 3 2 7 2 9 7 2 1 1 3 7 5 3 1 9 8 2 2 2 8 8 5 7 3 8 9 8 8 6 8 2 3 9 7 5 6
 2 9 2 8 8 1 6 8 8 7 9 1 8 0 1 7 2 0 7 5 1 9 0 2 0 9 8 6 2 3 9 3 8 0 2 1 1
 1 1 4 2 9 7 7 5 1 1 2 1 9 9 9 1 0 2 0 2 1 1 4 6 4 1 5 4 9 7 7 7 5 6 2 2 2
 8 0 6 9 6 1 9 7 7 1 4 8 5 3 4 3 4]
[8 3 2 7 2 9 7 2 1 1 3 7 5 3 1 9 8 2 2 2 8 8 5 7 3 8 9 8 8 6 8 2 3 9 7 5 6
 2 9 2 8 8 1 6 8 8 7 9 1 8 0 1 7 2 0 7 5 1 9 0 2 0 9 8 6 2 3 9 3 8 0 2 1 1
 1 1 4 2 9 7 2 

 7 9 8 6 0 8 1 7 7 1 3 2 3 1 4 2 0]
[4 1 0 5 3 1 9 2 1 9 6 0 4 6 1 7 3 8 7 2 9 6 5 8 3 5 7 1 6 1 0 9 6 2 5 4 2
 3 4 4 6 0 0 2 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9
 8 6 5 0 6 8 9 4 1 9 5 3 0 4 8 9 1 4 0 5 5 2 1 5 4 0 7 6 0 1 7 0 6 8 9 5 1
 7 9 8 6 0 8 1 7 7 1 3 2 3 1 4 2 0]
125 128
[0 7 8 4 6 4 9 3 8 4 7 2 5 6 3 6 9 6 3 2 2 4 6 9 0 2 5 5 1 3 3 9 7 8 7 2 2
 5 7 9 8 2 1 3 1 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 5 0 1 2 3 4 5 6 7
 8 5 1 2 6 5 3 0 7 0 4 1 4 3 6 7 2 3 1 2 1 2 5 6 0 1 3 0 2 7 5 7 6 2 9 1 9
 0 6 0 6 0 2 0 6 1 5 8 4 3 0 1 5 4]
[0 7 8 4 6 4 9 3 8 4 7 2 5 6 3 6 9 6 3 2 2 4 6 9 0 2 5 5 1 3 3 9 7 8 7 2 2
 5 7 9 8 2 1 3 1 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
 8 9 1 2 6 5 3 0 7 0 4 1 4 3 6 7 2 3 1 2 1 2 9 6 0 1 3 0 2 7 5 7 6 2 9 1 9
 0 6 0 6 0 2 0 6 1 5 8 4 3 0 1 5 4]
124 128
[4 8 5 7 5 7 8 3 4 8 8 5 2 5 7 1 3 8 1 0 7 5 3 6 9 4 7 7 5 8 3 4 4 3 8 6 2
 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 8 3 9 5 5 2
 6 8 4 9 1 7 1 

 8 1 8 2 8 7 6 8 7 1 6 2 9 3 0 1 2]
128 128
[3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 8 9 5 7 0 3 1 6 8 4
 1 5 6 4 2 7 8 1 3 4 3 4 7 2 0 5 0 1 9 2 3 2 3 5 5 7 8 4 9 9 7 1 1 9 0 7 8
 3 4 8 6 3 8 0 9 6 2 1 0 1 0 6 2 3 8 9 0 7 2 3 4 5 5 2 8 5 4 6 6 6 7 9 1 8
 2 1 5 3 4 7 9 4 0 0 0 1 2 3 4 5 6]
[3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 8 9 5 7 0 3 1 6 8 4
 1 5 6 4 2 7 8 1 3 4 3 4 7 2 0 5 0 1 9 2 3 2 3 5 5 7 8 4 9 9 7 1 1 9 0 7 8
 3 4 8 6 3 8 0 9 6 2 1 0 1 0 6 2 3 8 9 0 7 2 3 4 5 5 2 8 5 4 6 6 6 7 9 1 8
 2 1 5 3 4 7 9 4 0 0 0 1 2 3 4 5 6]
128 128
[7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 9 0 1 3 1 5 1 2 4 9 2 4 6 8 0 1 1
 9 2 6 6 8 7 4 2 9 7 0 2 1 0 3 6 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
 1 2 3 4 5 6 7 8 9 8 6 5 9 7 0 2 3 4 3 8 5 1 5 2 3 0 1 2 1 3 2 6 5 3 0 7 2
 7 4 6 4 0 5 9 9 8 9 5 3 1 7 4 7 6]
[7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 9 0 1 3 1 5 1 2 4 9 2 4 6 8 0 1 1
 9 2 6 6 8 7 4 2 9 7 0 2 1 0 3 6 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0
 1 2 3 4 5 6 7 

 8 7 1 5 2 4 9 4 3 6 4 1 7 2 6 5 0]
[9 8 6 7 3 4 6 8 7 0 4 2 7 7 5 4 3 4 2 8 1 5 1 0 2 3 3 5 7 0 6 8 6 3 9 9 8
 2 7 7 1 0 1 7 8 9 0 1 2 3 4 5 6 7 8 0 1 2 3 4 7 8 9 7 8 6 4 1 9 3 8 4 4 7
 0 1 9 2 8 7 8 2 6 0 6 5 3 3 3 9 1 4 0 6 1 0 0 6 2 1 1 7 7 8 4 6 0 7 0 3 6
 8 7 1 5 2 4 9 4 3 6 4 1 7 2 6 5 0]
16 16
[1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6]
[1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6]
[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]
AUC: 99.32%


In [None]:
# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    
    
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=0),
#             nn.BatchNorm2d(16),
            nn.ReLU())
#             nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3 = nn.Sequential(
            
        
        )
        self.fc = nn.Linear(7*7*32, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

model = ConvNet(num_classes).to(device)