# Import GAN
trained using  https://github.com/alinlab/Confident_classifier

In [None]:
import torch

import utils.confident_classifier as cc


folder = '../Confident_classifier/SavedModels/'

# MNIST
file = 'MNIST_epochs100_lr0.0002_classifier.pth'
model = cc.LeNet()
model.load_state_dict(torch.load(folder + file))

model.cpu()
torch.save(model, 'SavedModels/other/gan/' + 'MNIST' + '.pth')


# FMNIST
file = 'FMNIST_epochs100_lr0.0002_beta1.0_type_VGG_classifier.pth'
model = cc.vgg13(in_channels=1, num_classes=10)
model.load_state_dict(torch.load(folder + file))

model.cpu()
torch.save(model, 'SavedModels/other/gan/' + 'FMNIST' + '.pth')


# SVHN
file = 'SVHN_epochs100_lr0.0002_beta1.0_type_VGG_classifier.pth'
model = cc.vgg13(in_channels=3, num_classes=10)
model.load_state_dict(torch.load(folder + file))

model.cpu()
torch.save(model, 'SavedModels/other/gan/' + 'SVHN' + '.pth')


# CIFAR10
file = 'CIFAR10_epochs100_lr0.0002_beta1.0_type_VGG_classifier.pth'
model = cc.vgg13(in_channels=3, num_classes=10)
model.load_state_dict(torch.load(folder + file))

model.cpu()
torch.save(model, 'SavedModels/other/gan/' + 'CIFAR10' + '.pth')

# Import OE
trained using https://github.com/hendrycks/outlier-exposure

In [None]:
import torch
import utils.hendrycks as oe
import utils.dataloaders as dl


folder = 'SavedModels/other/outlier-exposure/'
device = torch.device('cpu')


# MNIST
file = '../outlier-exposure/MNIST/snapshots/oe_scratch/oe_scratch_epoch_99.pt'

model = oe.MNIST_ConvNet()
model.load_state_dict(torch.load(file, map_location = device))
model.cpu()
torch.save(model, folder + 'hendrycks_MNIST' + '.pth')


# FMNIST
file = '../outlier-exposure/FMNIST/snapshots/oe_scratch/oe_scratch_epoch_99.pt'

#model = oe.WideResNet(40, 10, channels=1, dataset='FMNIST')
model = oe.ResNet18(num_of_channels=1, num_classes=10, dataset='FMNIST')
model.load_state_dict(torch.load(file, map_location = device))
model.cpu()
torch.save(model, folder + 'hendrycks_FMNIST' + '.pth')


# SVHN
file = '../outlier-exposure/SVHN/snapshots/oe_scratch/wrn_oe_scratch_epoch_99.pt'

#model = wrn.WideResNet(16, 10, 4, 0.4, dataset='SVHN')
model = oe.ResNet18(num_of_channels=3, num_classes=10, dataset='SVHN')
model.load_state_dict(torch.load(file, map_location = device))
model.cpu()
torch.save(model, folder + 'hendrycks_SVHN' + '.pth')


# CIFAR10
file = '../outlier-exposure/CIFAR/snapshots/oe_scratch/cifar10_resnet_oe_scratch_epoch_99.pt'

#model = wrn.WideResNet(40, 10, 2, 0.3)
model = oe.ResNet18(num_of_channels=3, num_classes=10, dataset='CIFAR10')
model.load_state_dict(torch.load(file, map_location = device))
model.cpu()
torch.save(model, folder + 'hendrycks_CIFAR10' + '.pth')


# Calibrate [ODIN](https://arxiv.org/abs/1706.02690)



In [None]:
import torch
import model_params as params
import utils.odin as odin

import model_paths


device = torch.device('cuda:0')

out_seeds = True

datasets = ['MNIST', 'FMNIST', 'SVHN', 'CIFAR10']

for dataset in datasets:
    print(dataset)
    model_params = params.params_dict[dataset](augm_flag=True)
    path = model_paths.model_dict[dataset]().file_dict['Base']

    model_params = params.params_dict[dataset]()

    base_model = torch.load(path).to(device)
    ODIN_model, _, _ = odin.grid_search_variables(base_model.to(device), model_params, 
                                                  device, out_seeds=out_seeds)

    torch.save(ODIN_model.cpu(), 'SavedModels/other/odin/' + dataset + '_ODIN.pth')

# Calibrate [Maha](https://arxiv.org/abs/1807.03888)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import utils.dataloaders as dl

import utils.single_maha as maha

import model_params as params
import model_paths


device = torch.device('cuda:0')

datasets = ['MNIST', 'FMNIST', 'SVHN', 'CIFAR10']


for dataset in datasets:
    print(dataset)
    model_params = params.params_dict[dataset](augm_flag=True)
    path = model_paths.model_dict[dataset]().file_dict['Base']

    pretrained = torch.load(path).state_dict()

    if dataset=='MNIST':
        model = maha.LeNetMadry()
    elif dataset=='FMNIST':
        model = maha.ResNet18(1, 10)
    elif dataset in ['SVHN', 'CIFAR10']:
        model = maha.ResNet18(3, 10)

    model.load_state_dict(pretrained)

    maha_model = maha.Mahalanobis(model.to(device), model_params, device)
    final_model = maha.ModelODIN(maha_model, model_params, device)

    torch.save(final_model.cpu(), 'SavedModels/other/single_mahalanobis/' + dataset + '.pth')