In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import logging

from kyle.sampling.fake_clf import DirichletFC, MultiDirichletFC
from kyle.evaluation import EvalStats, compute_accuracy, compute_ECE, compute_expected_max
from kyle.transformations import *
from kyle.calibration.calibration_methods import TemperatureScaling

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import pytorch_lightning as pl

import scipy.stats
import scipy.optimize
import numpy as np
import matplotlib.pyplot as plt

# Realistic Fake Classifiers

It's good to have a model of what a realistic Fake Classifier should look like.

Probably the simplest model for the fake classifier confidence vector distribution is the **Dirichlet Fake Classifier**:

\begin{equation}
C \sim Dirichlet(\alpha_1, \alpha_2, \alpha_3, ...)
\end{equation}

However, this model is possibly a bit too simple as it only has a single local maximum in the distribution. A realistic fake classifier might for example have multiple local maxima in each of the corners of the simplex, i.e. it generally is very confident in its prediction and only very rarely uncertain (center of simplex). Something similar can actually be achieved using the Dirichlet distribution by setting all the parameters $\alpha_n < 1$. This pushes the distribution out into the corners, BUT however also out onto the sides of the simplex, which is not quite what we want. The center of a side of the simplex corresponds to a confidence vector $\vec\alpha = (1/\text{num_classes}-1, 1/\text{num_classes}-1, ..., 1/\text{num_classes}-1, 0)$, i.e. very uncertain in all but one of the classes.

Therefore we also consider two other Fake Classifiers that can have multiple local maxima in each of the corners and therefore possibly represent real neural networks better:

Firstly the **Multi-Dirichlet Fake Classifier**:

\begin{align}
K & \sim Catgeorical(p_1, p_2, p_3, ...) \\
C & \sim Dirichlet_k(\sigma_k\cdot[1, 1, ..., 1, \alpha_k, 1, ...])
\end{align}

i.e. we first draw from a K-categorical distribution and based on the result we then draw from one of K Dirichlet distributions. Each of the K Dirichlet distributions has two parameters $\sigma$ and $\alpha_k$ and represent the width and position of the local maximum in the k-th corner of the simplex.
Note: The pdf of this mixture distribution will be a weighted sum of the individual dirichlet distributions

Secondly the **Multi-Gaussian Fake Classifier**:

K-Categorical followed by one of K Gaussians followed by softmax

(probably doesn't make too much sense as Multi-Gaussian pdf is analytically intractable due to the softmax transformation)

In order to get an actually realistic Fake Classifier we use these three Fake Classifier models and fit their distributions to the observed confidence vector distributions for a couple of different neural networks.

In this case we use:

**LeNet 5** on CIFAR 10

**ResNet 20** on CIFAR 10

**ResNet 110** on CIFAR 10

In [None]:
#load cifar dataset
#normalizarion also from https://github.com/akamaster/pytorch_resnet_cifar10
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

cifar_train_set = datasets.CIFAR10(os.getcwd(), train=True, download=True,
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
cifar_test_set = datasets.CIFAR10(os.getcwd(), train=False, download=True,
                transform=transforms.Compose([transforms.ToTensor(), normalize]))

cifar_train = torch.utils.data.DataLoader(cifar_train_set, batch_size=4, shuffle=True, num_workers=2)
cifar_test = torch.utils.data.DataLoader(cifar_test_set, batch_size=4, shuffle=False, num_workers=2)

In [None]:
#Small simple LeNet5 for CIFAR 10 classification

class lenet5(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def training_step(self, batch, batch_idx):
        x, target = batch
        output = self(x)
        loss = F.cross_entropy(output, target)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, target = batch
        output = self(x)
        loss = F.cross_entropy(output, target)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
            
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        return optimizer
    

In [None]:
#Proper implementation of ResNet18 for Cifar10. Pytorch only has ResNets for ImageNet which
#differ in number of parameters
#Code taken from: https://github.com/akamaster/pytorch_resnet_cifar10

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
    

def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])


def resnet32():
    return ResNet(BasicBlock, [5, 5, 5])


def resnet44():
    return ResNet(BasicBlock, [7, 7, 7])


def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])


def resnet110():
    return ResNet(BasicBlock, [18, 18, 18])


def resnet1202():
    return ResNet(BasicBlock, [200, 200, 200])

In [None]:
#Train a LeNet5
#Load selftrained LeNet 5 and pretrained Resnet20 and Resnet110 (don't have a dedicated GPU at hand D:)
#Pretrained nets taken from https://github.com/akamaster/pytorch_resnet_cifar10

#selftrained_lenet5 = lenet5()
#checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss', save_top_k=1, save_last=True)
#trainer = pl.Trainer(max_epochs=20, logger=False, checkpoint_callback=checkpoint_callback)
#trainer.fit(selftrained_lenet5, cifar_train, cifar_test)

selftrained_lenet5 = lenet5.load_from_checkpoint('./trained_models/lenet5.ckpt')

pretrained_resnet20 = resnet20()
pretrained_resnet110 = resnet110()

pretrained_resnet20_dict = torch.load('./trained_models/resnet20-12fca82f.th',
                               map_location=torch.device('cpu'))['state_dict']
pretrained_resnet20_dict = {key.replace("module.", ""): value for key, value in pretrained_resnet20_dict.items()}
pretrained_resnet20.load_state_dict(pretrained_resnet20_dict)

pretrained_resnet110_dict = torch.load('./trained_models/resnet110-1d1ed7c2.th',
                               map_location=torch.device('cpu'))['state_dict']
pretrained_resnet110_dict = {key.replace("module.", ""): value for key, value in pretrained_resnet110_dict.items()}
pretrained_resnet110.load_state_dict(pretrained_resnet110_dict)

# Set which neural net to fit here

In [None]:
neural_net = selftrained_lenet5
#neural_net = pretrained_resnet20
#neural_net = pretrained_resnet110

In [None]:
#Get NN predictions on CIFAR10 test set

cifar_test_full = torch.utils.data.DataLoader(cifar_test_set, batch_size=len(cifar_test_set),
                                              shuffle=False, num_workers=2)
images, labels = next(iter(cifar_test_full))

neural_net.eval()

with torch.no_grad():
    
    logits = neural_net(images)
    prob = F.softmax(logits, dim=1)
    _, predicted = torch.max(prob, dim=1)
    print(f'NLL = {F.cross_entropy(logits, labels)}')
    print(f'accuracy = {(predicted == labels).sum().item() / labels.size(0)}')
    
gt_labels = labels.numpy()
confidences = prob.numpy()


In [None]:
gt_labels_copy = gt_labels.copy()
confidences_copy = confidences.copy()

In [None]:
gt_labels = gt_labels_copy.copy()
confidences = confidences_copy.copy()

In [None]:
print(f'num non normalized confidence vectors = {np.sum((np.sum(confidences, axis=1) - 1) >= 1e-10)}')

#confidences are not perfectly normalized due to floating point error
#scipy.stats.dirichlet.pdf is very picky about normalization
#convert confidences to float64 first for better/more accurate normalization

confidences = np.array(confidences, dtype='float64')
confidences = confidences / np.sum(confidences, axis=1)[:,None]
print(f'num non normalized confidence vectors = {np.sum((np.sum(confidences, axis=1) - 1) >= 1e-10)}')


# Fitting Fake Classifiers using MLE

Having gotten the confidences of our neural net on the CIFAR 10 test set we can now try and fit an appropiate fake classifier to them. This can be done quite easily for the Dirichlet and Multi-Dirichlet FC's using MLE, as we have relatively simple expressions for the distribution pdf's. The Multi-Gaussian FC is not as easy as the softmax function complicates the fake classifier's pdf. (It would be necessary to invert the softmax function, which is only possible up to an additive constant. As a result the Multi-Gaussian fake classifier's pdf will be an integral over a gaussian mixture model's pdf.)

MLE is probably the easiest and simplest approach when the fake classifier's pdf is known exactly. To this end we minimize the negative log likelihood of the neural net's predicted confidence vectors under assumption of either a Dirichlet or Multi-Dirichlet distribution.

As always with fitting the choice of minimization algorithm, initial guesses and bounds is important:

As discussed at the start a somewhat alright fake classifier can possibly be achieved by using a simple Dirichlet Fake Classifier with alpha parameters $\alpha_n < 1$. For fitting the DirichletFC appropiate initial guesses and bounds might therefore be $\alpha_\text{init} = (1,1,1,1,...)$  and $\alpha_\min, \alpha_\max = (0.0001, \text{None})$

As dicussed at the start the reasoning behind the Multi-Dirichlet FC is that each separate Dirichlet can be used to create a local maximum in one of the corners of the simplex. This only works if the full alpha vector of each dirichlet has all entries $>1$ (if any entry is $<1$ a local maximum does not exist), which means for each Dirichlet wee need $\alpha_k >1$ and $\sigma_k>1$. We also expect the maxima to be very 'squished' into the corners. i.e. $\alpha_k$ to be large. For fitting the Multi-Dirichlet FC appropate initial guesses and bounds might therefore be $\alpha_\text{init} = (10,10,10,10,...)$ $\sigma_\text{init} = (2,2,2,2,...)$ (squished into corners) and $\alpha_\min, \alpha_\max = (1, \text{None})$ $\sigma_\min, \sigma_\max = (1, \text{None})$

In [None]:
NNet_eval = EvalStats(gt_labels, confidences)
NNet_eval.plot_confidence_distributions([0, 1,"top_class"])

In [None]:
#Fit Dirichlet and Multi-Dirichlet FC's to the test set confidence vector distributions using MLE fitting

num_classes = confidences.shape[1]

Dir_FC = DirichletFC(num_classes)
MultiDir_FC = MultiDirichletFC(num_classes)

Dir_NLL = lambda parm: -np.sum(np.log( Dir_FC.pdf(confidences, parm) ))
MultiDir_NLL = lambda parm: -np.sum(np.log( MultiDir_FC.pdf(confidences, *np.split(parm,3)) ))


#initial guesses and bounds for fitting simple dirichlet
init_alpha = 1*np.ones(num_classes)
bounds_alpha = [(0.001, None)] * num_classes

Dirichlet_bestfit = scipy.optimize.minimize(Dir_NLL, init_alpha, bounds=bounds_alpha, options={'disp': True})


#initial guesses and bounds for fitting multi-dirichlet
init_alpha = 10*np.ones(num_classes)
init_sigma = 2*np.ones(num_classes)
init_distribution_weights = np.ones(num_classes) / num_classes
bounds_alpha = [(1, None)] * num_classes
bounds_sigma = [(1, None)] * num_classes
bounds_distribution_weights = [(0, 1)] * num_classes

MultiDirichlet_bestfit = scipy.optimize.minimize(MultiDir_NLL,
                                                 np.concatenate((init_alpha, init_sigma, init_distribution_weights)),
                                                 bounds=bounds_alpha + bounds_sigma + bounds_distribution_weights,
                                                 options={'disp': True})


In [None]:
#Set FC parameters to those found from MLE fit

Dir_FC.set_alpha(Dirichlet_bestfit.x)
MultiDir_FC.set_parameters(*np.split(MultiDirichlet_bestfit.x,3))

In [None]:
Dir_FC_eval = EvalStats(*Dir_FC.get_sample_arrays(50000))

MultiDir_FC_eval = EvalStats(*MultiDir_FC.get_sample_arrays(50000))

NNet_eval = EvalStats(gt_labels, confidences)

Dir_FC_eval.plot_confidence_distributions([0, "top_class"])

MultiDir_FC_eval.plot_confidence_distributions([0,"top_class"])

NNet_eval.plot_confidence_distributions([0,"top_class"])

#for i in range(confidences.shape[1]):
for i in range(3):

    Dir_FC_eval.plot_confidence_distributions([i])

    MultiDir_FC_eval.plot_confidence_distributions([i])

    NNet_eval.plot_confidence_distributions([i])


Above we have plotted the top class confidence distribution as well as the marginal confidence distributions for the actual neural network and our fitted fake classifiers (the full 10D confidence vector distribution is unfortuanetly a teeny bit difficult to visualise). This allows us to visually inspect how well our fake classifiers really reflect the true confidence distribution of our classifier. The 1st, 4th, ... graphs are the simple dirichlet FC. The 2nd, 5th, ... graphs are the multi-dirichlet FC. The 3rd, 6th, ... graphs are the real neural network confidence distributions.

As expected the simple dirichlet FC does not work well. It doesn't capture the multimodal nature of the true marginal distribution nor the high frequency of p=1.0 confidences in the top class confidence distributions.

The Multi-dirichlet FC seems to be a bit better capturing both of these qualities. However the maxima of the distributions don't lie far enough at the extremes, which is likely due to the fact that the local maximum of a dirichlet only lies exactly on the corner in the limit of $\alpha_k \rightarrow \infty$

TODO: Figure out why fitted sigma values always hit the lower bound of $\sigma_k = 1$

TODO: Split fitting of the multi dirichlet FC into fitting of singular dirichlets by assuming almost sufficient confidence and estimating distribution weights from the neural network class prediction ratios. This would turn the fitting of 30 parameters of the multi dirichlet FC into 10 fittings of 2 parameters each (only alpha and sigma, the distribution weights are estimated as stated)

TODO: Try other fitting methods: e.g. moment matching or fitting only to the marginal distributions

TODO: Try 'stochastic fitting' for the multi-gaussian FC where we sample the FC calculate the corrseponding marginal distributions and minimize e.g. the squared error loss between the neural network marginal distributions and the marginal distributions sampled from the FC. Could get around problem of multi-gaussian FC having analytically intractable pdf

In [None]:
print(Dirichlet_bestfit.x)
print(*np.split(MultiDirichlet_bestfit.x,3))