<a href="https://colab.research.google.com/github/sakshamgarg/Augmenting-Dirichlet-Network/blob/main/finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pdb
import argparse
import numpy as np
from tqdm import tqdm
from sklearn import metrics
import matplotlib.pyplot as plt
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from torchvision.utils import make_grid
from torch.autograd import Variable
from itertools import cycle
import seaborn as sns
# from models.vgg import VGG, VarVGG
# from models.densenet import DenseNet3
# from models.resnext import resnext
# from models.wideresnet import WideResNet
# from models.resnet import *
# from utils.ood_metrics import tpr95, detection
# from utils.datasets import GaussianNoise, UniformNoise
# from utils.utils import obtain_dirichelets, cosine_distance, mse_loss, Transit, reproject_dirichlets
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
ind_options = ['cifar10']
ood_options = ['svhn']
model_options = ['resnet']
process_options = ['baseline', 'ODIN', 'confidence', 'confidence_scaling', 'dirichlet', 'variational', 'finetune']

args_checkpoint = 'cifar10_resnet_variational_dirichlet'
args_ind_dataset = 'cifar10'
args_model = 'resnet18'
cudnn.benchmark = True  # Should make training should go faster for large models

filename = args_checkpoint

proj_filename = '/content/drive/My Drive/CV/checkpoint/{}_{}_proj.pt'.format(args_ind_dataset, args_model)


In [4]:
###########################
### Set up data loaders ###
###########################

# For Cifar10
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

transform = transforms.Compose([transforms.ToTensor(),
                                normalize])
num_classes = 10
ind_dataset = datasets.CIFAR10(root='data/',
                              train=False,
                              transform=transform,
                              download=True)
data_path = 'data/'
ood_dataset = datasets.SVHN(root='data/', split='test', transform=transform, download=True)

args_batch_size = 128
ind_loader = torch.utils.data.DataLoader(dataset=ind_dataset,
                                         batch_size=args_batch_size,
                                         shuffle=True,
                                         pin_memory=True,
                                         num_workers=2)

ood_loader = torch.utils.data.DataLoader(dataset=ood_dataset,
                                         batch_size=args_batch_size,
                                         shuffle=False,
                                         pin_memory=True,
                                         num_workers=2)

valid_ood_loader = torch.utils.data.DataLoader(dataset=ood_dataset,
                                         batch_size=args_batch_size,
                                         shuffle=False,
                                         pin_memory=True,
                                         num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data/
Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to data/test_32x32.mat


  0%|          | 0/64275384 [00:00<?, ?it/s]

In [5]:
from torch.nn.parameter import Parameter


class Transit(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(Transit, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.copy_(0.01 * torch.ones(self.in_features, self.out_features))
        if self.bias is not None:
            self.bias.data.copy_(torch.zeros(self.out_features, ))

    def forward(self, input):
        weight = torch.relu(self.weight)
        #bias = torch.clamp(self.bias, min=0)
        #norm_input = torch.norm(input, 2, -1)
        epsilon = F.linear(input, weight, self.bias)
        #norm_epsilon = torch.norm(epsilon, 2, -1)
        #ratio = norm_epsilon / norm_input
        #scale = torch.max(ratio / 0.1, torch.ones(ratio.shape[0], ).cuda()).unsqueeze(-1)
        #epsilon = epsilon / scale
        return epsilon + input

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


In [6]:
##############################
### Load pre-trained model ###
##############################

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out, None


def ResNet18(num_classes):
    return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes)


In [9]:
cnn = ResNet18(num_classes=num_classes)

cnn = torch.nn.DataParallel(cnn)
cnn.to(device)

pretrained_dict = torch.load('/content/drive/My Drive/CV/checkpoint/' + filename + '.pt')
if "state_dict" in pretrained_dict:
    pretrained_dict = pretrained_dict['state_dict']
cnn.load_state_dict(pretrained_dict)

cnn.eval()

proj = Transit(num_classes, num_classes, bias=False)
proj.to(device)

args_learning_rate=1e-4
optimizer = torch.optim.Adam(proj.parameters(), lr=args_learning_rate, weight_decay=5e-4)
#optimizer = torch.optim.SGD(proj.parameters(), lr=args.learning_rate, nesterov=True, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
criterion.to(device)
min_pixel = -3
max_pixel = 3


In [10]:
## Utils.py

clamp_threshold = 1000

def entropy(alpha):
    alpha0 = torch.sum(alpha, 1)
    logB = torch.sum(torch.lgamma(alpha), 1) - torch.lgamma(alpha0)
    digamma_1 = (alpha0 - alpha.size(1)) * torch.digamma(alpha0)
    digamma_2 = -torch.sum((alpha - 1) * torch.digamma(alpha), 1)
    entropy = logB + digamma_1 + digamma_2
    return entropy, logB, digamma_1, digamma_2

def reproject_dirichlets(logits, proj, model, mean=False, logscale=False):
    if logscale:
        alphas = torch.log(torch.exp(logits) + 1) + 1
    else:
        alphas = torch.exp(logits).clamp(0, clamp_threshold)
    alphas = proj(alphas)

    scale = torch.max(alphas,dim=-1,keepdim=True)[0] / clamp_threshold
    scale = torch.clamp(scale, min=1)
    alphas = alphas / scale

    ent, ret1, ret2, ret3 = entropy(alphas)
    conf = -ent
    if mean:
        conf = torch.mean(conf)
    return conf, alphas



## ood_metrics.py

def tpr95(ind_confidences, ood_confidences):
    #calculate the falsepositive error when tpr is 95%
    Y1 = ood_confidences
    X1 = ind_confidences
    delta = np.percentile(X1, 5)
    fprBase = np.sum(np.sum(Y1 > delta)) / float(len(Y1))
    """
    start = np.min([np.min(X1), np.min(Y1)])
    end = np.max([np.max(X1), np.max(Y1)])
    gap = (end - start) / 100000

    total = 0.0
    fpr = 0.0

    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(X1 >= delta)) / np.float(len(X1))
        error2 = np.sum(np.sum(Y1 > delta)) / np.float(len(Y1))
        if tpr <= 0.9505 and tpr >= 0.9495:
            fpr += error2
            total += 1
    fprBase = fpr / total
    """
    return fprBase

In [11]:
args_logscale = False

def fgsm(cnn, images, labels, eps=0.03):
    images_adv = Variable(images.data, requires_grad=True).cuda()
    images_adv.retain_grad()
    pre_logits, _ = cnn(images_adv)
    cost = -criterion(pre_logits, labels)
    
    cnn.zero_grad()
    if images_adv.grad is not None:
        images_adv.grad.data.fill_(0)
    cost.backward()

    images_adv.grad.sign_()
    images_adv = images_adv - eps * images_adv.grad
    images_adv = torch.clamp(images_adv, min_pixel, max_pixel)
    return images_adv

def output_conf(images, mean=False):
    pre_logits, _ = cnn(images)
    pre_logits = Variable(pre_logits.data)
    conf, alphas = reproject_dirichlets(pre_logits, proj, args_model, mean, logscale=args_logscale)
    return conf, pre_logits, alphas

def calc_fpr():
    def evaluate(data_loader, mode):
        out = []
        all_alphas = []
        for data in data_loader:
            if type(data) == list:
                images, labels = data
            else:
                images = data

            images = Variable(images, requires_grad=True).cuda()
            images.retain_grad()

            if mode == 'dirichlet':
                confidence, logits, alphas = output_conf(images, mean=False)
                out.append(confidence.data.cpu().numpy())
                all_alphas.append(alphas.data.cpu().numpy())

        # print(out)
        out = np.concatenate(out)
        all_alphas = np.concatenate(all_alphas)
        return out, all_alphas

    ind_scores, alphas = evaluate(ind_loader, args_process)
    ind_labels = np.ones(ind_scores.shape[0])

    ood_scores, ood_alphas = evaluate(ood_loader, args_process)
    ood_labels = np.zeros(ood_scores.shape[0])

    labels = np.concatenate([ind_labels, ood_labels])
    scores = np.concatenate([ind_scores, ood_scores])

    fpr_at_95_tpr = tpr95(ind_scores, ood_scores)
    return fpr_at_95_tpr, alphas, ood_alphas


In [12]:
train_ood_loader = cycle(iter(valid_ood_loader))
#train_ood_loader = None

args_reload = False
args_process = 'dirichlet'
args_epsilon = 0.2

if os.path.exists(proj_filename) and args_reload:
    proj.load_state_dict(torch.load(proj_filename))
    print("loading pre-trained model")

score = 1
losses = []
cur_score, test_alphas, test_ood_alphas = calc_fpr()
print("starting FPR = {}".format(cur_score))
for epoch in range(20):
    def obtain_conf(data, adv_data=None):
        if type(data) == list:
            images, labels = data
        else:
            images = data

        labels = Variable(labels).cuda()
        
        if adv_data is None:
            eps = 0.2
            adv_images = fgsm(cnn, images, labels, eps=args_epsilon)
        else:
            adv_images, _ = adv_data
        
        confidence, _, alphas = output_conf(images, mean=True)
        adv_confidence, _, adv_alphas = output_conf(adv_images, mean=True)


        #return confidence, adv_confidence, cross_entropy
        return confidence, adv_confidence, alphas, adv_alphas

    for i, ind_data in enumerate(ind_loader):
        cnn.zero_grad()
        proj.zero_grad()
        #ood_data = next(train_ood_loader)
        conf, adv_conf, alphas, adv_alphas = obtain_conf(ind_data, None)
        loss = adv_conf - conf
        losses.append(loss)
        loss.backward()
        optimizer.step()
        if i % 20 == 0:
            cur_score, test_alphas, test_ood_alphas = calc_fpr()
            print("current best results FPR = {} with avg loss = {}".format(cur_score, sum(losses)/len(losses)))
            print("alphas: mean = {}; adv_alphas: mean = {}".format(alphas.mean().item(), adv_alphas.mean().item()))
            print("alphas: mean = {}; ood_alphas: mean = {}".format(test_alphas.mean(), test_ood_alphas.mean()))
            print("")
            if cur_score < score:
                torch.save(proj.state_dict(), proj_filename)
                score = cur_score
            losses = []


starting FPR = 0.4627381684081131
current best results FPR = 0.46246926859250154 with avg loss = -9.439702987670898
alphas: mean = 92.3352279663086; adv_alphas: mean = 25.740859985351562
alphas: mean = 94.80915069580078; ood_alphas: mean = 24.193002700805664

current best results FPR = 0.4572449293177627 with avg loss = -8.799222946166992
alphas: mean = 93.64287567138672; adv_alphas: mean = 32.56784439086914
alphas: mean = 93.9263687133789; ood_alphas: mean = 24.05292320251465

current best results FPR = 0.448640135218193 with avg loss = -9.428142547607422
alphas: mean = 95.21593475341797; adv_alphas: mean = 27.561203002929688
alphas: mean = 92.94032287597656; ood_alphas: mean = 23.87471580505371

current best results FPR = 0.435079901659496 with avg loss = -10.162351608276367
alphas: mean = 91.43299865722656; adv_alphas: mean = 35.95753860473633
alphas: mean = 91.75495910644531; ood_alphas: mean = 23.64591407775879

current best results FPR = 0.422979409956976 with avg loss = -11.5426