In [1]:
import torch
from torch import nn
import numpy as np
import torch.nn as nn
from torch.autograd.function import Function
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from  torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt

In [2]:
### DATASET ###

!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz
root_dir = './'

# DOWNLOAD TRAINING SET WITH CLASS 0-8

trainset = datasets.MNIST(root=root_dir, download=True,train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))]))

class YourSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, mask, data_source):
        self.mask = mask
        self.data_source = data_source

    def __iter__(self):
        return iter([i.item() for i in torch.nonzero(mask)])

    def __len__(self):
        return len(self.data_source)

mask = [0 if (trainset[i][1] == 9) else 1 for i in range(len(trainset))]
mask = torch.tensor(mask)  
sampler = YourSampler(mask, trainset)

train_loader = DataLoader(trainset, batch_size=128, sampler = sampler, num_workers=2)  

# DOWNLOAD TESING SET WITH CLASS 0-9

testset = datasets.MNIST(root=root_dir, download=True,train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))]))

test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)  

--2022-05-10 15:57:22--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz
Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]
--2022-05-10 15:57:22--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/x-gzip]
Saving to: ‘MNIST.tar.gz’

MNIST.tar.gz            [             <=>    ]  33.20M  11.7MB/s    in 2.8s    

2022-05-10 15:57:26 (11.7 MB/s) - ‘MNIST.tar.gz’ saved [34813078]

MNIST/
MNIST/raw/
MNIST/raw/train-labels-idx1-ubyte
MNIST/raw/t10k-labels-idx1-ubyte.gz
MNIST/raw/t10k-labels-idx1-ubyte
MNIST/raw/t10k-images-idx3-ubyte.gz
MNIST/raw/train-images-idx3-ubyte
MNIST/raw/train-labels-idx1-ubyte.gz
MNIST/raw/t10k-images-idx3-ubyte
MNIST/raw/tra

In [3]:
### BASIC NET ###
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.prelu1_1 = nn.PReLU()
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2)
        self.prelu1_2 = nn.PReLU()
        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.prelu2_1 = nn.PReLU()
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
        self.prelu2_2 = nn.PReLU()
        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
        self.prelu3_1 = nn.PReLU()
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2)
        self.prelu3_2 = nn.PReLU()
        self.preluip1 = nn.PReLU()
        self.ip1 = nn.Linear(128*3*3, 2)

    def forward(self, x):
        x = self.prelu1_1(self.conv1_1(x))
        x = self.prelu1_2(self.conv1_2(x))
        x = F.max_pool2d(x,2)
        x = self.prelu2_1(self.conv2_1(x))
        x = self.prelu2_2(self.conv2_2(x))
        x = F.max_pool2d(x,2)
        x = self.prelu3_1(self.conv3_1(x))
        x = self.prelu3_2(self.conv3_2(x))
        x = F.max_pool2d(x,2)
        x = x.view(-1, 128*3*3)
        ip1 = self.preluip1(self.ip1(x))
        return ip1

In [4]:
### CUSTOM COSINE SIMILARITY DENSE ###

class CosineSimilarityDense(nn.Module):

    def __init__(self, inputs, outputs):
        super().__init__()
        self.outputs = outputs
        self.weight = nn.Parameter(torch.randn(inputs, outputs))
        
    def forward(self, x):
        linear = torch.matmul(x, self.weight/torch.norm(self.weight) )
        
        return F.log_softmax(linear)

In [5]:
### COMPLETE NET ###

class Conv2D_Cos(nn.Module):
    def __init__(self):
        super(Conv2D_Cos, self).__init__()
        self.net = Net()
        self.cos_dense = CosineSimilarityDense(2,9)

    def forward(self, x):
        x = self.net(x)
        x = self.cos_dense(x)
        
        return x

    def get_embeddings(self, x):
        x = self.net(x)

        return x

In [6]:
### TRAINING FUNCTIONS/LOOPS ###

def Cos_training(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    correct = 0.0
    loss_sum = 0.0
    cnt_batches = 0

    for batch, (X, y) in enumerate(dataloader):

        X, y = X.to(device), y.to(device)
        cnt_batches += 1
    
        # Backpropagation
        optimizer.zero_grad()
        
        # Compute prediction and loss
        preds = model(X.float().to(device))
       
        loss = CE_loss(preds, y.long())

        loss.backward()
        optimizer.step()
        
        correct += (preds.argmax(1) == y).type(torch.float).sum().item()
        loss_sum += loss
        
        if batch % 20 == 0:
            loss, current = loss.item(), batch * len(X)
    
    # training accuracy
    correct /= size
    loss /= cnt_batches

    print(f"Training Error:  Accuracy: {(100*correct):>0.1f}%")
    loss = loss.to('cpu')
    
    return float(loss.detach().numpy()), correct

In [7]:
### Softmax/CE loss FUNCTION ###
def CE_loss(prediction, label):
    CEL =  torch.nn.CrossEntropyLoss()
    
    return CEL(prediction, label)

In [8]:
### TRAIN THE NETWORK ###

use_cuda = torch.cuda.is_available() and True
device = torch.device("cuda" if use_cuda else "cpu")

# Model
model = Conv2D_Cos().to(device)

# optimzer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9, weight_decay=0.0005)
sheduler = lr_scheduler.StepLR(optimizer,20,gamma=0.8)

# Training with 50 epoches
for epoch in range(50):
    sheduler.step()
    print(f"Epoch: : {epoch}")
    Cos_training(train_loader, model, optimizer)

Epoch: : 0


  del sys.path[0]


Training Error:  Accuracy: 9.8%
Epoch: : 1
Training Error:  Accuracy: 17.4%
Epoch: : 2
Training Error:  Accuracy: 18.3%
Epoch: : 3
Training Error:  Accuracy: 20.2%
Epoch: : 4
Training Error:  Accuracy: 44.1%
Epoch: : 5
Training Error:  Accuracy: 56.2%
Epoch: : 6
Training Error:  Accuracy: 68.3%
Epoch: : 7
Training Error:  Accuracy: 77.3%
Epoch: : 8
Training Error:  Accuracy: 82.8%
Epoch: : 9
Training Error:  Accuracy: 84.5%
Epoch: : 10
Training Error:  Accuracy: 85.6%
Epoch: : 11
Training Error:  Accuracy: 86.4%
Epoch: : 12
Training Error:  Accuracy: 86.9%
Epoch: : 13
Training Error:  Accuracy: 87.3%
Epoch: : 14
Training Error:  Accuracy: 87.6%
Epoch: : 15
Training Error:  Accuracy: 87.9%
Epoch: : 16
Training Error:  Accuracy: 88.1%
Epoch: : 17
Training Error:  Accuracy: 88.2%
Epoch: : 18
Training Error:  Accuracy: 88.3%
Epoch: : 19
Training Error:  Accuracy: 88.6%
Epoch: : 20
Training Error:  Accuracy: 88.9%
Epoch: : 21
Training Error:  Accuracy: 89.2%
Epoch: : 22
Training Error:  Acc

In [9]:
### EXTRACT THE CLASS SCORE ###

total_value = []

for data_train in train_loader:
    
    images,real_labels = data_train
    images = images.to(device)
    real_labels = real_labels.to(device)
    
    output = model(images)
    
    # get the output value for all train data
    value, position = torch.max(output,1)
    value = value.tolist()
    total_value.extend(value)

  del sys.path[0]


In [10]:
### DETERMINE THE OOD DETECTION THRESHOLD BY SORTING

total_value.sort()
id = int(0.01 * len(total_value))
threshold = total_value[id]

print(f"The threshold by sorting is: {threshold}")

The threshold by sorting is: -0.022744616493582726


In [11]:
### USE MODEL AND THRESHOLD ON TEST DATA ###

correct = 0
total = 0
i = 0
s = []
labels_pre = []

for data_test in test_loader:
    images,real_labels = data_test
    images = images.to(device)
    real_labels = real_labels.to(device)
    
    output = model(images)
    
    value, position = torch.max(output,1)
    labels_pre.append((position))
    
    # get the ID of predicted ood samples and save in s
    y = torch.zeros(value.shape).to(value.device) 
    y[value < threshold] = 1
    id = (y==1).nonzero(as_tuple = False).cpu()

    ood = i*128+id
    ood = (ood).numpy().flatten()
    s = np.concatenate((s,ood))
    
    i = i + 1
    correct += (position == real_labels).sum()
    
s = s.astype(int)  
print(f"The overall accuracy before detection is: {(1/100*correct):>0.1f}%")

  del sys.path[0]


The overall accuracy before detection is: 87.8%


In [12]:
### MARK THE DETECTED OOD SAMPLES AS 9 ###

labels = torch.cat(labels_pre, 0)

for i in range(len(s)):
    j = s[i]
    labels[j] = 9

In [13]:
### EVALUATE THE RESULTS ###

# Get the number of in and out of distribution samples

testset.targets = testset.targets.to(device)
y = torch.zeros(testset.targets.shape).to(testset.targets.device) 
y[testset.targets == 9] = 1
num_ood = (y==1).nonzero(as_tuple = False).cpu()
num_ind = (y!=1).nonzero(as_tuple = False).cpu()

# Percentage of OOD samples detected as OOD samples 

ood_acc = 0

for k in range(len(testset.targets)):
    if testset.targets[k]==9:
            ood_acc += (testset.targets[k] == labels[k])

ood_acc =ood_acc / len(num_ood) 
print(f"The accuracy of out of distribution detection: \n Accuracy is: {(100*ood_acc):>0.1f}%")

# Percentage of in-distribution samples that were detected as in-distribution  

ind_acc = 0

for k in range(len(testset.targets)):
    if testset.targets[k]!=9:
            ind_acc += (testset.targets[k] == labels[k])

ind_acc =ind_acc / len(num_ind) 
print(f"The accuracy of in distribution detection: \n Accuracy is: {(100*ind_acc):>0.1f}%")

# Evaluate the overall accuracy

overall_acc = 0

for k in range(len(testset.targets)):
    overall_acc += (testset.targets[k] == labels[k])
overall_acc = overall_acc/len(testset.targets) 
print(f"The overall accuracy after ood detection: \n Accuracy is: {(100*overall_acc):>0.1f}%")

The accuracy of out of distribution detection: 
 Accuracy is: 62.8%
The accuracy of in distribution detection: 
 Accuracy is: 95.4%
The overall accuracy after ood detection: 
 Accuracy is: 92.1%
