In [None]:
import os
os.chdir('/nfs/homedirs/ayle/guided-research/SNIP-it/bayesian')
import sys
sys.path.append('/nfs/homedirs/ayle/guided-research/SNIP-it')

260926

In [None]:
!python main_bayesian.py --net_type conv6 --dataset CIFAR10 --prune_criterion StructuredSNR --pruning_limit 0.2

In [None]:
import argparse
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
import torchvision
from torch.nn import functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pickle

import data
from main_bayesian import getModel
import config_bayesian as cfg

from main_bayesian import validate_model
import metrics
from uncertainty_estimation import *

In [None]:
# CUDA settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
in_data = init_dataset('CIFAR10')

In [None]:
ood_data = init_dataset('SVHN')

In [None]:
net_type = 'conv6'

In [None]:
# LOAD STRUCTURED PRUNED MODEL
if net_type == 'customconv6':
    import pickle
    with open('./checkpoints/CIFAR10/bayesian/model_conv6_bbb_relu_StructuredSNR_0.5.pt', 'rb') as f:
         net = pickle.load(f)

In [None]:
with open('/nfs/homedirs/ayle/model_conv6_0.5.pickle', 'rb') as f:
     pre_pruned_model = pickle.load(f)

In [None]:
pre_pruned_model

In [None]:
net = getModel(net_type, 3, 10, priors=None, layer_type='bbb', activation_type='relu', pre_pruned_model=None)

In [None]:
net.load_state_dict(torch.load('./checkpoints/CIFAR10/bayesian/model_conv6_bbb_relu_StructuredSNR_0.5_during.pt'))
net.eval()
net.to(device)

In [None]:
with open('/nfs/homedirs/ayle/mask.pickle', 'rb') as f:
    mask = pickle.load(f)

mask_keys = list(mask.keys())

count = 0
for name, module in net.named_modules():
    if name.startswith('conv') or name.startswith('fc'):
        module.mask = mask[mask_keys[count]]
        count += 1

In [None]:
import pickle
with open('./checkpoints/CIFAR10/bayesian/model_conv6_bbb_relu_StructuredSNR_0.5_during.pt', 'rb') as f:
     net = pickle.load(f)

In [None]:
sparsity = 0.7
all_scores = []
for name, module in net.named_modules():
    if name.startswith('conv') or name.startswith('fc'):
        scores = torch.abs(module.W_mu) / torch.log1p(torch.exp(module.W_rho)) # / module.weight.sigma
#         scores = - torch.log1p(torch.exp(module.W_rho)) 
        all_scores.append(scores.flatten())
all_scores = torch.cat([x for x in all_scores])
threshold, _ = torch.topk(all_scores, int(len(all_scores)*(1-sparsity)), sorted=True)
acceptable_score = threshold[-1]

In [None]:
for name, module in net.named_modules():
    if name.startswith('conv') or name.startswith('fc'):
        mask = (torch.abs(module.W_mu) / torch.log1p(torch.exp(module.W_rho))) > acceptable_score
#         mask = - torch.log1p(torch.exp(module.W_rho))  > acceptable_score
#         mask = (- module.weight.sigma) > acceptable_score
        module.mask = mask
        
        print(mask.sum().float() / torch.numel(mask))

In [None]:
valid_size = 0.2
batch_size = 256
num_workers = 4

trainset, testset, inputs, outputs = data.getDataset('CIFAR10')
train_loader, valid_loader, test_loader = data.getDataloader(
trainset, testset, valid_size, batch_size, num_workers)

ood_trainset, ood_testset, ood_inputs, ood_outputs = data.getDataset('SVHN')
ood_train_loader, ood_valid_loader, ood_test_loader = data.getDataloader(
ood_trainset, ood_testset, valid_size, batch_size, num_workers)

criterion = metrics.ELBO(len(trainset)).to(device)
beta_type = 0.1
epoch = 1
n_epochs = 1

In [None]:
n_ens = 5

In [None]:
valid_loss, valid_acc, max_probs = validate_model(net, criterion, valid_loader, num_ens=n_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)

In [None]:
valid_acc

In [None]:
ood_valid_loss, _, ood_max_probs = validate_model(net, criterion, ood_valid_loader, num_ens=n_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)

In [None]:
from sklearn import metrics as sk_metrics
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt


def calculate_auroc(correct, predictions):
    fpr, tpr, thresholds = sk_metrics.roc_curve(correct, predictions)
    auroc = sk_metrics.auc(fpr, tpr)
    plt.plot(fpr, tpr)
    return auroc


def calculate_aupr(correct, predictions):
    aupr = sk_metrics.average_precision_score(correct, predictions)
    return aupr

In [None]:
corrects = np.concatenate((np.ones_like(max_probs), np.zeros_like(ood_max_probs)))
print(calculate_auroc(corrects, np.concatenate((max_probs, ood_max_probs))))
print(calculate_aupr(corrects, np.concatenate((max_probs, ood_max_probs))))

In [None]:
all_epi, all_ale = 0, 0

for sample in test_loader:
    pred, epi_norm, ale_norm = get_uncertainty_per_batch(net, sample[0], T=25, normalized=True)
    pred, epi_soft, ale_soft = get_uncertainty_per_batch(net, sample[0], T=25, normalized=False)
    
    all_epi += epi_norm.mean(0)
    all_ale += ale_norm.mean(0)

In [None]:
print(all_epi.mean())
print(all_ale.mean())

In [None]:
ood_all_epi, ood_all_ale = 0, 0

for sample in ood_test_loader:
    pred, epi_norm, ale_norm = get_uncertainty_per_batch(net, sample[0], T=25, normalized=True)
    pred, epi_soft, ale_soft = get_uncertainty_per_batch(net, sample[0], T=25, normalized=False)
    
    ood_all_epi += epi_norm.mean(0)
    ood_all_ale += ale_norm.mean(0)

In [None]:
print(ood_all_epi.mean())
print(ood_all_ale.mean())

In [None]:
0.6441222
1.8379018

0.69216275
1.8008404