In [None]:
!pip install gpytorch

In [None]:
''' Model swag-20/swag-40 saved in checkpoint directory, are used to evaluate results on drifted data. Two experiments
    are performed here:
    1- Using same data as input that SWAG was trained on (undrifted)
    2- Using different data as input (Drifted data)
    
    Two files  get'''

In [None]:
import argparse
import torch
import torch.nn.functional as F
import numpy as np
import os
import tqdm


In [None]:

from swag import data, losses, models, utils
from swag.posteriors import SWAG, KFACLaplace


In [None]:
import imageio
import glob


In [None]:
from PIL import Image

In [None]:
from skimage.transform import rotate, AffineTransform, warp
from sklearn.model_selection import train_test_split

In [None]:
# These parameters should be similar to what are used for training
cov_mat=False
# Use Swag
swa=True
# number of models to sample from Gaussian
max_num_models=20
# loss function to use
loss='CE'
lr_init=0.1
wd=3e-4 
momentum=0.9
start_epoch=0
resume=None
swa_resume=None
# total epochs to run it also training the model
epochs=10
swa_start=161
eval_freq=5
no_schedule=False
swa_lr=0.02
save_freq = 2


In [None]:
file=r'./checkpoints/swag-20.pt'


use_test=False
batch_size=16
split_classes=1
num_workers=4

model='PreResNet56'
method='SWAG'
N=10
# SWAG parameter
scale=1.0
cov_mat=True
use_diag=True
seed=1
num_classes=2

In [None]:
# Save entropies and accuracies, change the file name based on the whether in-class or out-of-class data is being used
save_path=r"/content/drive/My Drive/swa_gaussian-master/data/output_out"
# save_path=r"/content/drive/My Drive/swa_gaussian-master/data/output_in"


In [None]:
eps = 1e-12
if  cov_mat:
     cov_mat = True
else:
     cov_mat = False


torch.backends.cudnn.benchmark = True
torch.manual_seed( seed)
torch.cuda.manual_seed( seed)


In [None]:
# Replace file path here when running experiment for out of class data
def load_pngs():
    good, bad = [], []
    for im_path in glob.glob("./data/bottle/good/*.png"):
        im = Image.open(im_path)
        good.append(im)
    for im_path in glob.glob("./data/toothbrush/train/good/*.png"):
        im = Image.open(im_path)
        bad.append(im)
    return good, bad

In [None]:
class MvTecDataset(torch.utils.data.Dataset):

    def __init__(self, imgs, labels, transform):
        # self.imgs = imgs.astype(np.float32)
        self.imgs = imgs
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.transform(self.imgs[idx]), self.labels[idx]

In [None]:
good_imgs, bad_imgs = load_pngs()

In [None]:
images = []
labels = []
for i in good_imgs:
    images.append(i)
    labels += [1] 

for i in bad_imgs:
    images.append(i)
    labels += [0]

In [None]:

def nll(outputs, labels):
    labels = labels.astype(int)
    idx = (np.arange(labels.size), labels)
    ps = outputs[idx]
    nll = -np.sum(np.log(ps))
    return nll

In [None]:
print("Using model %s" %  model)
model_cfg = getattr(models,  model)



In [None]:
model_cfg

In [None]:
print("Preparing model")
model = SWAG(
        model_cfg.base,
        no_cov_mat=not  cov_mat,
        max_num_models=20,
        *model_cfg.args,
        num_classes=num_classes,
        **model_cfg.kwargs
    )



In [None]:
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.145, shuffle=True)
loaders = {
    "train": torch.utils.data.DataLoader(MvTecDataset(X_train, y_train, model_cfg.transform_test), batch_size=batch_size, shuffle=True, drop_last=True),    
    "test": torch.utils.data.DataLoader(MvTecDataset(X_test, y_test, model_cfg.transform_test), batch_size=batch_size, shuffle=True, drop_last=True)}

print(len(X_train), len(X_test))

In [None]:
def train_dropout(m):
    if type(m) == torch.nn.modules.dropout.Dropout:
        m.train()


In [None]:
print("Loading model %s" %  file)
checkpoint = torch.load( file)
model.load_state_dict(checkpoint["state_dict"])


In [None]:
predictions = np.zeros((len(loaders["train"].dataset), num_classes))
targets = np.zeros(len(loaders["train"].dataset))
print(targets.size)

In [None]:
# Evaluating SWAG on loaded data, this function outputs entropies
N = 10

for i in range( N):
    print("%d/%d" % (i + 1,  N))
    k = 0
    for input, target in tqdm.tqdm(loaders["train"]):
        input = input.cuda(non_blocking=True)
        torch.manual_seed(i)
        output = model(input)

        with torch.no_grad():
            predictions[k : k + input.size()[0]] += (
                F.softmax(output, dim=1).cpu().numpy()
            )
        targets[k : (k + target.size(0))] = target.numpy()
        k += input.size()[0]

    print("Accuracy:", np.mean(np.argmax(predictions, axis=1) == targets))
    #nll is sum over entire dataset
    print("NLL:", nll(predictions / (i + 1), targets))
predictions /=  N

entropies = -np.sum(np.log(predictions + eps) * predictions, axis=1)

In [None]:
np.savez( save_path, entropies=entropies, predictions=predictions, targets=targets)