<center><h1>WideResNet: Cifar100</h1></center>

## Imports

In [1]:
from __future__ import division,print_function

%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
from tqdm import tqdm_notebook as tqdm

import random
import matplotlib.pyplot as plt
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.autograd import Variable, grad
from torchvision import datasets, transforms
from torch.nn.parameter import Parameter

import calculate_log as callog

import warnings
warnings.filterwarnings('ignore')

In [2]:
torch.cuda.set_device(1) #Select the GPU

## Model definition

In [3]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def G_p(ob, p):
    temp = ob.detach()

    temp = temp**p
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2)
    temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)

    return temp

class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes,track_running_stats=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes,track_running_stats=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
            torch_model.record(x)
        else:
            out = self.relu1(self.bn1(x))
            torch_model.record(out)
        if self.equalInOut:
            out = self.conv1(out)
        else:
            out = self.conv1(x)
        torch_model.record(out)
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        t = self.relu2(self.bn2(out))
        torch_model.record(t)
        out = self.conv2(t)
        torch_model.record(out)
        if not self.equalInOut:
            return torch.add(self.convShortcut(x), out)
        else:
            return torch.add(x, out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(nb_layers):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        self.collecting = False
        n = (depth - 4) // 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    
    def forward(self, x):
        out = self.conv1(x)
        torch_model.record(out)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)
    
    def record(self, t):
        if self.collecting:
            self.gram_feats.append(t)
    
    def gram_feature_list(self,x):
        self.collecting = True
        self.gram_feats = []
        self.forward(x)
        self.collecting = False
        temp = self.gram_feats
        self.gram_feats = []
        return temp
    
    def load(self, path="cifar100_wrn_oe_scratch_epoch_99.pt"):
        tm = torch.load(path,map_location="cpu")     
        self.load_state_dict(tm)
    
    def get_min_max(self, data, power):
        mins = []
        maxs = []
        
        for i in range(0,len(data),128):
            batch = data[i:i+128].cuda()
            feat_list = self.gram_feature_list(batch)
            for L,feat_L in enumerate(feat_list):
                if L==len(mins):
                    mins.append([None]*len(power))
                    maxs.append([None]*len(power))
                
                for p,P in enumerate(power):
                    g_p = G_p(feat_L,P)
                    
                    current_min = g_p.min(dim=0,keepdim=True)[0]
                    current_max = g_p.max(dim=0,keepdim=True)[0]
                    
                    if mins[L][p] is None:
                        mins[L][p] = current_min
                        maxs[L][p] = current_max
                    else:
                        mins[L][p] = torch.min(current_min,mins[L][p])
                        maxs[L][p] = torch.max(current_max,maxs[L][p])
        
        return mins,maxs
    
    def get_deviations(self,data,power,mins,maxs):
        deviations = []
        
        for i in range(0,len(data),128):            
            batch = data[i:i+128].cuda()
            feat_list = self.gram_feature_list(batch)
            batch_deviations = []
            for L,feat_L in enumerate(feat_list):
                dev = 0
                for p,P in enumerate(power):
                    g_p = G_p(feat_L,P)
                    
                    dev +=  (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True)
                    dev +=  (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True)
                batch_deviations.append(dev.cpu().detach().numpy())
            batch_deviations = np.concatenate(batch_deviations,axis=1)
            deviations.append(batch_deviations)
        deviations = np.concatenate(deviations,axis=0)
        
        return deviations

torch_model = WideResNet(depth=40,widen_factor=2, num_classes=100)
# torch_model = nn.DataParallel(torch_model)
# torch_model.load_state_dict(torch.load("wrn_cifar100.pth"))

torch_model.load()
torch_model.cuda()
torch_model.params = list(torch_model.parameters())
torch_model.eval()
print("Done")    

Done


## Datasets

<b>In-distribution Datasets</b>

In [4]:
batch_size = 128
mean = np.array([[0.4914, 0.4822, 0.4465]]).T

std = np.array([[0.2023, 0.1994, 0.2010]]).T
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

transform_test = transforms.Compose([
    transforms.CenterCrop(size=(32, 32)),
        transforms.ToTensor(),
        normalize
    ])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR100('data', train=True, download=True,
                   transform=transform_train),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR100('data', train=False, transform=transform_test),
    batch_size=batch_size)


Files already downloaded and verified


In [5]:
data_train = list(torch.utils.data.DataLoader(
        datasets.CIFAR100('data', train=True, download=True,
                       transform=transform_test),
        batch_size=1, shuffle=False))

Files already downloaded and verified


In [6]:
data = list(torch.utils.data.DataLoader(
    datasets.CIFAR100('data', train=False, download=True,
                   transform=transform_test),
    batch_size=1, shuffle=False))

Files already downloaded and verified


In [7]:
torch_model.eval()
correct = 0
total = 0
for x,y in test_loader:
    x = x.cuda()
    y = y.numpy()
    correct += (y==np.argmax(torch_model(x).detach().cpu().numpy(),axis=1)).sum()
    total += y.shape[0]
print("Accuracy: ",correct/total)


Accuracy:  0.7503


<b>Out-of-distribution Datasets</b>

In [8]:
cifar10 = list(torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, download=True,
                   transform=transform_test),
    batch_size=1, shuffle=False))

Files already downloaded and verified


In [10]:
isun = list(torch.utils.data.DataLoader(
    datasets.ImageFolder("iSUN/",transform=transform_test),batch_size=1,shuffle=False))

In [11]:
lsun_c = list(torch.utils.data.DataLoader(
    datasets.ImageFolder("LSUN/",transform=transform_test),batch_size=1,shuffle=True))

In [12]:
lsun_r = list(torch.utils.data.DataLoader(
    datasets.ImageFolder("LSUN_resize/",transform=transform_test),batch_size=1,shuffle=True))

In [13]:
tinyimagenet_c = list(torch.utils.data.DataLoader(
    datasets.ImageFolder("Imagenet/",transform=transform_test),batch_size=1,shuffle=True))

In [14]:
tinyimagenet_r = list(torch.utils.data.DataLoader(
    datasets.ImageFolder("Imagenet_resize/",transform=transform_test),batch_size=1,shuffle=True))

In [9]:
svhn = list(torch.utils.data.DataLoader(
    datasets.SVHN('data', split="test", download=True,
                   transform=transform_test),
    batch_size=1, shuffle=True))

Using downloaded and verified file: data/test_32x32.mat


## Code for Detecting OODs

<b> Extract predictions for train and test data </b>

In [15]:
train_preds = []
train_confs = []
train_logits = []
for idx in range(0,len(data_train),128):
    batch = torch.squeeze(torch.stack([x[0] for x in data_train[idx:idx+128]]),dim=1).cuda()
    
    logits = torch_model(batch)
    confs = F.softmax(logits,dim=1).cpu().detach().numpy()
    preds = np.argmax(confs,axis=1)
    logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)

    train_confs.extend(np.max(confs,axis=1))    
    train_preds.extend(preds)
    train_logits.extend(logits)
print("Done")

test_preds = []
test_confs = []
test_logits = []

for idx in range(0,len(data),128):
    batch = torch.squeeze(torch.stack([x[0] for x in data[idx:idx+128]]),dim=1).cuda()
    
    logits = torch_model(batch)
    confs = F.softmax(logits,dim=1).cpu().detach().numpy()
    preds = np.argmax(confs,axis=1)
    logits = (logits.cpu().detach().numpy())#**2)#.sum(axis=1)

    test_confs.extend(np.max(confs,axis=1))    
    test_preds.extend(preds)
    test_logits.extend(logits)
print("Done")

Done
Done


<b> Code for detecting OODs by identifying anomalies in correlations </b>

In [16]:
import calculate_log as callog

def detect(all_test_deviations,all_ood_deviations, test_confs = None, ood_confs=None, verbose=True, normalize=True):
    if test_confs is not None:
        test_confs = np.array(test_confs)
        ood_confs = np.array(ood_confs)
    
    average_results = {}
    for i in range(1,11):
        random.seed(i)
        
        validation_indices = random.sample(range(len(all_test_deviations)),int(0.1*len(all_test_deviations)))
        test_indices = sorted(list(set(range(len(all_test_deviations)))-set(validation_indices)))

        validation = all_test_deviations[validation_indices]
        test_deviations = all_test_deviations[test_indices]

        t95 = validation.mean(axis=0)+10**-7
        if not normalize:
            t95 = np.ones_like(t95)
        test_deviations = (test_deviations/t95[np.newaxis,:]).sum(axis=1)
        ood_deviations = (all_ood_deviations/t95[np.newaxis,:]).sum(axis=1)
        
        if test_confs is not None:
            thresh = np.max((validation/t95[np.newaxis,:]).sum(axis=1))
                        
            ood_deviations = ood_deviations - thresh*ood_confs
            test_deviations = test_deviations - thresh*test_confs[test_indices]
        
        results = callog.compute_metric(-test_deviations,-ood_deviations)
        for m in results:
            average_results[m] = average_results.get(m,0)+results[m]
    
    for m in average_results:
        average_results[m] /= i
    if verbose:
        callog.print_results(average_results)
    return average_results

def cpu(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cpu()
    return ob
    
def cuda(ob):
    for i in range(len(ob)):
        for j in range(len(ob[i])):
            ob[i][j] = ob[i][j].cuda()
    return ob

class Detector:
    def __init__(self):
        self.all_test_deviations = None
        self.mins = {}
        self.maxs = {}
        
        self.classes = range(100)
    
    def compute_minmaxs(self,data_train,POWERS=[10]):
        for PRED in tqdm(self.classes):
            train_indices = np.where(np.array(train_preds)==PRED)[0]
            train_PRED = torch.squeeze(torch.stack([data_train[i][0] for i in train_indices]),dim=1)
            mins,maxs = torch_model.get_min_max(train_PRED,power=POWERS)
            self.mins[PRED] = cpu(mins)
            self.maxs[PRED] = cpu(maxs)
            torch.cuda.empty_cache()
    
    def compute_test_deviations(self,POWERS=[10]):
        all_test_deviations = None
        all_test_deviations_msp = None
        all_test_confs = []
        for PRED in tqdm(self.classes):
            test_indices = np.where(np.array(test_preds)==PRED)[0]
            test_PRED = torch.squeeze(torch.stack([data[i][0] for i in test_indices]),dim=1)
            test_confs_PRED = np.array([test_confs[i] for i in test_indices])
            all_test_confs.extend(test_confs_PRED)
            mins = cuda(self.mins[PRED])
            maxs = cuda(self.maxs[PRED])
            test_deviations = torch_model.get_deviations(test_PRED,power=POWERS,mins=mins,maxs=maxs)
            test_deviations_MSP = test_deviations/test_confs_PRED[:,np.newaxis]
            cpu(mins)
            cpu(maxs)
            if all_test_deviations is None:
                all_test_deviations = test_deviations
                all_test_deviations_MSP = test_deviations_MSP
            else:
                all_test_deviations = np.concatenate([all_test_deviations,test_deviations],axis=0)
                all_test_deviations_MSP = np.concatenate([all_test_deviations_MSP,test_deviations_MSP],axis=0)
            torch.cuda.empty_cache()
        self.all_test_confs = all_test_confs
        self.all_test_deviations = all_test_deviations
        self.all_test_deviations_MSP = all_test_deviations_MSP
    
    def compute_ood_deviations(self,ood,POWERS=[10],msp=False):
        ood_preds = []
        ood_confs = []
        
        for idx in range(0,len(ood),128):
            batch = torch.squeeze(torch.stack([x[0] for x in ood[idx:idx+128]]),dim=1).cuda()
            logits = torch_model(batch)
            confs = F.softmax(logits,dim=1).cpu().detach().numpy()
            preds = np.argmax(confs,axis=1)
            
            ood_confs.extend(np.max(confs,axis=1))
            ood_preds.extend(preds)  
            torch.cuda.empty_cache()
        print("MSP")
        callog.print_results(callog.compute_metric(np.array(test_confs),np.array(ood_confs)))
        
        all_ood_deviations = None
        all_ood_deviations_MSP = None
        all_ood_confs = []
        for PRED in tqdm(self.classes):
            ood_indices = np.where(np.array(ood_preds)==PRED)[0]
            if len(ood_indices)==0:
                continue
            ood_PRED = torch.squeeze(torch.stack([ood[i][0] for i in ood_indices]),dim=1)
            
            ood_confs_PRED =  np.array([ood_confs[i] for i in ood_indices])
            
            all_ood_confs.extend(ood_confs_PRED)
            
            mins = cuda(self.mins[PRED])
            maxs = cuda(self.maxs[PRED])
            ood_deviations = torch_model.get_deviations(ood_PRED,power=POWERS,mins=mins,maxs=maxs)
            ood_deviations_MSP = ood_deviations/ood_confs_PRED[:,np.newaxis]
            cpu(self.mins[PRED])
            cpu(self.maxs[PRED])            
            if all_ood_deviations is None:
                all_ood_deviations = ood_deviations
                all_ood_deviations_MSP = ood_deviations_MSP
            else:
                all_ood_deviations = np.concatenate([all_ood_deviations,ood_deviations],axis=0)
                all_ood_deviations_MSP = np.concatenate([all_ood_deviations_MSP,ood_deviations_MSP],axis=0)
            torch.cuda.empty_cache()
        
        self.all_ood_confs = all_ood_confs
        
        print("Ours")
        average_results = detect(self.all_test_deviations,all_ood_deviations)
        print("Ours+MSP")
        average_results = detect(self.all_test_deviations,all_ood_deviations,self.all_test_confs,self.all_ood_confs)
        return average_results, self.all_test_deviations, all_ood_deviations

<center><h1> Results </h1></center>

## Ours

In [17]:
def G_p(ob, p):
    temp = ob.detach()
    
    temp = temp**p
    temp = temp.reshape(temp.shape[0],temp.shape[1],-1)
    temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1)))).sum(dim=2) 
    temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1)
    
    return temp

detector = Detector()
detector.compute_minmaxs(data_train,POWERS=range(1,11))

detector.compute_test_deviations(POWERS=range(1,11))

print("iSUN")
isun_results = detector.compute_ood_deviations(isun,POWERS=range(1,11))
print("LSUN (R)")
lsunr_results = detector.compute_ood_deviations(lsun_r,POWERS=range(1,11))
print("LSUN (C)")
lsunc_results = detector.compute_ood_deviations(lsun_c,POWERS=range(1,11))
print("TinyImgNet (R)")
timr_results = detector.compute_ood_deviations(tinyimagenet_r,POWERS=range(1,11))
print("TinyImgNet (C)")
timc_results = detector.compute_ood_deviations(tinyimagenet_c,POWERS=range(1,11))
print("SVHN")
svhn_results = detector.compute_ood_deviations(svhn,POWERS=range(1,11))
print("CIFAR-10")
c10_results = detector.compute_ood_deviations(cifar10,POWERS=range(1,11))

HBox(children=(IntProgress(value=0), HTML(value='')))




HBox(children=(IntProgress(value=0), HTML(value='')))


iSUN
MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 50.936 89.799 82.313 91.287 87.690


HBox(children=(IntProgress(value=0), HTML(value='')))


Ours
 TNR    AUROC  DTACC  AUIN   AUOUT 
 96.272 99.109 95.975 99.003 99.157
Ours+MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 95.564 98.891 96.022 98.783 99.003
LSUN (R)
MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 58.310 92.028 84.680 92.609 91.181


HBox(children=(IntProgress(value=0), HTML(value='')))


Ours
 TNR    AUROC  DTACC  AUIN   AUOUT 
 98.381 99.569 97.299 99.500 99.578
Ours+MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 97.404 99.319 97.392 99.158 99.420
LSUN (C)
MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 69.490 94.003 86.625 94.304 93.810


HBox(children=(IntProgress(value=0), HTML(value='')))


Ours
 TNR    AUROC  DTACC  AUIN   AUOUT 
 69.721 92.564 85.300 90.695 93.731
Ours+MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 83.083 96.297 89.662 96.025 96.616
TinyImgNet (R)
MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 36.090 85.079 77.545 86.051 83.371


HBox(children=(IntProgress(value=0), HTML(value='')))


Ours
 TNR    AUROC  DTACC  AUIN   AUOUT 
 96.254 99.116 95.851 98.835 99.242
Ours+MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 92.806 98.242 94.630 97.822 98.615
TinyImgNet (C)
MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 41.570 86.323 78.580 86.913 85.335


HBox(children=(IntProgress(value=0), HTML(value='')))


Ours
 TNR    AUROC  DTACC  AUIN   AUOUT 
 90.113 97.715 92.775 97.174 98.073
Ours+MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 87.115 96.933 91.136 96.407 97.505
SVHN
MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 56.150 92.497 85.577 87.364 96.150


HBox(children=(IntProgress(value=0), HTML(value='')))


Ours
 TNR    AUROC  DTACC  AUIN   AUOUT 
 84.837 96.510 90.815 90.459 98.676
Ours+MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 85.545 96.799 90.415 92.076 98.804
CIFAR-10
MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 17.410 78.395 71.725 81.506 73.226


HBox(children=(IntProgress(value=0), HTML(value='')))


Ours
 TNR    AUROC  DTACC  AUIN   AUOUT 
  7.487 59.332 57.341 57.102 59.165
Ours+MSP
 TNR    AUROC  DTACC  AUIN   AUOUT 
 16.500 77.673 71.586 78.561 73.990
