In [33]:
import csv
import pandas as pd
from sklearn.metrics import log_loss


In [65]:
MODEL = 'Mixtral-8x7B-Instruct-v0.1'
CHAIN_NAME = 'base_probs'

In [66]:
results_path = f"../results/{MODEL}/output/{CHAIN_NAME}.csv"
eval_path = f"../results/{MODEL}/eval/{CHAIN_NAME}_eval.csv"

In [67]:
probabilities = ('probs' in CHAIN_NAME)

In [68]:
logloss, f1, recall, precision, roc_auc = (0, 0,0,0,0)

def rate(path, threshold=.5):
    df = pd.read_csv(path)
    # Extract predictions
    pred = df['pred']
    truth = df['truth'].astype(int)
    print(len(df))
    if 'probs' in path:
        logloss = log_loss(truth, pred)
        print(f'Log loss: {logloss}')
        pred = [i >= threshold for i in pred]
        return evaluate(pred, truth)
    else:
        df = df[df['pred'].isin(['True', 'False'])]
        df['pred'] = df['pred'].astype(bool)
        pred = df['pred']
        truth = df['truth'].astype(int)
        f1, recall, precision, roc_auc = evaluate(pred, truth)

In [69]:
f1, recall, precision, roc_auc = rate(results_path, threshold=.7)

498
Log loss: 5.496333391694868
F1 Score: 0.6926536731634183
Recall: 0.5620437956204379
Precision: 0.90234375
ROC AUC: 0.6373437368906788


In [64]:
# Save metrics to a separate CSV file
with open(eval_path, 'w', newline='') as metrics_csvfile:
    fieldnames_metrics = ['Metric', 'Value']
    writer_metrics = csv.DictWriter(metrics_csvfile, fieldnames=fieldnames_metrics)
    writer_metrics.writeheader()
    writer_metrics.writerow({'Metric': 'F1 Score', 'Value': f1})
    writer_metrics.writerow({'Metric': 'Recall', 'Value': recall})
    writer_metrics.writerow({'Metric': 'Precision', 'Value': precision})
    writer_metrics.writerow({'Metric': 'ROC AUC', 'Value': roc_auc})

In [40]:
from sklearn.metrics import f1_score, recall_score, precision_score, roc_auc_score

def evaluate(pred, truth):

    # Calculate F1 score
    f1 = f1_score(truth, pred)

    # Calculate recall
    recall = recall_score(truth, pred)

    # Calculate precision
    precision = precision_score(truth, pred)

    # Calculate ROC AUC
    try:
        roc_auc = roc_auc_score(truth, pred)
    except Exception as e:
        print(e)
        roc_auc = None

    # Print or use the calculated metrics
    print(f"F1 Score: {f1}")
    print(f"Recall: {recall}")
    print(f"Precision: {precision}")
    print(f"ROC AUC: {roc_auc}")
    return f1, recall, precision, roc_auc