In [1]:
import sys
import os
sys.path.append("../src")
import llm_utils
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
import ast
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report

classes = ["War/Terror", "Conspiracy Theory", "Education", "Election Campaign", "Environment", 
              "Government/Public", "Health", "Immigration/Integration", 
              "Justice/Crime", "Labor/Employment", 
              "Macroeconomics/Economic Regulation", "Media/Journalism", "Religion", "Science/Technology", "Others"]
# Transform the labels into binary format
mlb = MultiLabelBinarizer(classes=classes)

def get_report(base_path, test_name, extract_func, mlb, classes, verbose = False):
    df = pd.read_csv(base_path+test_name+"/test_generic_test_0.csv")
    if verbose:
        print("------------------")
        print("Prompt:")
        print(df.prompt[0])
        print("------------------")
        print()
        print()
        for i in range(0,3):
            print("------------------")
            print("Pre-Extraction:")
            print("------------------")
            print("Normalized Tweet: ", df.iloc[i].normalized_tweet)
            print("Response: ", df.iloc[i]['response'])
        print()
    df['annotations'] = df['annotations'].apply(lambda x: extract_func(x, classes))
    df['response'] = df['response'].apply(lambda x: extract_func(x, classes))
    if verbose:
        for i in range(0,3):
            print("------------------")
            print("Post-Extraction:")
            print("------------------")
            print("Normalized Tweet: ", df.iloc[i].normalized_tweet)
            print("Response: ", df.iloc[i]['test'+test_name])
        print()
    y_true = mlb.fit_transform(df['annotations'])
    y_pred = mlb.transform(df['response'])
    report = classification_report(y_true, y_pred, output_dict=True, target_names=classes)
    return pd.DataFrame(report).transpose()

# No Fine Tuning

In [2]:
base_path = "../data/vicuna_4bit/"

report_multilabel_no_fine_tune_v01 = get_report(base_path, "multilabel_no_fine_tune_v01", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_no_fine_tune_v02 = get_report(base_path, "multilabel_no_fine_tune_v02", llm_utils.extract_multilabel_list, mlb, classes)
#report_multilabel_no_fine_tune_v01_only_first_label_extracted = get_report(base_path, "multi_label_no_fine_tune_v01", llm_utils.extract_multilabel_list_only_first_class, mlb, classes)
report_multilabel_no_fine_tune_explanation_first_v01 = get_report(base_path, "multi_label_no_fine_tune_explanation_first_v01", llm_utils.extract_multilabel_list_explanation_first, mlb, classes)
report_multilabel_no_fine_tune_explanation_first_v01_only_first_label_extracted = get_report(base_path, "multi_label_no_fine_tune_explanation_first_v01", llm_utils.extract_multilabel_list_explanation_first_only_first_class, mlb, classes)

print()
print("----------------")
print("Multilabel v01 NO FINETUNING")
print("----------------")
print(report_multilabel_no_fine_tune_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v02 2 NO FINETUNING")
print("----------------")
print(report_multilabel_no_fine_tune_v02)
print()
print("----------------")
print("Multilabel v01 NO FINETUNING Explanation First")
print("----------------")
print(report_multilabel_no_fine_tune_explanation_first_v01)
print("----------------")
print()
"""print("----------------")
print("Multilabel v01 NO FINETUNING Explanation First Only First Label Extracted")
print("----------------")
print(report_multilabel_no_fine_tune_explanation_first_v01_only_first_label_extracted)
print("----------------")"""


----------------
Multilabel v01 NO FINETUNING
----------------
                                    precision    recall  f1-score  support
War/Terror                           0.332863  0.925490  0.489627    255.0
Conspiracy Theory                    0.050467  0.600000  0.093103     45.0
Education                            0.018717  0.538462  0.036176     13.0
Election Campaign                    0.069136  0.848485  0.127854     33.0
Environment                          0.023936  0.642857  0.046154     14.0
Government/Public                    0.310850  0.728522  0.435766    291.0
Health                               0.073350  0.652174  0.131868     46.0
Immigration/Integration              0.050119  0.583333  0.092308     36.0
Justice/Crime                        0.189807  0.788321  0.305949    137.0
Labor/Employment                     0.043796  0.642857  0.082005     28.0
Macroeconomics/Economic Regulation   0.084833  0.532258  0.146341     62.0
Media/Journalism                    

'print("----------------")\nprint("Multilabel v01 NO FINETUNING Explanation First Only First Label Extracted")\nprint("----------------")\nprint(report_multilabel_no_fine_tune_explanation_first_v01_only_first_label_extracted)\nprint("----------------")'

# Multilabel Fine-Tuning

In [3]:
vicuna_lora_multilabel_with_rules_v02_df = pd.read_csv("../data/vicuna_4bit/lora/multilabel_with_rules_v02/test_generic_test_0.csv")
vicuna_lora_multilabel_with_rules_v02_predictions_per_class, confusion_matrices, binary_classification_reports, multilabel_classification_reports = llm_utils.calculate_metrics_from_multilabel_list(vicuna_lora_multilabel_with_rules_v02_df, classes, llm_utils.extract_multilabel_list)
vicuna_lora_multilabel_with_rules_v02 = {"confusion_matrices": confusion_matrices, "binary_classification_reports": binary_classification_reports, "multilabel_classification_reports": multilabel_classification_reports}
vicuna_lora_multilabel_with_rules_v02["multilabel_classification_reports"]

Unnamed: 0,precision,recall,f1-score,support
War/Terror,0.97,0.760784,0.852747,255.0
Conspiracy Theory,0.451613,0.622222,0.523364,45.0
Education,0.428571,0.692308,0.529412,13.0
Election Campaign,0.851852,0.69697,0.766667,33.0
Environment,0.615385,0.571429,0.592593,14.0
Government/Public,0.76087,0.721649,0.740741,291.0
Health,0.785714,0.478261,0.594595,46.0
Immigration/Integration,0.789474,0.416667,0.545455,36.0
Justice/Crime,0.886792,0.686131,0.773663,137.0
Labor/Employment,0.521739,0.428571,0.470588,28.0


In [4]:
vicuna_lora_multilabel_with_rules_v02["binary_classification_reports"]

Unnamed: 0,label,f1_score_macro,precision_macro,recall_macro,support_macro,f1_score_class_0,support_class_0,f1_score_class_1,support_class_1,precision_class_0,recall_class_0,precision_class_1,recall_class_1
0,War/Terror,0.904691,0.946875,0.876365,1000,0.956634,745,0.852747,255,0.92375,0.991946,0.97,0.760784
1,Conspiracy Theory,0.748212,0.716745,0.79331,1000,0.973059,955,0.523364,45,0.981876,0.964398,0.451613,0.622222
2,Education,0.760637,0.712243,0.840075,1000,0.991862,987,0.529412,13,0.995914,0.987842,0.428571,0.692308
3,Election Campaign,0.879725,0.920787,0.846417,1000,0.992784,967,0.766667,33,0.989723,0.995863,0.851852,0.69697
4,Environment,0.793509,0.804653,0.783179,1000,0.994425,986,0.592593,14,0.993921,0.994929,0.615385,0.571429
5,Government/Public,0.819079,0.824496,0.81428,1000,0.897418,709,0.740741,291,0.888122,0.906911,0.76087,0.721649
6,Health,0.789509,0.880511,0.735986,1000,0.984424,954,0.594595,46,0.975309,0.993711,0.785714,0.478261
7,Immigration/Integration,0.766301,0.884033,0.706259,1000,0.987147,964,0.545455,36,0.978593,0.995851,0.789474,0.416667
8,Justice/Crime,0.87118,0.919347,0.836113,1000,0.968697,863,0.773663,137,0.951902,0.986095,0.886792,0.686131
9,Labor/Employment,0.728367,0.752681,0.708627,1000,0.986147,972,0.470588,28,0.983623,0.988683,0.521739,0.428571


In [6]:
base_path = "../data/vicuna_4bit/lora/"

In [7]:
report_multilabel_v01_128_rank_retest = get_report(base_path, "multilabel_without_context_v01_retest", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v01 = get_report(base_path, "multilabel_without_context_v01", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v01_256_rank = get_report(base_path, "multilabel_without_context_v01_256_rank", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v02 = get_report(base_path, "multilabel_with_rules_v02", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v02_256_rank = get_report(base_path, "multilabel_with_rules_v02_256_rank", llm_utils.extract_multilabel_list, mlb, classes)

print("All labels extracted from LLM response")
print()
print("----------------")
print("Multilabel v01 128 LoRA rank retest")
print("----------------")
print(report_multilabel_v01_128_rank_retest)
print("----------------")
print()
print("----------------")
print("Multilabel v01")
print("----------------")
print(report_multilabel_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v01 256 LoRA rank")
print("----------------")
print(report_multilabel_v01_256_rank)
print("----------------")
print()
print("----------------")
print("Multilabel v02")
print("----------------")
print(report_multilabel_v02)
print("----------------")
print()
print("----------------")
print("Multilabel v02 256 LoRA rank")
print("----------------")
print(report_multilabel_v02_256_rank)
print("----------------")

All labels extracted from LLM response

----------------
Multilabel v01 128 LoRA rank retest
----------------
                                    precision    recall  f1-score  support
War/Terror                           0.936893  0.756863  0.837310    255.0
Conspiracy Theory                    0.508772  0.644444  0.568627     45.0
Education                            0.444444  0.307692  0.363636     13.0
Election Campaign                    0.933333  0.424242  0.583333     33.0
Environment                          0.833333  0.357143  0.500000     14.0
Government/Public                    0.861789  0.364261  0.512077    291.0
Health                               0.769231  0.434783  0.555556     46.0
Immigration/Integration              0.916667  0.611111  0.733333     36.0
Justice/Crime                        0.825581  0.518248  0.636771    137.0
Labor/Employment                     0.625000  0.357143  0.454545     28.0
Macroeconomics/Economic Regulation   0.727273  0.129032  0.219178

# Metrics for multilabel but only the first class predicted is used

In [8]:
report_multilabel_v01_128_rank_retest = get_report(base_path, "multilabel_without_context_v01_retest", llm_utils.extract_multilabel_list_only_first_class, 
                                                   mlb, classes)

report_multilabel_v01 = get_report(base_path, "multilabel_without_context_v01", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

report_multilabel_v01_256_rank = get_report(base_path, "multilabel_without_context_v01_256_rank", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

report_multilabel_v02 = get_report(base_path, "multilabel_with_rules_v02", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

report_multilabel_v02_256_rank = get_report(base_path, "multilabel_with_rules_v02_256_rank", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

print("Only first label is extracted from LLM response")
print()
print("----------------")
print("Multilabel v01 retest")
print("----------------")
print(report_multilabel_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v01")
print("----------------")
print(report_multilabel_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v01 256 LoRA rank")
print("----------------")
print(report_multilabel_v01_256_rank)
print("----------------")
print()
print("----------------")
print("Multilabel v02")
print("----------------")
print(report_multilabel_v02)
print("----------------")
print()
print("----------------")
print("Multilabel v02 256 LoRA rank")
print("----------------")
print(report_multilabel_v02_256_rank)
print("----------------")

Only first label is extracted from LLM response

----------------
Multilabel v01 retest
----------------
                                    precision    recall  f1-score  support
War/Terror                           0.888476  0.937255  0.912214    255.0
Conspiracy Theory                    0.428571  0.631579  0.510638     19.0
Education                            0.444444  0.727273  0.551724     11.0
Election Campaign                    0.782609  0.620690  0.692308     29.0
Environment                          0.666667  0.571429  0.615385     14.0
Government/Public                    0.726027  0.716216  0.721088    222.0
Health                               0.600000  0.600000  0.600000     20.0
Immigration/Integration              0.500000  0.166667  0.250000      6.0
Justice/Crime                        0.780000  0.821053  0.800000     95.0
Labor/Employment                     0.625000  0.714286  0.666667     14.0
Macroeconomics/Economic Regulation   0.619048  0.481481  0.541667     

  _warn_prf(average, modifier, msg_start, len(result))
