In [None]:
from models_code.utilities import create_model
from models_code.utilities import dump_results

from models_code.experiments import correlation_test_error_uncertainty
from models_code.experiments import load_lfw
from models_code.experiments import not_mnist_predictions
from models_code.experiments import non_distribution
from models_code.experiments import test_eval
from models_code.experiments import softmax2d

from models_code.mnist import perform_training

from models_code.cifar import load_data
from models_code.cifar import load_svhn
from models_code.cifar import ISCifar

from models_code.utilities import load_model

from utilities.metric import entropy

import torch

import numpy as np

from sklearn.metrics import accuracy_score
from sklearn.metrics import log_loss

In [None]:
def set_same_seed():
    torch.manual_seed(9)
    torch.cuda.manual_seed(9)

In [None]:
batch_size = 64
log_interval = 100
epochs = 100

# Inhibited softmax

In [None]:
set_same_seed()
train_loader, test_loader = load_data(batch_size)

In [None]:
is_, optimizer, cross_entropy  = create_model(ISCifar)

def is_loss(model):
    
    return (
        lambda pred,aft_cauchy,y: cross_entropy(pred,y)
        - 0.000001 * aft_cauchy.sum()
    )
    
perform_training(
    epochs,
    is_,
    train_loader,
    test_loader,
    optimizer,
    is_loss(is_),
    log_interval,
    './models/cifar_lenet/is2.torch',
    60000 // batch_size + 1,
    channels=3
)

# is_ = load_model(ISCifar, './models/cifar_lenet/is2.torch')

In [None]:
test_preds, test_labels, test_probs = test_eval(is_, test_loader, channels=3)

In [None]:
accuracy_score(test_labels, test_preds)

In [None]:
log_loss(test_labels, softmax2d(test_probs[:,:10]))

### Second experiment

In [None]:
roc, ac, fpr, tpr, pr, re = correlation_test_error_uncertainty(
    lambda x: softmax2d(x)[:,10],
    test_probs,
    test_labels
)

In [None]:
roc

In [None]:
ac

In [None]:
dump_results(fpr, tpr, pr, re, './results/mnist/is.pickle')

### Third experiment

In [None]:
svhn_loader = load_svhn(batch_size)

In [None]:
svhn_preds, svhn_labels, svhn_probs = test_eval(is_, svhn_loader, channels=3)

In [None]:
roc, ac, fpr, tpr, pr, re = non_distribution(
    test_probs,
    softmax2d(test_probs)[:,10].reshape(10000,1),
    softmax2d(svhn_probs)[:,10].reshape(73257,1),
    83257,
    10000
)

In [None]:
roc

In [None]:
ac

In [None]:
dump_results(fpr, tpr, pr, re, './results/notmnist/is.pickle')

### LFW-a

In [None]:
lfw_loader = load_lfw(batch_size)

In [None]:
lfw_preds, lfw_labels, lfw_probs = test_eval(is_, lfw_loader, channels=3)

In [None]:
roc, ac, fpr, tpr, pr, re = non_distribution(
    test_probs,
    softmax2d(test_probs)[:,10].reshape(10000,1),
    softmax2d(lfw_probs)[:,10].reshape(1054,1),
    11054,
    10000
)


In [None]:
roc

In [None]:
ac