In [None]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
from sklearn.calibration import CalibrationDisplay

from lllm.classification_utils import Classifier, create_datasets_paired_questions as create_datasets
from simpleFacts import SimpleFacts

In [None]:
rng=np.random.RandomState(42)

In [None]:
simple_facts = SimpleFacts()

In [None]:
logprobs_dataset = simple_facts.get_logprobs_differences(return_setup=True)
probs_dataset = simple_facts.get_probs_differences(return_setup=True)

In [None]:
# add a column with lie instruction:
logprobs_dataset["lie_instruction"] = [elem["lie_instruction"] for elem in
                                       logprobs_dataset["llama-2-7b-chat_probes_setup"]]
# add a column with truth instruction:
logprobs_dataset["truth_instruction"] = [elem["truth_instruction"] for elem in
                                         logprobs_dataset["llama-2-7b-chat_probes_setup"]]

In [None]:
lie_instructions_list = logprobs_dataset["lie_instruction"].unique()
truth_instructions_list = logprobs_dataset["truth_instruction"].unique()

In [None]:
# convert to numpy array
logprobs_dataset.iloc[:, 0] = logprobs_dataset.iloc[:, 0].apply(lambda x: np.array(x))
logprobs_dataset.iloc[:, 1] = logprobs_dataset.iloc[:, 1].apply(lambda x: np.array(x))
probs_dataset.iloc[:, 0] = probs_dataset.iloc[:, 0].apply(lambda x: np.array(x))
probs_dataset.iloc[:, 1] = probs_dataset.iloc[:, 1].apply(lambda x: np.array(x))

In [None]:
X_train_logprobs, X_test_logprobs, train_instructions, test_instructions, train_datasets, test_datasets, X_train_probs, X_test_probs, y_train, y_test = create_datasets(logprobs_dataset, probs_dataset, rng=rng)

In [None]:
log_reg_classifier = Classifier(X_train_logprobs, y_train, random_state=rng)
accuracy, auc, conf_matrix = log_reg_classifier.evaluate(X_test_logprobs, y_test)

print(f"Accuracy: {accuracy}")
print(f"AUC: {auc}")
print(f"Confusion matrix: {conf_matrix}")

In [None]:
y_pred_proba = log_reg_classifier.predict_proba(X_test_logprobs)
calibration_display = CalibrationDisplay.from_predictions(y_test, y_pred_proba, n_bins=20, name="LogReg", strategy="quantile")

In [None]:
plt.hist(y_pred_proba, bins=20)
plt.show()

In [None]:
with open(f"trained_classifiers/logistic_logprobs_classifier.pkl", "wb") as f:
        pickle.dump(log_reg_classifier, f)