# 1. Data & Preprocessing

In [133]:
import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.autograd import Variable
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets, transforms

import argparse

from sklearn.metrics import confusion_matrix

## Datasets, dataloaders and transformations

In [2]:
# Random seed (for reproducibility)
np.random.seed(9)
batch_size = 128

MNIST - for training

In [3]:
mnist_transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize(
                                  (0.1307,),
                                  (0.3081,)), # mnist mean & std
                              transforms.Resize(28)])

mnist_all_train_set = torchvision.datasets.MNIST(root='./data',
                                                 train=True,
                                                 download=True,
                                                 transform=mnist_transform)

# Train, validation splits
mnist_train_size = int(0.8*len(mnist_all_train_set))
mnist_validation_size = len(mnist_all_train_set) - mnist_train_size

mnist_train_set, mnist_validation_set = torch.utils.data.random_split(
    mnist_all_train_set, [mnist_train_size, mnist_validation_size])
mnist_test_set = torchvision.datasets.MNIST(root='./data',
                                            train=False,
                                            download=True,
                                            transform=mnist_transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
mnist_train_loader = torch.utils.data.DataLoader(mnist_train_set,
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=2)
mnist_validation_loader = torch.utils.data.DataLoader(mnist_validation_size,
                                                      batch_size=batch_size,
                                                      shuffle=False,
                                                      num_workers=2)
mnist_test_loader = torch.utils.data.DataLoader(mnist_test_set,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=2)

CIFAR10 as OOD dataset - for evaluation

In [143]:
cifar10_transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize(
                                  (0.4914, 0.4822, 0.4465),
                                  (0.247, 0.243, 0.261)), # cifar10 mean & std
                              transforms.Resize(28),
                              transforms.Grayscale(num_output_channels=1)])

cifar10_data_set = torchvision.datasets.CIFAR10(root='./data',
                                              train=False,
                                              download=True,
                                              transform=cifar10_transform)

cifar10_data_set = torch.utils.data.Subset(cifar10_data_set, np.random.choice(range(len(cifar10_data_set)), size=500))

Files already downloaded and verified


In [54]:
class sub_data_set(Dataset):
  """Combined dataset."""
  def __init__(self, dataset, targets):
    self.dataset = dataset
    self.targets = targets
      
  def __getitem__(self, idx):
      image = self.dataset[idx][0]
      target = self.targets[idx]
      return (image, target)

  def __len__(self):
      return len(self.dataset)

In [146]:
cifar10_size = 500

sub_cifar10_data = torch.utils.data.Subset(cifar10_data_set, np.random.choice(range(len(cifar10_data_set)), size=cifar10_size))
sub_cifar10_targets = [10 for i in range(cifar10_size)]

sub_cifar10_data_set = sub_data_set(sub_cifar10_data, sub_cifar10_targets)

sub_cifar10_test_loader = torch.utils.data.DataLoader(sub_cifar10_data_set,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=2)

combined_test_loader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([mnist_test_set, sub_cifar10_data_set]),
        batch_size=batch_size, shuffle=True)


# 2. Models

In [60]:
def out_size(W, F, S, P):
  return ((W-F+2*P) // S) + 1

Baseline model class

In [61]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32,64, kernel_size=5)
        self.fc1 = nn.Linear(3*3*64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        #x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.avg_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.avg_pool2d(self.conv3(x),2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.view(-1,3*3*64 )
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

OSR model class

In [70]:
class OSR_CNN(nn.Module):
    def __init__(self):
        super(OSR_CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32,64, kernel_size=5)
        self.fc1 = nn.Linear(3*3*64, 256)
        self.fc2 = nn.Linear(256, 11)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        #x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.avg_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.avg_pool2d(self.conv3(x),2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.view(-1,3*3*64 )
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 3. Training

In [105]:
def fit(model, train_loader):
    it = iter(train_loader)
    X_batch, y_batch = next(it)
    optimizer = torch.optim.Adam(model.parameters())#,lr=0.001, betas=(0.9,0.999))
    error = nn.CrossEntropyLoss()
    EPOCHS = 1 # return to 5 at the end
    model.train()
    for epoch in range(EPOCHS):
        correct = 0
        for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
            var_X_batch = Variable(X_batch).float()
            var_y_batch = Variable(y_batch)
            optimizer.zero_grad()
            output = model(var_X_batch)
            loss = error(output, var_y_batch)
            loss.backward()
            optimizer.step()

            # Total correct predictions
            predicted = torch.max(output.data, 1)[1] 
            correct += (predicted == var_y_batch).sum()
            #print(correct)
            if batch_idx % 50 == 0:
                print('Epoch : {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t Accuracy:{:.3f}%'.format(
                    epoch,
                    batch_idx*len(X_batch),
                    len(train_loader.dataset),
                    100.*batch_idx / len(train_loader),
                    loss.data,
                    float(correct*100) / float(batch_size*(batch_idx+1))))


Training procedure for baseline model

In [104]:
cnn = CNN()
fit(cnn, mnist_train_loader)



Training procedure for OSR model

In [106]:
osr_cnn = OSR_CNN()
fit(osr_cnn, mnist_train_loader)



# 4. Evaluation

In [127]:
dictionary = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0, 10:1}

def evaluate(model, labels, test_loader, binary=False):
  y_true = []
  y_pred = []
  correct = 0 
  for test_imgs, test_labels in test_loader:
    
    # print(test_imgs.shape)
    test_imgs = Variable(test_imgs).float()
    output = model(test_imgs)
    predicted = torch.max(output,1)[1]
    correct += (predicted == test_labels).sum()

    y_true += list(test_labels.detach().cpu())
    y_pred += list(predicted.detach().cpu())

  if binary:
    y_true = [dictionary[int(y.detach().cpu())] for y in y_true]
    y_pred = [dictionary[int(y.detach().cpu())] for y in y_pred]

  print("Test accuracy:{:.3f}% ".format( float(correct*100) / (len(test_loader)*batch_size)))
  a = confusion_matrix(y_true, y_pred, labels=labels)
  return a

Baseline results

In [139]:
baseline_c_matrix = evaluate(cnn, range(10), mnist_test_loader)
print(baseline_c_matrix)

Test accuracy:94.996% 
[[ 961    0    3    2    1    2    4    3    3    1]
 [   0 1118    5    1    1    2    2    4    2    0]
 [   4    4 1000    4    2    2    1   11    3    1]
 [   1    1    7  968    0   17    0   10    6    0]
 [   1    3    1    0  931    2   10    3    2   29]
 [   3    1    0   21    1  857    4    2    0    3]
 [   9    1    4    0    6   11  926    0    1    0]
 [   2    5   19    5    1    1    0  986    1    8]
 [   5    0    9    4    4   17    9    3  908   15]
 [   6    4    1    5   13   14    1   11    3  951]]


OSR rational

:TODO

OOD results

In [147]:
ood_c_matrix = evaluate(osr_cnn, range(2), sub_cifar10_test_loader, True)
print(ood_c_matrix)

Test accuracy:0.000% 
[[  0   0]
 [500   0]]


OSR results

In [148]:
osr_c_matrix = evaluate(osr_cnn, range(11), combined_test_loader)
print(osr_c_matrix)

Test accuracy:90.258% 
[[ 953    0    1    0    3    3    8    1    6    5    0]
 [   1 1113    5    0    2    1    1    2   10    0    0]
 [   6    5  987    1    1    0    3   17   12    0    0]
 [   1    0    5  975    0    7    0   11    9    2    0]
 [   0    0    3    1  954    0    6    0    0   18    0]
 [   2    0    0   20    0  826    8    2   25    9    0]
 [   4    4    1    0   13    4  925    0    7    0    0]
 [   1    5   26    4    1    2    0  968    7   14    0]
 [   4    1   12    4    5    0    1    2  938    7    0]
 [   2    1    0    4   17    7    1    9   18  950    0]
 [  15    9  115   89   15  161   16   48   17   15    0]]
