In [None]:
import numpy as np
from scipy.optimize import minimize
from sklearn.metrics import log_loss
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve

######This function is to find the temperature that reduces the log loss.

##Replace these with you data
all_train_logits = all_train_logits_df.to_numpy()
y_train = y_train_df.to_numpy().ravel()



def temperature_scaled_softmax(logits, temperature):
    scaled_logits = logits / temperature
    exp_logits = np.exp(scaled_logits - np.max(scaled_logits, axis=-1, keepdims=True))
    return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)

def stratified_brier_score_per_class(y_true, y_prob):
    brier_scores = []
    for i in range(y_prob.shape[1]):
        pos_indices = (y_true == i)
        neg_indices = (y_true != i)

        N_pos = np.sum(pos_indices)
        N_neg = np.sum(neg_indices)

        if N_pos > 0:
            brier_score_pos = np.sum((1 - y_prob[pos_indices, i])**2) / N_pos
        else:
            brier_score_pos = 0

        if N_neg > 0:
            brier_score_neg = np.sum((y_prob[neg_indices, i])**2) / N_neg
        else:
            brier_score_neg = 0

        stratified_brier = (brier_score_pos + brier_score_neg) / 2
        brier_scores.append(stratified_brier)
    return brier_scores

def optimize_temperature(logits, labels):
    def loss_to_minimize(temperature):
        probs = temperature_scaled_softmax(logits, temperature)

        return log_loss(labels, probs)

    # Find the temperature that minimizes the loss function
    result = minimize(loss_to_minimize, x0=1.5, bounds=[(0.1, 10.0)], method='L-BFGS-B')
    return result.x[0]


#Best practice is to use cross validation here
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

optimal_temperatures = []

out_of_sample_calibrated_probs = np.zeros_like(all_train_logits)

print(f"--- Starting Corrected {n_splits}-Fold Cross-Validation ---")

for fold, (train_index, val_index) in enumerate(kf.split(all_train_logits)):
    train_logits, val_logits = all_train_logits[train_index], all_train_logits[val_index]
    train_labels, val_labels = y_train[train_index], y_train[val_index]


    optimal_temp = optimize_temperature(train_logits, train_labels)
    optimal_temperatures.append(optimal_temp)
    print(f"Fold {fold+1}/{n_splits} | Optimal Temperature found on TRAIN set: {optimal_temp:.4f}")


    calibrated_val_probs = temperature_scaled_softmax(val_logits, optimal_temp)


    out_of_sample_calibrated_probs[val_index] = calibrated_val_probs

print("\n--- Cross-Validation Complete ---")

##Final Robust Metrics and Analysis##


print("\n--- Final Metrics (Calculated from Robust Out-of-Sample Predictions) ---")

#Note: This temperature can also be used to calibrate the entire training set logits
best_overall_temperature = np.mean(optimal_temperatures)
print(f"Average Optimal Temperature across all folds: {best_overall_temperature:.4f}")




final_log_losses = [log_loss((y_train == i).astype(int), out_of_sample_calibrated_probs[:, i]) for i in range(out_of_sample_calibrated_probs.shape[1])]
final_stratified_brier_scores = stratified_brier_score_per_class(y_train,out_of_sample_calibrated_probs)




final_log_loss_overall = log_loss(y_train, out_of_sample_calibrated_probs)
final_stratified_brier_scores = stratified_brier_score_per_class(y_train, out_of_sample_calibrated_probs)



print(f"Overall Mean Stratified Brier Score: {np.mean(final_stratified_brier_scores):.4f}")

print('Final metrics with best temperature:')
for i in range(len(final_log_losses)):
    print(f'Class {i}:')
    print(f'    Log-loss: {final_log_losses[i]:.4f}')
    print(f'    Stratified Brier score: {final_stratified_brier_scores[i]:.4f}')



####if the plot is to be drawn##.
def plot_calibration_curve(y_true, probs, title):
    plt.figure(figsize=(12, 8))
    for i in range(probs.shape[1]):
        # Use y_true directly, no need for argmax if it's already 1D
        prob_true, prob_pred = calibration_curve((y_true == i), probs[:, i], n_bins=10)
        plt.plot(prob_pred, prob_true, marker='o', label=f'Class {i}')

    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly Calibrated')
    plt.xlabel('Mean Predicted Probability', fontweight='bold', fontsize=24)
    plt.ylabel('Fraction of Positives', fontweight='bold', fontsize=24)
    plt.title(title, fontweight='bold', fontsize=24)
    plt.xticks(fontsize=24, fontweight='bold')
    plt.yticks(fontsize=24, fontweight='bold')
    plt.legend(fontsize=15)
    plt.grid(True)
    plt.show()

####call the function to plot the curve here
plot_calibration_curve(y_train, out_of_sample_calibrated_probs, 'Calibration Curve (After Temperature Scaling)')