In [44]:
import sys
import os
sys.path.append("../src")
import llm_utils
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
from sklearn.metrics import classification_report
import ast
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report

classes = ["War/Terror", "Conspiracy Theory", "Education", "Election Campaign", "Environment", 
              "Government/Public", "Health", "Immigration/Integration", 
              "Justice/Crime", "Labor/Employment", 
              "Macroeconomics/Economic Regulation", "Media/Journalism", "Religion", "Science/Technology", "Others"]
# Transform the labels into binary format
mlb = MultiLabelBinarizer(classes=classes)

def get_report(base_path, test_name, extract_func, mlb, classes, verbose = False):
    df = pd.read_csv(base_path+test_name+"/test_generic_test_0.csv")
    if verbose:
        print("------------------")
        print("Prompt:")
        print(df.prompt[0])
        print("------------------")
        print()
        print()
        for i in range(0,3):
            print("------------------")
            print("Pre-Extraction:")
            print("------------------")
            print("Normalized Tweet: ", df.iloc[i].normalized_tweet)
            print("Response: ", df.iloc[i]['test'+test_name])
        print()
    df['annotations'] = df['annotations'].apply(lambda x: llm_utils.extract_multilabel_list(x, classes))
    df['test'+test_name] = df['test'+test_name].apply(lambda x: extract_func(x, classes))
    if verbose:
        for i in range(0,3):
            print("------------------")
            print("Post-Extraction:")
            print("------------------")
            print("Normalized Tweet: ", df.iloc[i].normalized_tweet)
            print("Response: ", df.iloc[i]['test'+test_name])
        print()
    y_true = mlb.fit_transform(df['annotations'])
    y_pred = mlb.transform(df['test'+test_name])
    report = classification_report(y_true, y_pred, output_dict=True, target_names=classes)
    return pd.DataFrame(report).transpose()

# No Fine Tuning

In [45]:
base_path = "../data/vicuna_4bit/"

report_multilabel_no_fine_tune_v01 = get_report(base_path, "multi_label_no_fine_tune_v01", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_no_fine_tune_v01_only_first_label_extracted = get_report(base_path, "multi_label_no_fine_tune_v01", llm_utils.extract_multilabel_list_only_first_class, mlb, classes)
report_multilabel_no_fine_tune_explanation_first_v01 = get_report(base_path, "multi_label_no_fine_tune_explanation_first_v01", llm_utils.extract_multilabel_list_explanation_first, mlb, classes)
report_multilabel_no_fine_tune_explanation_first_v01_only_first_label_extracted = get_report(base_path, "multi_label_no_fine_tune_explanation_first_v01", llm_utils.extract_multilabel_list_explanation_first_only_first_class, mlb, classes)

print()
print("----------------")
print("Multilabel v01 NO FINETUNING")
print("----------------")
print(report_multilabel_no_fine_tune_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v01 NO FINETUNING Only First Label Extracted")
print("----------------")
print(report_multilabel_no_fine_tune_v01_only_first_label_extracted)
print("----------------")
print()
print("----------------")
print("Multilabel v01 NO FINETUNING Explanation First")
print("----------------")
print(report_multilabel_no_fine_tune_explanation_first_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v01 NO FINETUNING Explanation First Only First Label Extracted")
print("----------------")
print(report_multilabel_no_fine_tune_explanation_first_v01_only_first_label_extracted)
print("----------------")


----------------
Multilabel v01 NO FINETUNING
----------------
                                    precision    recall  f1-score  support
War/Terror                           0.301956  0.968627  0.460391    255.0
Conspiracy Theory                    0.052098  0.800000  0.097826     45.0
Education                            0.019753  0.615385  0.038278     13.0
Election Campaign                    0.061728  0.757576  0.114155     33.0
Environment                          0.027228  0.785714  0.052632     14.0
Government/Public                    0.345339  0.560137  0.427261    291.0
Health                               0.082547  0.760870  0.148936     46.0
Immigration/Integration              0.043478  0.500000  0.080000     36.0
Justice/Crime                        0.226190  0.693431  0.341113    137.0
Labor/Employment                     0.042607  0.607143  0.079625     28.0
Macroeconomics/Economic Regulation   0.058228  0.370968  0.100656     62.0
Media/Journalism                    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Multilabel v01

In [4]:
base_path = "../data/vicuna_4bit/lora/"

In [9]:
report_multilabel_v01_128_rank_retest = get_report(base_path, "multilabel_no_context_v01_128_rank_retest", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v01 = get_report(base_path, "multilabel_without_context_v01", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v01_256_rank = get_report(base_path, "multilabel_no_context_v01_256_rank", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v02 = get_report(base_path, "multilabel_with_rules_v02", llm_utils.extract_multilabel_list, mlb, classes)
report_multilabel_v02_256_rank = get_report(base_path, "multilabel_with_rules_v02_256_rank", llm_utils.extract_multilabel_list, mlb, classes)

print("All labels extracted from LLM response")
print()
print("----------------")
print("Multilabel v01 128 LoRA rank retest")
print("----------------")
print(report_multilabel_v01_128_rank_retest)
print("----------------")
print()
print("----------------")
print("Multilabel v01")
print("----------------")
print(report_multilabel_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v01 256 LoRA rank")
print("----------------")
print(report_multilabel_v01_256_rank)
print("----------------")
print()
print("----------------")
print("Multilabel v02")
print("----------------")
print(report_multilabel_v02)
print("----------------")
print()
print("----------------")
print("Multilabel v02 256 LoRA rank")
print("----------------")
print(report_multilabel_v02_256_rank)
print("----------------")

All labels extracted from LLM response

----------------
Multilabel v01 128 LoRA rank retest
----------------
                                    precision    recall  f1-score  support
War/Terror                           0.955446  0.756863  0.844639    255.0
Conspiracy Theory                    0.527273  0.644444  0.580000     45.0
Education                            0.444444  0.307692  0.363636     13.0
Election Campaign                    0.933333  0.424242  0.583333     33.0
Environment                          0.833333  0.357143  0.500000     14.0
Government/Public                    0.861789  0.364261  0.512077    291.0
Health                               0.769231  0.434783  0.555556     46.0
Immigration/Integration              0.916667  0.611111  0.733333     36.0
Justice/Crime                        0.825581  0.518248  0.636771    137.0
Labor/Employment                     0.666667  0.357143  0.465116     28.0
Macroeconomics/Economic Regulation   0.727273  0.129032  0.219178

# Metrics for multilabel but only the first class predicted is used

In [6]:
report_multilabel_v01 = get_report(base_path, "multilabel_without_context_v01", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

report_multilabel_v01_256_rank = get_report(base_path, "multilabel_no_context_v01_256_rank", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

report_multilabel_v02 = get_report(base_path, "multilabel_with_rules_v02", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

report_multilabel_v02_256_rank = get_report(base_path, "multilabel_with_rules_v02_256_rank", llm_utils.extract_multilabel_list_only_first_class,
                                   mlb, classes)

print("Only first label is extracted from LLM response")
print()
print("----------------")
print("Multilabel v01")
print("----------------")
print(report_multilabel_v01)
print("----------------")
print()
print("----------------")
print("Multilabel v01 256 LoRA rank")
print("----------------")
print(report_multilabel_v01_256_rank)
print("----------------")
print()
print("----------------")
print("Multilabel v02")
print("----------------")
print(report_multilabel_v02)
print("----------------")
print()
print("----------------")
print("Multilabel v02 256 LoRA rank")
print("----------------")
print(report_multilabel_v02_256_rank)
print("----------------")

Only first label is extracted from LLM response

----------------
Multilabel v01
----------------
                                    precision    recall  f1-score  support
War/Terror                           0.953947  0.568627  0.712531    255.0
Conspiracy Theory                    0.478873  0.755556  0.586207     45.0
Education                            0.421053  0.615385  0.500000     13.0
Election Campaign                    0.750000  0.545455  0.631579     33.0
Environment                          0.700000  0.500000  0.583333     14.0
Government/Public                    0.781377  0.663230  0.717472    291.0
Health                               0.740741  0.434783  0.547945     46.0
Immigration/Integration              0.750000  0.250000  0.375000     36.0
Justice/Crime                        0.929293  0.671533  0.779661    137.0
Labor/Employment                     0.687500  0.392857  0.500000     28.0
Macroeconomics/Economic Regulation   0.833333  0.322581  0.465116     62.0
Me

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
