In [None]:
%load_ext autoreload
%autoreload 2
from kyle.sampling.fake_clf import DirichletFC, SufficientlyConfidentFC, SufficientlyConfidentFCBuilder
from kyle.evaluation import EvalStats
from kyle.transformations import *
import numpy as np
import matplotlib.pyplot as plt

In [None]:
n_classes = 3
n_samples = 15000

## Sufficiently confident fake classifiers

This model has the advantage that several quantities can be computed analytically. However, several
transformations introduce

In [None]:
sufficiently_confident_fc = SufficientlyConfidentFCBuilder(n_classes).build()
eval_stats = EvalStats(*sufficiently_confident_fc.get_sample_arrays(n_samples))
# this should converge to 1/2(1 + 1/num_classes)
print(f"Accuracy is {eval_stats.accuracy()}")
# the default case is perfectly calibrated, ECE converges to zero
print(f"ECE is {eval_stats.expected_calibration_error()}")

eval_stats.plot_reliability_curves([0, EvalStats.TOP_CLASS_LABEL])


In [None]:
print("Overestimating classes 0 and 1")

overestimating_aut0 = MaxComponentSimplexAutomorphism(sufficiently_confident_fc.num_classes, lambda x: x**7)
overestimating_aut1 = MaxComponentSimplexAutomorphism(sufficiently_confident_fc.num_classes, lambda x: x**2)
overestimating_fc = SufficientlyConfidentFCBuilder(n_classes).\
    with_simplex_automorphisms([overestimating_aut0, overestimating_aut1, None]).build()

eval_stats = EvalStats(*overestimating_fc.get_sample_arrays(n_samples))

print(f"Accuracy is {eval_stats.accuracy()}")
print(f"ECE is {eval_stats.expected_calibration_error()}")
eval_stats.plot_reliability_curves([0, 1, eval_stats.TOP_CLASS_LABEL])


## Dirichlet fake classifiers

This model is not sufficiently confident but has reduced complexity. In particular, only one simplex
automorphism has to be defined and it is difficult to produce such spurious correlations between classes
as in the sufficiently confident model.

On the downside, accuracy and ECE are hard to compute analytically...

In [None]:
dirichlet_fc = DirichletFC(n_classes)

In [None]:
print("mostly overestimating all classes (starting at 1/n_classes)")

overestimating_power_aut = PowerLawSimplexAutomorphism(np.array([2, 2, 2]))
dirichlet_fc.set_simplex_automorphism(overestimating_power_aut)
eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))

print(f"Accuracy is {eval_stats.accuracy()}")
print(f"ECE is {eval_stats.expected_calibration_error()}")
eval_stats.plot_reliability_curves([0, 1, EvalStats.TOP_CLASS_LABEL])

In [None]:
print("mostly underestimating all classes (starting at 1/n_classes)")

underestimating_power_aut = PowerLawSimplexAutomorphism(np.array([0.3, 0.3, 0.3]))
dirichlet_fc.set_simplex_automorphism(underestimating_power_aut)
eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))

print(f"Accuracy is {eval_stats.accuracy()}")
print(f"ECE is {eval_stats.expected_calibration_error()}")
eval_stats.plot_reliability_curves([0, 1, EvalStats.TOP_CLASS_LABEL])


In [None]:
print("Overestimating predictions")

overestimating_predicted_class = MaxComponentSimplexAutomorphism(n_classes, lambda x: x/2)
dirichlet_fc.set_simplex_automorphism(overestimating_predicted_class)
eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))

print(f"Accuracy is {eval_stats.accuracy()}")
print(f"ECE is {eval_stats.expected_calibration_error()}")
eval_stats.plot_reliability_curves([0, 1, EvalStats.TOP_CLASS_LABEL])

In [None]:
print("Overestimating class 0")

overestimating_0 = SingleComponentSimplexAutomorphism(n_classes, 0, lambda x: x/2)
dirichlet_fc.set_simplex_automorphism(overestimating_0)
eval_stats = EvalStats(*dirichlet_fc.get_sample_arrays(n_samples))

print(f"Accuracy is {eval_stats.accuracy()}")
print(f"ECE is {eval_stats.expected_calibration_error()}")
eval_stats.plot_reliability_curves([0, 1, EvalStats.TOP_CLASS_LABEL])