In [32]:
import os
import re
import glob

import numpy as np
import pandas as pd 
from IPython.core.display import Markdown

from utils import plot_count_and_normalized_confusion_matrix, \
    load_dataset_task_prompt_mappings, map_label_to_completion

In [33]:
def map_outputs_task_1(output):
    # Check first if it matches with the start of the sentence
    if re.search(r'^(\s)*RELEVANT', output.strip()):
        return 'RELEVANT'
    elif re.search(r'^(\s)*IRRELEVANT', output.strip()):
        return 'IRRELEVANT'
    elif re.search(r'\sRELEVANT|(t|T)he tweet is relevant', output.strip()):
        return 'RELEVANT'
    elif re.search(r'IRRELEVANT', output.strip()):
        return 'IRRELEVANT'
    elif output == np.nan or output == 'nan':
        return "NAN"
    else:
        print(f'Weird value: {output.strip()}')
        return "NAN"


In [None]:
def map_outputs_task_2(output):
    if re.search(r'^(\s)*PROBLEM', output.strip()):
        return 'PROBLEM'
    elif re.search(r'^(\s)*SOLUTION', output.strip()):
        return 'SOLUTION'
    elif re.search(r'^(\s)*(NEITHER|NEUTRAL)|Therefore, I would classify this tweet as NEUTRAL|I think it is NEUTRAL', output.strip()):
        return 'NEUTRAL'

    elif re.search(r'PROBLEM|(t|T)he tweet describes content moderation as a problem|'
                 r'(T|t)he tweet is describing content moderation as a problem', output.strip()):
        return 'PROBLEM'
    elif re.search(r'SOLUTION', output.strip()):
        return 'SOLUTION'
    elif re.search(r'(NEITHER|NEUTRAL)', output.strip()):
        return 'NEUTRAL'
    elif output == np.nan or output == 'nan':
        return "NAN"
    else:
        print(f'Weird value: {output.strip()}')
        return "NAN"


In [34]:
def map_outputs_task_3(output):
    if re.search(r'^(\s*A:){0,1}(\s)*ECONOMY', output.strip()):
        return 'ECONOMY'

    elif re.search(r'^(\s*B:){0,1}(\s)*MORALITY', output.strip()):
        return 'MORALITY'

    elif re.search(r'^(\s*C:){0,1}(\s)*FAIRNESS AND EQUALITY', output.strip()):
        return 'FAIRNESS AND EQUALITY'

    elif re.search(r'^(\s*D:){0,1}(\s)*POLICY PRESCRIPTION AND EVALUATION', output.strip()):
        return 'POLICY PRESCRIPTION AND EVALUATION'

    elif re.search(r'^(\s*E:){0,1}(\s)*LAW AND ORDER, CRIME AND JUSTICE', output.strip()):
        return 'LAW AND ORDER, CRIME AND JUSTICE'

    elif re.search(r'^(\s*F:){0,1}(\s)*SECURITY AND DEFENSE', output.strip()):
        return 'SECURITY AND DEFENSE'

    elif re.search(r'^(\s*G:){0,1}(\s)*HEALTH AND SAFETY', output.strip()):
        return 'HEALTH AND SAFETY'

    elif re.search(r'^(\s*H:){0,1}(\s)*QUALITY OF LIFE', output.strip()):
        return 'QUALITY OF LIFE'

    elif re.search(r'^(\s*I:){0,1}(\s)*POLITICAL', output.strip()):
        return 'POLITICAL'

    elif re.search(r'^(\s*J:){0,1}(\s)*EXTERNAL REGULATION AND REPUTATION', output.strip()):
        return 'EXTERNAL REGULATION AND REPUTATION'

    elif re.search(r'^(\s*K:){0,1}(\s)*OTHER', output.strip()):
        return 'OTHER'
    
    elif re.search(r'(\s*A:){0,1}(\s)*ECONOMY|^\s*A', output.strip()):
        return 'ECONOMY'

    elif re.search(r'(\s*B:){0,1}(\s)*MORALITY|^\s*B(\s+|$)', output.strip()):
        return 'MORALITY'

    elif re.search(r'(\s*C:){0,1}(\s)*FAIRNESS AND EQUALITY', output.strip()):
        return 'FAIRNESS AND EQUALITY'

    elif re.search(r'(\s*D:){0,1}(\s)*POLICY PRESCRIPTION AND EVALUATION', output.strip()):
        return 'POLICY PRESCRIPTION AND EVALUATION'

    elif re.search(r'(\s*E:){0,1}(\s)*LAW AND ORDER, CRIME AND JUSTICE', output.strip()):
        return 'LAW AND ORDER, CRIME AND JUSTICE'

    elif re.search(r'(\s*F:){0,1}(\s)*SECURITY AND DEFENSE', output.strip()):
        return 'SECURITY AND DEFENSE'

    elif re.search(r'(\s*G:){0,1}(\s)*HEALTH AND SAFETY', output.strip()):
        return 'HEALTH AND SAFETY'

    elif re.search(r'(\s*H:){0,1}(\s)*QUALITY OF LIFE', output.strip()):
        return 'QUALITY OF LIFE'

    elif re.search(r'(\s*I:){0,1}(\s)*(POLITICAL|POLICITAL)', output.strip()):
        return 'POLITICAL'

    elif re.search(r'(\s*J:){0,1}(\s)*EXTERNAL REGULATION AND REPUTATION', output.strip()):
        return 'EXTERNAL REGULATION AND REPUTATION'

    elif re.search(r'(\s*K:){0,1}(\s)*OTHER|EDUCATION', output.strip()):
        return 'OTHER'

    elif output == np.nan or output == 'nan':
        return "NAN"

    else:
        print(f'Weird value: {output.strip()}')
        return "NAN"




In [None]:
def map_outputs_task_4(output):
    if re.search(r'^(\s)*IN FAVOR OF', output.strip()):
        return 'IN FAVOR OF'

    elif re.search(r'^(\s)*AGAINST', output.strip()):
        return 'AGAINST'

    elif re.search(r'^(\s)*NEUTRAL', output.strip()):
        return 'NEUTRAL'
    
    elif re.search(r'IN FAVOR OF', output.strip()):
        return 'IN FAVOR OF'

    elif re.search(r'AGAINST', output.strip()):
        return 'AGAINST'

    elif re.search(r'NEUTRAL', output.strip()):
        return 'NEUTRAL'

    elif output == np.nan or output == 'nan':
        return "NAN"

    else:
        print(f'Weird value: {output.strip()}')
        return "NAN"



In [None]:
def map_outputs_task_5(output):
    if re.search(r'^(the tweet is (about (the){0,1})){0,1}(\s)*section 230', output.lower().strip()):
        return 'Section 230'

    elif re.search(r'^(the tweet is about (the){0,1}){0,1}(\s)*trump ban', output.lower().strip()):
        return 'Trump ban'

    elif re.search(r'^(the tweet is about (the){0,1}){0,1}(\s)*twitter support', output.lower().strip()):
        return 'Twitter Support'

    elif re.search(r'^(the tweet is about (the){0,1}){0,1}(\s)*platform policies', output.lower().strip()):
        return 'Platform Policies'

    elif re.search(r'^(the tweet is about (the){0,1}){0,1}(\s)*complaint', output.lower().strip()):
        return 'Complaint'

    elif re.search('^(the tweet is about (the){0,1}){0,1}(\s)*other', output.lower().strip()):
        return 'Other'
    
    elif re.search('OTHER|Therefore, it (should|can) be classified as "Other.|Other"',
                   output.strip()):
        return 'Other'
    
    elif re.search(r'SECTION 230|classified as A: Section 230|"Section 230"|A: Section 2', output.strip()):
         return 'Section 230'

    elif re.search(r'TRUMP BAN|Trump Ban|B: Trump ban|B Trump ban|B(\s+|$)|B (Trump ban)', output.strip()):
        return 'Trump ban'

    elif re.search(r'TWITTER SUPPORT|classified as C \(Twitter Support\).', output.strip()):
        return 'Twitter Support'

    elif re.search(r'PLATFORM POLICIES|Platform Policies', output.strip()):
        return 'Platform Policies'

    elif re.search(r'COMPLAINT|E: Complaint|"Complaints"|E \(Complaint\)|E(\s+|$)|classified as \(E\) Complaint', output.strip()):
        return 'Complaint'

    elif output == np.nan or output == 'nan':
        return "NAN"

    else:
        print(f'Weird value: {output.strip()}')
        return 'Other'
    

In [35]:
def map_outputs_task_6(output):
    if re.search(r'^(\s*A:){0,1}(\s)*POLICY AND REGULATION', output.strip()):
        return 'POLICY AND REGULATION'
    
    elif re.search(r'^(\s*B:){0,1}(\s)*MORALITY AND LAW', output.strip()):
        return 'MORALITY AND LAW'
    
    elif re.search(r'^(\s*C:){0,1}(\s)*ECONOMICS', output.strip()):
        return 'ECONOMICS'
    
    elif re.search(r'^(\sD:){0,1}(\s)*OTHER', output.strip()):
        return 'OTHER'

    elif re.search(r'(\s*C:){0,1}(\s)*ECONOM(Y|ICS)|^\s*A', output.strip()):
        return 'ECONOMICS'
    
    elif re.search(r'(\s*C:){0,1}(\s)*PUBLIC OPINION|^\s*A', output.strip()):
        return 'ECONOMICS'

    elif re.search(r'(\s*B:){0,1}(\s)*MORALITY|^\s*B(\s+|$)', output.strip()):
        return 'MORALITY AND LAW'

    elif re.search(r'(\s*B:){0,1}(\s)*FAIRNESS AND EQUALITY', output.strip()):
        return 'MORALITY AND LAW'

    elif re.search(r'(\s*A:){0,1}(\s)*POLICY PRESCRIPTION AND EVALUATION|POLICY AND REGULATION', output.strip()):
        return 'POLICY AND REGULATION'

    elif re.search(r'(\s*B:){0,1}(\s)*LAW AND ORDER, CRIME AND JUSTICE', output.strip()):
        return 'MORALITY AND LAW'
    
    elif re.search(r'(\s*B:){0,1}(\s)*CONSTITUTIONALITY AND JURISPRUDENCE', output.strip()):
        return 'MORALITY AND LAW'

    elif re.search(r'(\s*C:){0,1}(\s)*SECURITY AND DEFENSE', output.strip()):
        return 'CAPACITY AND RESOURCES'

    elif re.search(r'(\s*B:){0,1}(\s)*HEALTH AND SAFETY', output.strip()):
        return 'MORALITY AND LAW'

    elif re.search(r'(\sC:){0,1}(\s)*QUALITY OF LIFE', output.strip()):
        return 'ECONOMICS'
    
    elif re.search(r'(\sC:){0,1}(\s)*CAPACITY AND RESOURCES', output.strip()):
        return 'ECONOMICS'

    elif re.search(r'(\s*A:){0,1}(\s)*(POLITICAL|POLICITAL)', output.strip()):
        return 'POLICY AND REGULATION'

    elif re.search(r'(\s*A:){0,1}(\s)*EXTERNAL REGULATION AND REPUTATION', output.strip()):
        return 'POLICY AND REGULATION'

    elif re.search(r'(\s*D:){0,1}(\s)*OTHER|EDUCATION', output.strip()):
        return 'OTHER'

    elif output == np.nan or output == 'nan':
        return "NAN"

    else:
        print(f'Weird value: {output.strip()}')
        return "NAN"

In [36]:
map_to_task_label_processing_fn = {
    1: map_outputs_task_1,
    2: map_outputs_task_2,
    3: map_outputs_task_3,
    4: map_outputs_task_4,
    5: map_outputs_task_5,
    6: map_outputs_task_6
}

In [37]:
def process_output(output_str, task_num):
    # Remove the prompt
    output_str = ' '.join(output_str.split('<|endoftext|>')[1:])
    
    # process accordin to the task_num
    return map_to_task_label_processing_fn[task_num](output_str)

In [38]:
dataset_task_mappings_fp = os.path.join('..', '..', 'dataset_task_mappings.csv')

# Load data

In [None]:
# Get the expected labelset
dataset_idx, dataset_task_mappings = load_dataset_task_prompt_mappings(
    dataset_num=ds_num, task_num=task_num, dataset_task_mappings_fp=dataset_task_mappings_fp)
label_column = dataset_task_mappings.loc[dataset_idx, "label_column"]
labelset = dataset_task_mappings.loc[dataset_idx, "labelset_fullword"].split("; ")
labelset = [label.strip() for label in labelset] 
labelset += [label.upper() for label in labelset] 
labelset_full_description = labelset
num_rows = df.shape[0]

# Get predictions
y_pred = df.prediction_ds.map(lambda x: process_output(x, task_num))
(pd.concat([df.prediction_ds.map(lambda x: ' '.join(x.split('<|endoftext|>')[1:])), y_pred.rename('pred_label')], axis=1)
    .to_csv(f'test_ds_{ds_num}_task_{task_num}_sample_size{sample_size}.csv'))
assert num_rows == y_pred.shape[0], 'Number of rows in dataframe and predictions do not match'
#assert df['pred_label'].map(lambda pred: pred not in labelset).sum() == 0, 'Prediction not in expected labelset'

# Get ground truth in same format
y_true = df[label_column].map(lambda label: map_label_to_completion(
    label=label, task_num=task_num, full_label=True))
assert y_true.map(lambda pred: pred not in labelset).sum() == 0, 'Ground truth not in expected labelset'
assert num_rows == y_true.shape[0], 'Number of rows in dataframe and ground truth do not match'

# Get accuracy
labels = labelset
display_labels = labelset_full_description
cm_plot, classification_report, metrics = plot_count_and_normalized_confusion_matrix(
    y_true, y_pred, display_labels, labels, xticks_rotation='horizontal')

# Get accuracy
accuracy_summary_list.append({
    'exp_name': os.path.basename(prediction_fp),
    'dataset': ds_num,
    'task': task_num,
    'sample_size': sample_size,
    'accuracy': metrics['accuracy'],
    'f1-macro': metrics['f1'],
    'precision': metrics['precision'],
    'recall': metrics['recall']
})