In [None]:
from models_code.utilities import create_model
from models_code.utilities import dump_results

from models_code.experiments import correlation_test_error_uncertainty
from models_code.experiments import load_notmnist
from models_code.experiments import load_omniglot
from models_code.experiments import load_cifar_bw

from models_code.experiments import not_mnist_predictions
from models_code.experiments import non_distribution
from models_code.experiments import test_eval
from models_code.experiments import softmax2d

from models_code.mnist import perform_training
from models_code.mnist import load_data

from models_code.mnist import Mnist
from models_code.mnist import ISMnist

from models_code.utilities import load_model

import torch

import numpy as np

from sklearn.metrics import accuracy_score
from sklearn.metrics import log_loss

In [None]:
def set_same_seed():
    torch.manual_seed(9)
    torch.cuda.manual_seed(9)

In [None]:
batch_size = 128
log_interval = 100
epochs = 12

# Inhibited softmax

In [None]:
set_same_seed()
train_loader, test_loader = load_data(batch_size)

In [None]:
is_, optimizer, cross_entropy  = create_model(ISMnist)


def is_loss(model):
    
    return (
        lambda pred,aft_cauchy,y: cross_entropy(pred,y)
        - 0.000001 * torch.log(aft_cauchy).sum()
    )

perform_training(
    epochs,
    is_,
    train_loader,
    test_loader,
    optimizer,
    is_loss(is_),
    log_interval,
    './models/mnist_lenet/is.torch',
    60000 // batch_size + 1
)

# is_ = load_model(ISMnist, './models/mnist_lenet/is.torch')

In [None]:
test_preds, test_labels, test_probs = test_eval(is_, test_loader)

In [None]:
accuracy_score(test_labels, test_preds)

In [None]:
log_loss(test_labels, softmax2d(test_probs[:,:10]))

### Second experiment - wrong prediction detection

In [None]:
roc, ac, fpr, tpr, pr, re = correlation_test_error_uncertainty(
    lambda x: -np.max(softmax2d(x)[:,:10], axis=1),
    test_probs,
    test_labels
)


In [None]:
roc

In [None]:
ac

In [None]:
dump_results(fpr, tpr, pr, re, './results/mnist/is.pickle')

### Third experiment - out of distribution detection

In [None]:
not_mnist_loader = load_notmnist(batch_size)

In [None]:
notmnist_truth, notmnist_probs, notmnist_images = not_mnist_predictions([is_], not_mnist_loader, softmaxed=False)

In [None]:
roc, ac, fpr, tpr, pr, re = non_distribution(
    test_probs,
    softmax2d(test_probs)[:,10].reshape(10000,1),
    softmax2d(notmnist_probs[0])[:,10].reshape(18724,1),
    28724,
    10000
)


In [None]:
roc

In [None]:
ac

In [None]:
dump_results(fpr, tpr, pr, re, './results/notmnist/is.pickle')

### Omniglot

In [None]:
not_mnist_loader = load_omniglot(batch_size)

In [None]:
notmnist_truth, notmnist_probs, notmnist_images = not_mnist_predictions([is_], not_mnist_loader, softmaxed=False)

In [None]:
roc, ac, fpr, tpr, pr, re = non_distribution(
    test_probs,
    softmax2d(test_probs)[:,10].reshape(10000,1),
    softmax2d(notmnist_probs[0])[:,10].reshape(32460,1),
    42460,
    10000
)


In [None]:
roc

In [None]:
ac

In [None]:
dump_results(fpr, tpr, pr, re, './results/omniglot/is.pickle')

### Cifar-bw

In [None]:
not_mnist_loader = load_cifar_bw(batch_size)

In [None]:
notmnist_truth, notmnist_probs, notmnist_images = not_mnist_predictions([is_], not_mnist_loader, softmaxed=False)

In [None]:
roc, ac, fpr, tpr, pr, re = non_distribution(
    test_probs,
    softmax2d(test_probs)[:,10].reshape(10000,1),
    softmax2d(notmnist_probs[0])[:,10].reshape(50000,1),
    60000,
    10000
)

In [None]:
roc

In [None]:
ac

In [None]:
dump_results(fpr, tpr, pr, re, './results/cifar-bw/is.pickle')