In [2]:
import pandas as pd
import numpy as np
from scipy.special import log_softmax, softmax, logsumexp
from scipy.optimize import minimize
import matplotlib.pyplot as plt

In [16]:

class IterativeCalibration:

    def __init__(self, num_classes, tolerance=1e-5):
        self.num_classes = num_classes
        self.tolerance = tolerance
        self.alpha = 1
        self.beta = np.zeros(num_classes)

    def fit(self, alpha_train_logprobs, alpha_train_labels, beta_train_logprobs, beta_train_labels):
        last_loss = float('inf')
        for i in range(100):
            new_alpha = self._next_alpha(alpha_train_logprobs, alpha_train_labels, self.alpha, self.beta)
            new_beta = self._next_beta(beta_train_logprobs, beta_train_labels, new_alpha, self.beta)
            loss = np.abs(new_alpha - self.alpha) + np.linalg.norm(new_beta - self.beta)
            self.alpha = new_alpha
            self.beta = new_beta
            if np.abs(last_loss - loss) < self.tolerance:
                break
            last_loss = loss
    
    def _next_alpha(self, logprobs, labels, alpha, beta):
        
        def compute_alpha_loss(a):
            ce = -np.mean(logprobs[np.arange(len(logprobs)), labels])
            calprobs = softmax(a * logprobs + beta, axis=1)
            soft_ce = -np.mean(np.sum(calprobs * logprobs, axis=1))
            return np.abs(soft_ce - ce)

        res = minimize(compute_alpha_loss, alpha, method='L-BFGS-B', tol=self.tolerance)
        return res.x

    def _next_beta(self, logprobs, labels, alpha, beta):
        logpriors = np.log(np.bincount(labels) / len(labels))
        logmean = np.log(np.mean(np.exp(alpha * logprobs) / np.sum(np.exp(alpha * logprobs + beta), axis=1, keepdims=True), axis=0))
        return logpriors - logmean
    
    def calibrate(self, logprobs):
        return log_softmax(self.alpha * logprobs + self.beta, axis=1)

# dataset = "sst2"
# size = 128
# test_list = "test_400"
dataset = "banking77"
size = 616
test_list = "test_1000"
seed = 2
train_logits = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct/{dataset}_{size}_0.3_{seed}/test={dataset}/list=train_{size}_0.0_{seed}/logits.csv", index_col=0, header=None).values.astype(float)
train_logprobs = log_softmax(train_logits, axis=1)
train_labels = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct/{dataset}_{size}_0.3_{seed}/test={dataset}/list=train_{size}_0.0_{seed}/labels.csv", index_col=0, header=None).values.astype(int).flatten()
val_logits = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct/{dataset}_{size}_0.3_{seed}/test={dataset}/list=val_{size}_0.3_{seed}/logits.csv", index_col=0, header=None).values.astype(float)
val_logprobs = log_softmax(val_logits, axis=1)
val_labels = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct/{dataset}_{size}_0.3_{seed}/test={dataset}/list=val_{size}_0.3_{seed}/labels.csv", index_col=0, header=None).values.astype(int).flatten()
test_logits = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct/{dataset}_{size}_0.3_{seed}/test={dataset}/list={test_list}/logits.csv", index_col=0, header=None).values.astype(float)
test_logprobs = log_softmax(test_logits, axis=1)
test_labels = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct/{dataset}_{size}_0.3_{seed}/test={dataset}/list={test_list}/labels.csv", index_col=0, header=None).values.astype(int).flatten()

num_classes = train_logits.shape[1]
calibrator = IterativeCalibration(num_classes, tolerance=1e-5)
calibrator.fit(val_logprobs, val_labels, train_logprobs, train_labels)
test_calibrated_logprobs = calibrator.calibrate(test_logits)

test_ce = -np.mean(test_logprobs[np.arange(len(test_logprobs)), test_labels])
test_ce_priors = -np.mean(np.log((np.bincount(test_labels) / len(test_labels))[test_labels]))
print(f"Test NCE: {test_ce/test_ce_priors:.4f}")

test_calibrated_ce = -np.mean(test_calibrated_logprobs[np.arange(len(test_calibrated_logprobs)), test_labels])
test_calibrated_ce_priors = -np.mean(np.log((np.bincount(test_labels) / len(test_labels))[test_labels]))
print(f"Test calibrated NCE: {test_calibrated_ce/test_calibrated_ce_priors:.4f}")

Test NCE: 0.4165
Test calibrated NCE: 0.2400


In [11]:
test_logits = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct_all_train/{dataset}_{size}_0.0_{seed}/test={dataset}/list={test_list}/logits.csv", index_col=0, header=None).values.astype(float)
test_logprobs = log_softmax(test_logits, axis=1)
test_labels = pd.read_csv(f"../outputs/adaptation/llama3.2-1b/lora_ans_instruct_all_train/{dataset}_{size}_0.0_{seed}/test={dataset}/list={test_list}/labels.csv", index_col=0, header=None).values.astype(int).flatten()

test_ce = -np.mean(test_logprobs[np.arange(len(test_logprobs)), test_labels])
test_ce_priors = -np.mean(np.log((np.bincount(test_labels) / len(test_labels))[test_labels]))
print(f"Test NCE: {test_ce/test_ce_priors:.4f}")

Test NCE: 0.3862
