# General Purpose Exp Notebook 

This notebook has sections to train models, create uncertainty wrappers, and test the models. Experiment specific details are assumed to be contained in `config.py` in the experiment folder below.

In [None]:
%load_ext autoreload
%autoreload 2

EXP_FOLDER = 'MNIST'

In [None]:
import sys
import os
sys.path.append(os.path.abspath(EXP_FOLDER))
import config # imported from EXP_FOLDER
import cProfile

import torch

## Train and save models
Trains an ensemble of models as specified in config

In [None]:
from nn_ood.utils.train import train_ensemble
models = train_ensemble(config.N_MODELS, 
                        config.make_model, 
                        config.dataset_class, 
                        config.dist_fam, 
                        config.opt_class,
                        config.opt_kwargs,
                        config.sched_class,
                        config.sched_kwargs,
                        config.device,
                        num_epochs=config.N_EPOCHS,
                        batch_size=config.BATCH_SIZE)


## SAVE MODEL
print("saving models")
save_folder = os.path.join(EXP_FOLDER, 'models')
if not os.path.exists(save_folder):
    os.makedirs(save_folder)

for i, model in enumerate(models):
    filename = os.path.join(EXP_FOLDER, "models", config.FILENAME + "_%d" % i)
    torch.save(model.state_dict(), filename)

In [None]:
# clear memory
del models

## Process Data to create uncertainty wrappers
Loops over data to create uncertainty wrappers, and saves them

In [None]:
save_folder = os.path.join(EXP_FOLDER, 'times')
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
    
## SET UP MODEL
model = config.make_model()

## LOAD MODEL
filename = os.path.join(EXP_FOLDER, "models", config.FILENAME + "_0" )
model.load_state_dict(torch.load(filename))
model = model.to(config.device)
model.eval()

## SETUP DATASET
dataset = config.dataset_class("train", N=5000)

## SET UP UNC WRAPPERS
for name, info in config.prep_unc_models.items():
    print(name)
    
    config.unfreeze_model(model)
    if 'freeze' in info:
        if type(info['freeze']) is bool:
            freeze_frac = None
        else:
            freeze_frac = info['freeze']
        config.freeze_model(model, freeze_frac=freeze_frac)        
    
    if 'apply_fn' is info:
        model.apply(info['apply_fn'])

    unc_model = info['class'](model, config.dist_fam, info['kwargs'])

    cProfile.run("""\n
unc_model.process_dataset(dataset)
    """, os.path.join(EXP_FOLDER, "times", name+"_process.timing") )

    filename = os.path.join(EXP_FOLDER, "models", name+"_"+config.FILENAME)
    torch.save(unc_model.state_dict(), filename)

In [None]:
# clear memory
del model
del unc_model

## Test Uncertainty Wrappers
Evaluates prediction and uncertainty estimate on various datasets

In [None]:
from nn_ood.utils.test import process_datasets

# LOAD UNC_WRAPPERS
print("Loading models")
models = []
for i in range(config.N_MODELS):
    print("loading model %d" % i)
    filename = os.path.join(EXP_FOLDER, 'models', config.FILENAME + "_%d" % i)
    state_dict = torch.load(filename)
    model = config.make_model()
    model.load_state_dict(state_dict)
    model.eval()
    model.to(config.device)
    models.append(model)

model = models[0]

### Test against OoD datasets

In [None]:
save_folder = os.path.join(EXP_FOLDER, 'results')
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
save_folder = os.path.join(EXP_FOLDER, 'times')
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
    
for name, info in config.test_unc_models.items():
    print(name)
    
    config.unfreeze_model(model)
    if 'freeze' in info:
        if type(info['freeze']) is bool:
            freeze_frac = None
        else:
            freeze_frac = info['freeze']
        config.freeze_model(model, freeze_frac=freeze_frac)        
    
    if 'apply_fn' is info:
        model.apply(info['apply_fn'])
        
    if 'multi_model' in info:
        unc_model = info['class'](models, config.dist_fam, info['kwargs'])
    else:
        unc_model = info['class'](model, config.dist_fam, info['kwargs'])
    
    if info['load_name'] is not None: 
        filename = os.path.join(EXP_FOLDER, "models", info['load_name']+"_"+config.FILENAME)
        print(filename)
        unc_model.load_state_dict(torch.load(filename))
        unc_model.cuda()
    
    try:
        cProfile.run("""\n
results = process_datasets(config.dataset_class, 
                           config.test_dataset_args,
                           unc_model, 
                           config.device,
                           N=1000,
                           **info['forward_kwargs'])
        """, os.path.join(EXP_FOLDER, "times", name) )
        savepath = os.path.join(EXP_FOLDER, "results", name)
        torch.save(results, savepath)
    except Exception as e:
        print(e)

### Test against noise

In [None]:
from nn_ood.utils.test import transform_sweep

if "transforms" not in dir(config):
    raise NameError("No transforms to test for this experiment")
    
save_folder = os.path.join(EXP_FOLDER, 'results_transforms')
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
save_folder = os.path.join(EXP_FOLDER, 'times_transforms')
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
    
for name, info in config.test_unc_models.items():
    print(name)
    
    config.unfreeze_model(model)
    if 'freeze' in info:
        if type(info['freeze']) is bool:
            freeze_frac = None
        else:
            freeze_frac = info['freeze']
        config.freeze_model(model, freeze_frac=freeze_frac)        
    
    if 'apply_fn' is info:
        model.apply(info['apply_fn'])
        
    if 'multi_model' in info:
        unc_model = info['class'](models, config.dist_fam, info['kwargs'])
    else:
        unc_model = info['class'](model, config.dist_fam, info['kwargs'])
    
    if info['load_name'] is not None: 
        filename = os.path.join(EXP_FOLDER, "models", info['load_name']+"_"+config.FILENAME)
        print(filename)
        unc_model.load_state_dict(torch.load(filename))
        unc_model.cuda()
    

    dataset = config.dataset_class(config.in_dist_splits[0],N=1000)
    cProfile.run("""\n
results = transform_sweep(dataset, 
                      config.transforms,
                      unc_model, 
                      config.device,
                      **info['forward_kwargs'])
    """, os.path.join(EXP_FOLDER, "times", name) )
    savepath = os.path.join(EXP_FOLDER, "results", name)
    torch.save(results, savepath)
