## Import Libraries

 Requirements
 * numpy
 * torch
 * matplotlib
 * torchvision - for datasets
 * scipy for loading the datasets
 
 For setting up an environment please follow the instructions in https://github.com/AbinavRavi/OOD-detection/blob/master/README.md 
 After setting up the environment please run the cells one by one. The notebook should run without a GPU. 

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import datasets as datasets
import torchvision.transforms as T
import os
import sys

### Load the data

The In -distrbution dataset that we will use in this demonstration is MNIST and the Out of distribution dataset is SVHN. 

Reasons for choosing SVHN as Out of Distribution dataset
1. Similar size
2. Different intensity

In a production pipeline most images undergo same pre-processing and hence we don't need to choose the dataset during inference. Since this is a toy example we select an appropriate dataset.

In [2]:
transform = T.ToTensor()
path = './IDdata/'
if(os.path.exists(path) == False):
    train_data = datasets.MNIST(root='IDdata', train=True,download=True, transform=transform)
    test_data = datasets.MNIST(root='IDdata',train=False,download=True,transform=transform)
else:
    train_data = datasets.MNIST(root='IDdata', train=True,download=False, transform=transform)
    test_data = datasets.MNIST(root='IDdata',train=False,download=False,transform=transform)

In [3]:
transforms = T.Compose([T.ToTensor(), T.Resize(28)])
out_path = './OODdata/'
if(os.path.exists(out_path) == False):
    ood_data = datasets.FashionMNIST(root='OODdata', train=True,download=True, transform=transform)
else:
    ood_data = datasets.FashionMNIST(root='OODdata',train=True,download=False, transform=transform)

  0%|          | 0/26421880 [00:00<?, ?it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to OODdata/FashionMNIST/raw/train-images-idx3-ubyte.gz


26427392it [00:02, 9074639.47it/s]                              


Extracting OODdata/FashionMNIST/raw/train-images-idx3-ubyte.gz to OODdata/FashionMNIST/raw


32768it [00:00, 478273.11it/s]
  1%|          | 40960/4422102 [00:00<00:11, 395464.08it/s]

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to OODdata/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting OODdata/FashionMNIST/raw/train-labels-idx1-ubyte.gz to OODdata/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to OODdata/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


4423680it [00:00, 6927045.91it/s]                             
0it [00:00, ?it/s]

Extracting OODdata/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to OODdata/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to OODdata/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 125166.16it/s]


Extracting OODdata/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to OODdata/FashionMNIST/raw
Processing...
Done!


In [8]:
batch_size = 32
num_workers = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
train_loader = torch.utils.data.DataLoader(train_data,batch_size=batch_size,num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data,batch_size=batch_size,num_workers=num_workers)
ood_loader = torch.utils.data.DataLoader(ood_data,batch_size=batch_size,num_workers=num_workers)

## Model 

In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [9]:
model = Net()
model = model.to(device)