In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
import tqdm
from scipy.stats import pearsonr as pcorr
import itertools

In [2]:
PATH_ROOT = "/home/shaul/workspace/GitHub/SOTA/"

In [3]:
cd {PATH_ROOT}

/home/shaul/workspace/GitHub/SOTA


In [4]:
df = pd.read_csv('/home/shaul/workspace/GitHub/SOTA/data/combined/with_annotators/combined_dataset.csv', index_col= 0)

In [5]:
if "annotator" not in df.columns:
    # Older version of the combined dataset with the annotator
    df2 = pd.read_csv('/home/shaul/workspace/GitHub/SOTA/data/combined_dataset.csv.1', index_col = 0)
    df['annotator'] = df2.annotator

In [6]:
non_metric_columns = ['text1','text2','label','dataset','random','duration','total_seconds','pair_id','reduced_label','annotator']

In [76]:
def get_corr(df: pd.DataFrame, bad_annotator: list) -> dict:
    '''
    Get the correlation between the various metrics and the human labeling filtering out particular "bad annotators"

    parameters:
        df -- {pd.DataFrame} -- combined dataset
        bad_annotator -- {list} -- list of all the annotator ID's we want to filter out

    return:
        {pd.DataFrame} - correlations by each dataset of metric and human label
        {pd.DataFrame} - correlations by each dataset of metric and reduced human label (-1,0,1)
        {pd.Series} - correlations of all datasets of metric and human label
        {pd.Series} - correlations of all datasets of metris and reduced human label
    '''

    if bad_annotator is not None:
        df = df[~df.annotator.isin(bad_annotator)]
        #Remove all pairs if there is only one annotator
        df = df.groupby('pair_id').filter(lambda x: x.annotator.count() >= 2)

    metrics = [x for x in df.columns if x not in non_metric_columns]
    all_labels = metrics + ['label'] + ['reduced_label']
    df = df.groupby(['pair_id','dataset','random'])[all_labels].mean().reset_index()

    label_corr = dict()
    reduced_label_corr = dict()

    #Iterate through the datasets and get the correlation of each metric with label & reduced label (separately)
    for name,group in df.groupby('dataset'):
        label_corr[name] = group[metrics].corrwith(group['label'])
        reduced_label_corr[name] = group[metrics].corrwith(group['reduced_label'])

    combined_datasets_label_corr = df[metrics].corrwith(df['label'])
    combined_datasets_reduced_label_corr = df[metrics].corrwith(df['reduced_label'])

    random_label_corr = dict()
    random_reduced_label_corr = dict()

    for name,group in df.groupby('random'):
        random_label_corr[name] = group[metrics].corrwith(group['label'])
        random_reduced_label_corr[name] = group[metrics].corrwith(group['reduced_label'])

    correlations_dict = dict()
    correlations_dict['label_by_dataset'] = pd.DataFrame.from_dict(label_corr).T
    correlations_dict['reduced_label_by_dataset'] = pd.DataFrame.from_dict(reduced_label_corr).T
    correlations_dict['label_by_random'] = pd.DataFrame.from_dict(random_label_corr).T 
    correlations_dict['reduced_label_by_random'] = pd.DataFrame.from_dict(random_reduced_label_corr).T
    correlations_dict['label_by_combined'] = pd.Series(combined_datasets_label_corr)
    correlations_dict['reduced_label_by_combined'] = pd.Series(combined_datasets_reduced_label_corr)
    return correlations_dict

In [77]:
test = get_corr(df,None)

In [126]:
with open('/home/shaul/workspace/GitHub/SOTA/data/other/ba_all.txt','r+') as f:
    list_ba = f.read().splitlines() 

In [127]:
dict_baseline = get_corr(df,None)
dict_filtered = get_corr(df,list_ba)

In [128]:
def compare_correlations(dict_baseline, dict_filtered):

    ab_dict = dict()

    for key in dict_baseline.keys():
        ab_dict[key] = dict_filtered[key] - dict_baseline[key]

    return ab_dict

In [129]:
ab_dict = compare_correlations(dict_baseline,dict_filtered)

In [132]:
ab_dict['reduced_label_by_random']

Unnamed: 0,bleu,bleu1,glove_cosine,fasttext_cosine,BertScore,chrfScore,POS Dist score,1-gram_overlap,ROUGE-1,ROUGE-2,ROUGE-l
0,-0.008881,0.007736,-0.018534,-0.037464,-0.003839,0.000465,0.030275,-0.004488,0.00318,0.007254,0.007926
1,-0.051604,-0.043984,0.034127,0.045167,0.006353,-0.050304,0.005076,-0.048799,-0.038395,-0.049176,-0.030126
