# Calibration curve

In [16]:
import numpy as np

# Machine Learning
from sklearn.datasets import make_classification
from sklearn.calibration import calibration_curve
from sklearn.ensemble import RandomForestClassifier
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression

# Data visualization
import matplotlib.pyplot as plt

In [None]:
def expected_calibration_error(y, proba, bins='fd'):
    

In [4]:
# Create dataset
X, y = make_classification(
    n_samples = 15000,
    n_features = 50,
    n_informative = 30,
    n_redundant = 20,
    weights = [0.9, 0.1],
    random_state = 0
)

# Create train, validation, test datasets
X_train, X_valid, X_test = X[:5000], X[5000:10000], X[10000:]
y_train, y_valid, y_test = y[:5000], y[5000:10000], y[10000:]

#### Random Forest Classifier

In [13]:
# Train and fit a Random Forest Classifier to our training data
forest = RandomForestClassifier().fit(X_train, y_train)

print(f"Random Forest Classifier classes = {forest.classes_}")

# Get the class probabilities for the positive class only (second column) on our validation data
proba_valid = forest.predict_proba(X_valid)[:, 1]

Random Forest Classifier classes = [0 1]


#### Isotonic Regression

In [14]:
# Piece-wise linear model (function must be monotonic) and we fit this to our validation positive class probability
iso_reg = IsotonicRegression(y_min = 0,
                             y_max = 1,
                             out_of_bounds = 'clip').fit(proba_valid, y_valid)

# Predict probabilities on test set
proba_test_forest_isoreg = iso_reg.predict(forest.predict_proba(X_test)[:, 1])

#### Logistic Regression

In [10]:
# Fit Logistic regression model to our validation positive class probability
log_reg = LogisticRegression().fit(proba_valid.reshape(-1, 1), y_valid)

proba_test_forest_logreg = log_reg.predict_proba(forest.predict_proba(X_test)[:, 1].reshape(-1, 1))[:, 1]

array([0, 1])

In [9]:
X_valid

array([[ -0.86181522, -10.16068478,   4.79682908, ...,   3.86521852,
         -2.75142771,  -1.15713864],
       [  2.84748018,  -0.94559916,  -3.76606529, ...,   0.35651591,
        -11.27550557,   0.7618289 ],
       [  1.79590375,   6.62760703,   3.46470833, ...,  -4.59172565,
         -6.7055294 ,   1.59135773],
       ...,
       [ -5.52164103,   9.82889974,  -5.55077572, ...,  -6.01953906,
          2.58728136,   1.25963628],
       [ -5.78350537, -10.0472246 ,  -3.04950453, ...,  -1.6825213 ,
         -3.84211095,  -1.37807477],
       [ -2.30487826,  -0.9325564 ,  -0.7018942 , ...,  -1.28766483,
         16.02023276,   0.63219399]])