In [1]:
import sys
import os
sys.path.append("../src")
import llm_utils
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import ast
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report

classes = ["War/Terror", "Conspiracy Theory", "Education", "Election Campaign", "Environment", 
              "Government/Public", "Health", "Immigration/Integration", 
              "Justice/Crime", "Labor/Employment", 
              "Macroeconomics/Economic Regulation", "Media/Journalism", "Religion", "Science/Technology"]

def calculate_binary_metrics(df, class_short, class_full, extraction_function):
    prediction_per_class = None
    # Iterate through class labels and extract binary predictions
    pred_column_name = f"testbinary_{class_short}_v01"
    pred_column_df = df[df[pred_column_name].notna()].copy()
    pred_column_df[pred_column_name] = pred_column_df[pred_column_name].apply(extraction_function)
    prediction_per_class = pred_column_df

    confusion_matrices = {}
    classification_reports = {}
    pred_column_name = f"testbinary_{class_short}_v01"

    current_df = prediction_per_class
    
    # Ignore rows with NaN or invalid values in the predictions
    try:
        valid_rows = current_df[pred_column_name].notna()
        
        y_true = current_df.loc[valid_rows, 'annotations'].apply(lambda x: int(class_full in x))
        y_pred = current_df.loc[valid_rows, pred_column_name].astype(int)
    except KeyError:
        y_true = []
        y_pred = []
    except TypeError:
        y_true = []
        y_pred = []
    cm = confusion_matrix(y_true, y_pred)
    confusion_matrices[class_full] = cm
    cr = classification_report(y_true, y_pred, output_dict=True)
    classification_reports[class_full] = cr

    return prediction_per_class, confusion_matrices, classification_reports

binary_war_lora_df = pd.read_csv("../data/vicuna_4bit/lora/binary_war_v01/test_generic_test_0.csv")
extraction_function = llm_utils.get_extraction_function("extract_using_class_token", 1)
_, confusion_matrices, classification_reports = calculate_binary_metrics(binary_war_lora_df, "war", "War/Terror", extraction_function)
binary_war_lora = {"confusion_matrices": confusion_matrices, "classification_reports": classification_reports}

binary_conspiracy_lora_df = pd.read_csv("../data/vicuna_4bit/lora/binary_conspiracy_theory_v01/test_generic_test_0.csv")
extraction_function = llm_utils.get_extraction_function("extract_using_class_token", 1)
_, confusion_matrices, classification_reports = calculate_binary_metrics(binary_conspiracy_lora_df,"conspiracy_theory", "Conspiracy Theory", extraction_function)
binary_conspiracy_lora = {"confusion_matrices": confusion_matrices, "classification_reports": classification_reports}

binary_education_lora_df = pd.read_csv("../data/vicuna_4bit/lora/binary_education_v01/test_generic_test_0.csv")
extraction_function = llm_utils.get_extraction_function("extract_using_class_token", 1)
_, confusion_matrices, classification_reports = calculate_binary_metrics(binary_education_lora_df,"education", "Education", extraction_function)
binary_education_lora = {"confusion_matrices": confusion_matrices, "classification_reports": classification_reports}

binary_election_campaigns_lora_df = pd.read_csv("../data/vicuna_4bit/lora/binary_election_v01/test_generic_test_0.csv")
extraction_function = llm_utils.get_extraction_function("extract_using_class_token", 1)
_, confusion_matrices, classification_reports = calculate_binary_metrics(binary_election_campaigns_lora_df,"election", "Election Campaign", extraction_function)
binary_election_lora = {"confusion_matrices": confusion_matrices, "classification_reports": classification_reports}

In [2]:
print(binary_war_lora)
print(binary_conspiracy_lora)
print(binary_education_lora)
print(binary_election_lora)

{'confusion_matrices': {'War/Terror': array([[711,  34],
       [ 22, 233]])}, 'classification_reports': {'War/Terror': {'0': {'precision': 0.9699863574351978, 'recall': 0.9543624161073826, 'f1-score': 0.9621109607577808, 'support': 745}, '1': {'precision': 0.8726591760299626, 'recall': 0.9137254901960784, 'f1-score': 0.89272030651341, 'support': 255}, 'accuracy': 0.944, 'macro avg': {'precision': 0.9213227667325802, 'recall': 0.9340439531517305, 'f1-score': 0.9274156336355954, 'support': 1000}, 'weighted avg': {'precision': 0.9451679261768628, 'recall': 0.944, 'f1-score': 0.9444163439254663, 'support': 1000}}}}
{'confusion_matrices': {'Conspiracy Theory': array([[537, 397],
       [  5,  39]])}, 'classification_reports': {'Conspiracy Theory': {'0': {'precision': 0.9907749077490775, 'recall': 0.5749464668094219, 'f1-score': 0.7276422764227642, 'support': 934}, '1': {'precision': 0.08944954128440367, 'recall': 0.8863636363636364, 'f1-score': 0.16250000000000003, 'support': 44}, 'accurac

In [30]:
binary_conspiracy_lora

{'confusion_matrices': {'Conspiracy Theory': array([[537, 397],
         [  5,  39]])},
 'classification_reports': {'Conspiracy Theory': {'0': {'precision': 0.9907749077490775,
    'recall': 0.5749464668094219,
    'f1-score': 0.7276422764227642,
    'support': 934},
   '1': {'precision': 0.08944954128440367,
    'recall': 0.8863636363636364,
    'f1-score': 0.16250000000000003,
    'support': 44},
   'accuracy': 0.588957055214724,
   'macro avg': {'precision': 0.5401122245167406,
    'recall': 0.7306550515865291,
    'f1-score': 0.4450711382113821,
    'support': 978},
   'weighted avg': {'precision': 0.9502244822639593,
    'recall': 0.588957055214724,
    'f1-score': 0.7022166525346235,
    'support': 978}}}}

In [31]:
binary_education_lora

{'confusion_matrices': {'Education': array([[823,  71],
         [  3,  10]])},
 'classification_reports': {'Education': {'0': {'precision': 0.9963680387409201,
    'recall': 0.9205816554809844,
    'f1-score': 0.9569767441860465,
    'support': 894},
   '1': {'precision': 0.12345679012345678,
    'recall': 0.7692307692307693,
    'f1-score': 0.2127659574468085,
    'support': 13},
   'accuracy': 0.918412348401323,
   'macro avg': {'precision': 0.5599124144321884,
    'recall': 0.8449062123558768,
    'f1-score': 0.5848713508164275,
    'support': 907},
   'weighted avg': {'precision': 0.9838566316493799,
    'recall': 0.918412348401323,
    'f1-score': 0.9463099964158038,
    'support': 907}}}}

In [17]:
binary_war_df = pd.read_csv("../data/vicuna_4bit/generic_prompt_without_context_only_classification/generic_test_0.csv")
extraction_function = llm_utils.get_extraction_function("extract_using_class_token", 1)
binary_war_predictions_per_class, confusion_matrices, classification_reports = llm_utils.calculate_binary_metrics(binary_war_df, ["War/Terror"], extraction_function)
binary_war = {"confusion_matrices": confusion_matrices, "classification_reports": classification_reports}

In [18]:
binary_war

{'confusion_matrices': {'War/Terror': array([[54, 11],
         [13, 51]])},
 'classification_reports': {'War/Terror': {'0': {'precision': 0.8059701492537313,
    'recall': 0.8307692307692308,
    'f1-score': 0.8181818181818182,
    'support': 65},
   '1': {'precision': 0.8225806451612904,
    'recall': 0.796875,
    'f1-score': 0.8095238095238094,
    'support': 64},
   'accuracy': 0.813953488372093,
   'macro avg': {'precision': 0.8142753972075109,
    'recall': 0.8138221153846155,
    'f1-score': 0.8138528138528138,
    'support': 129},
   'weighted avg': {'precision': 0.8142110154404273,
    'recall': 0.813953488372093,
    'f1-score': 0.8138863720259069,
    'support': 129}}}}