In [14]:
import sys
import os
sys.path.append("../src")
import llm_utils
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import ast
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report

classes = ["War/Terror", "Conspiracy Theory", "Education", "Election Campaign", "Environment", 
              "Government/Public", "Health", "Immigration/Integration", 
              "Justice/Crime", "Labor/Employment", 
              "Macroeconomics/Economic Regulation", "Media/Journalism", "Religion", "Science/Technology"]

def calculate_binary_metrics(df, classes, extraction_function):
    predictions_per_class = []
    # Iterate through class labels and extract binary predictions
    for idx, label in enumerate(classes):
        pred_column_name = f"testbinary_war_v01"
        try:
            pred_column_df = df[df[pred_column_name].notna()].copy()
            pred_column_df[pred_column_name] = pred_column_df[pred_column_name].apply(extraction_function)
            predictions_per_class.append(pred_column_df)
        
        #Skip if the column (for example Others_pred) does not exist
        except KeyError:
            predictions_per_class.append(None)

    confusion_matrices = {}
    classification_reports = {}
    for idx, label in enumerate(classes):
        pred_column_name = f"testbinary_war_v01"

        current_df = predictions_per_class[idx]
        
        # Ignore rows with NaN or invalid values in the predictions
        try:
            valid_rows = current_df[pred_column_name].notna()
            
            y_true = current_df.loc[valid_rows, 'annotations'].apply(lambda x: int(label in x))
            y_pred = current_df.loc[valid_rows, pred_column_name].astype(int)
        except KeyError:
            y_true = []
            y_pred = []
        except TypeError:
            y_true = []
            y_pred = []
        cm = confusion_matrix(y_true, y_pred)
        confusion_matrices[label] = cm
        cr = classification_report(y_true, y_pred, output_dict=True)
        classification_reports[label] = cr

    return predictions_per_class, confusion_matrices, classification_reports

binary_war_lora_df = pd.read_csv("../data/vicuna_4bit/lora/binary_war_v01/test_generic_test_0.csv")
extraction_function = llm_utils.get_extraction_function("extract_using_class_token", 1)
_, confusion_matrices, classification_reports = calculate_binary_metrics(binary_war_lora_df, ["War/Terror"], extraction_function)
binary_war_lora = {"confusion_matrices": confusion_matrices, "classification_reports": classification_reports}

In [19]:
binary_war_lora

{'confusion_matrices': {'War/Terror': array([[711,  34],
         [ 22, 233]])},
 'classification_reports': {'War/Terror': {'0': {'precision': 0.9699863574351978,
    'recall': 0.9543624161073826,
    'f1-score': 0.9621109607577808,
    'support': 745},
   '1': {'precision': 0.8726591760299626,
    'recall': 0.9137254901960784,
    'f1-score': 0.89272030651341,
    'support': 255},
   'accuracy': 0.944,
   'macro avg': {'precision': 0.9213227667325802,
    'recall': 0.9340439531517305,
    'f1-score': 0.9274156336355954,
    'support': 1000},
   'weighted avg': {'precision': 0.9451679261768628,
    'recall': 0.944,
    'f1-score': 0.9444163439254663,
    'support': 1000}}}}

In [17]:
binary_war_df = pd.read_csv("../data/vicuna_4bit/generic_prompt_without_context_only_classification/generic_test_0.csv")
extraction_function = llm_utils.get_extraction_function("extract_using_class_token", 1)
binary_war_predictions_per_class, confusion_matrices, classification_reports = llm_utils.calculate_binary_metrics(binary_war_df, ["War/Terror"], extraction_function)
binary_war = {"confusion_matrices": confusion_matrices, "classification_reports": classification_reports}

In [18]:
binary_war

{'confusion_matrices': {'War/Terror': array([[54, 11],
         [13, 51]])},
 'classification_reports': {'War/Terror': {'0': {'precision': 0.8059701492537313,
    'recall': 0.8307692307692308,
    'f1-score': 0.8181818181818182,
    'support': 65},
   '1': {'precision': 0.8225806451612904,
    'recall': 0.796875,
    'f1-score': 0.8095238095238094,
    'support': 64},
   'accuracy': 0.813953488372093,
   'macro avg': {'precision': 0.8142753972075109,
    'recall': 0.8138221153846155,
    'f1-score': 0.8138528138528138,
    'support': 129},
   'weighted avg': {'precision': 0.8142110154404273,
    'recall': 0.813953488372093,
    'f1-score': 0.8138863720259069,
    'support': 129}}}}