In [None]:
%load_ext autoreload
%autoreload 2
from kale.sampling.fake_clf import DirichletFC, SufficientlyConfidentFC, SufficientlyConfidentFCBuilder
from kale.evaluation import EvalStats
from kale.transformations import *
import numpy as np
import matplotlib.pyplot as plt

In [None]:
n_classes = 3
n_samples = 15000
sufficiently_confident_fc = SufficientlyConfidentFCBuilder(n_classes).build()
eval_stats = EvalStats(*sufficiently_confident_fc.get_sample_arrays(n_samples))

In [None]:
# this should converge to 1/2(1 + 1/num_classes)
eval_stats.accuracy()

In [None]:
# barplot does something very strange when x-values are floats...
# the zero values do not contribute to any calibration metric, there are simply no data points in that region

# the default case is perfectly calibrated, ECE converges to zero
plt.plot(np.arange(eval_stats.bins)/eval_stats.bins, np.arange(eval_stats.bins)/eval_stats.bins, marker="o")
plt.plot(*eval_stats.top_class_reliability_hist(), marker="o")
plt.plot(*eval_stats.marginal_reliability_hist(0), marker="o")
print(f"ECE is {eval_stats.expected_calibration_error()}")

In [None]:
overestimating_aut0 = MaxComponentSimplexAutomorphism(sufficiently_confident_fc.num_classes, lambda x: x**7)
overestimating_aut1 = MaxComponentSimplexAutomorphism(sufficiently_confident_fc.num_classes, lambda x: x**2)
sufficiently_confident_fc.set_simplex_automorphism(0, overestimating_aut0)
sufficiently_confident_fc.set_simplex_automorphism(1, overestimating_aut1)

In [None]:
eval_stats_overestimating = EvalStats(*sufficiently_confident_fc.get_sample_arrays(n_samples))

In [None]:
print(eval_stats_overestimating.expected_calibration_error())
print(eval_stats_overestimating.accuracy())

In [None]:
plt.plot(np.arange(eval_stats_overestimating.bins)/eval_stats_overestimating.bins, np.arange(eval_stats_overestimating.bins)/eval_stats_overestimating.bins, marker="o")
plt.plot(*eval_stats_overestimating.marginal_reliability_hist(0), marker="o")
plt.plot(*eval_stats_overestimating.marginal_reliability_hist(1), marker="o")

In [None]:
diriclet_fc = DirichletFC(n_classes)
dirichlet_eval_stats = EvalStats(*diriclet_fc.get_sample_arrays(n_samples))


In [None]:
print(dirichlet_eval_stats.accuracy())
print(dirichlet_eval_stats.expected_calibration_error())

In [None]:
plt.plot(np.arange(dirichlet_eval_stats.bins)/dirichlet_eval_stats.bins, np.arange(dirichlet_eval_stats.bins)/dirichlet_eval_stats.bins, marker="o")
plt.plot(*dirichlet_eval_stats.top_class_reliability_hist(), marker="o")