In [100]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
PATH_ROOT = "/home/shaul/workspace/GitHub/SOTA/"

In [4]:
cd {PATH_ROOT}

/home/shaul/workspace/GitHub/SOTA


In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import math
import glob
import tqdm
from scipy.stats import pearsonr as pcorr
import itertools
import re
from sklearn.model_selection import train_test_split 
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn import preprocessing
from ipywidgets import interact
from src import metric_exploration

INFO:gensim.corpora.dictionary:adding document #0 to Dictionary(0 unique tokens: [])
INFO:gensim.corpora.dictionary:built Dictionary(12 unique tokens: ['computer', 'human', 'interface', 'response', 'survey']...) from 9 documents (total 29 corpus positions)
[nltk_data] Downloading package stopwords to /home/shaul/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [6]:
df = pd.read_csv('/home/shaul/workspace/GitHub/SOTA/data/full_DS/full_metrics.csv', index_col= 0)

df.dropna(inplace=True)

with open('/home/shaul/workspace/GitHub/SOTA/data/other/ba_all.txt','r+') as f:
    list_ba = f.read().splitlines() 
df_filtered = df[~df.annotator.isin(list_ba)]

In [114]:
non_metric_columns = ['text1','text2','label','dataset','random','duration','total_seconds','pair_id','reduced_label','annotator','radical','radical_random','radical_non_random','radical_or_centralist','num_labels','bad_annotator']
categories = ['dataset', 'random'] #['radical_or_centralist']

metrics = [x for x in df.columns if x not in non_metric_columns]
all_labels = metrics + ['label'] + ['reduced_label']

In [115]:
datasets_not_ivan = ['bible_random_human','paralex_random_human','paraphrase_random_human']
datasets_with_ivan = [column for column in df.dataset.unique() if column not in datasets_not_ivan]

In [116]:
non_random_columns = [column for column in df.dataset.unique() if ("random" not in column)]

In [117]:
df_ivan = df[df['dataset'].isin(datasets_with_ivan)]

## Rank-induced Orders

In [42]:
rank_induced_base = df.groupby('dataset')[all_labels].mean().rank(ascending=False).T
rank_induced_filt = df_filtered.groupby('dataset')[all_labels].mean().rank(ascending=False).T

In [77]:
rank_induced_base

dataset,bible_human,bible_random_human,gyafc_formal_human,gyafc_formal_random_human,gyafc_informal_human,gyafc_informal_random_human,gyafc_rewrites_human,gyafc_rewrites_random_human,paralex_human,paralex_random_human,paraphrase_human,paraphrase_random_human,yelp_human,yelp_random_human
bleu,1.0,10.0,2.0,9.0,5.0,8.0,3.0,12.0,6.0,11.0,7.0,14.0,4.0,13.0
bleu1,1.0,9.0,2.0,10.0,7.0,13.0,3.0,12.0,5.0,8.0,6.0,14.0,4.0,11.0
glove_cosine,14.0,8.0,13.0,7.0,9.0,5.0,12.0,6.0,10.0,3.0,2.0,1.0,11.0,4.0
fasttext_cosine,14.0,8.0,13.0,7.0,9.0,5.0,12.0,6.0,11.0,3.0,2.0,1.0,10.0,4.0
BertScore,2.0,13.0,1.0,10.0,7.0,14.0,4.0,12.0,6.0,8.0,5.0,11.0,3.0,9.0
chrfScore,1.0,8.0,3.0,10.0,7.0,13.0,2.0,12.0,5.0,9.0,6.0,14.0,4.0,11.0
POS Dist score,10.0,1.0,12.0,5.0,7.0,4.0,11.0,3.0,8.0,2.0,13.0,9.0,14.0,6.0
1-gram_overlap,2.0,9.0,1.0,10.0,6.0,13.0,3.0,12.0,5.0,8.0,7.0,14.0,4.0,11.0
ROUGE-1,1.0,9.0,2.0,10.0,6.0,12.0,3.0,13.0,5.0,8.0,7.0,14.0,4.0,11.0
ROUGE-2,3.0,9.0,1.0,11.0,6.0,10.0,2.0,12.0,5.0,8.0,7.0,14.0,4.0,13.0


In [78]:
rank_induced_filt

dataset,bible_human,bible_random_human,gyafc_formal_human,gyafc_formal_random_human,gyafc_informal_human,gyafc_informal_random_human,gyafc_rewrites_human,gyafc_rewrites_random_human,paralex_human,paralex_random_human,paraphrase_human,paraphrase_random_human,yelp_human,yelp_random_human
bleu,2.0,10.0,1.0,9.0,5.0,8.0,3.0,12.0,6.0,11.0,7.0,14.0,4.0,13.0
bleu1,1.0,9.0,2.0,10.0,7.0,13.0,3.0,12.0,5.0,8.0,6.0,14.0,4.0,11.0
glove_cosine,14.0,8.0,13.0,7.0,9.0,5.0,12.0,6.0,10.0,3.0,2.0,1.0,11.0,4.0
fasttext_cosine,14.0,8.0,13.0,7.0,9.0,5.0,12.0,6.0,11.0,3.0,2.0,1.0,10.0,4.0
BertScore,2.0,13.0,1.0,10.0,7.0,14.0,4.0,12.0,5.0,8.0,6.0,11.0,3.0,9.0
chrfScore,1.0,8.0,3.0,10.0,7.0,13.0,2.0,12.0,5.0,9.0,6.0,14.0,4.0,11.0
POS Dist score,10.0,1.0,12.0,5.0,7.0,4.0,11.0,3.0,8.0,2.0,14.0,9.0,13.0,6.0
1-gram_overlap,3.0,9.0,1.0,10.0,6.0,13.0,2.0,12.0,5.0,8.0,7.0,14.0,4.0,11.0
ROUGE-1,2.0,9.0,1.0,10.0,7.0,11.0,3.0,13.0,4.0,8.0,6.0,14.0,5.0,12.0
ROUGE-2,3.0,9.0,1.0,10.0,6.0,11.0,2.0,12.0,5.0,8.0,7.0,14.0,4.0,13.0


In [67]:
rank_induced_base.T[metrics].corrwith(rank_induced_base.T['label']).sort_values(ascending=False)

chrfScore          0.837363
bleu               0.832967
ROUGE-l            0.824176
ROUGE-2            0.815385
ROUGE-1            0.806593
1-gram_overlap     0.806593
bleu1              0.806593
BertScore          0.736264
glove_cosine      -0.665934
fasttext_cosine   -0.670330
POS Dist score    -0.679121
WMD               -0.850549
L2_score          -0.859341
dtype: float64

In [66]:
rank_induced_filt.T[metrics].corrwith(rank_induced_filt.T['label']).sort_values(ascending=False)

bleu               0.819780
ROUGE-l            0.784615
ROUGE-2            0.784615
ROUGE-1            0.771429
1-gram_overlap     0.758242
chrfScore          0.753846
bleu1              0.731868
BertScore          0.657143
glove_cosine      -0.595604
fasttext_cosine   -0.600000
POS Dist score    -0.670330
WMD               -0.780220
L2_score          -0.810989
dtype: float64

## Correlation Metrics

In [118]:
mc = metric_exploration.Metrics_Corr(df_ivan,non_metric_columns,categories)
base_results = mc.get_corr(None)
filtered_results = mc.get_corr(list_ba)

In [119]:
@interact
def display_corr(key_v = result.keys()):
     if type(base_results[key_v]) ==  pd.Series:
          display(base_results[key_v].sort_values(ascending=False))
     else:
          display(base_results[key_v])

interactive(children=(Dropdown(description='key_v', options=('label_by_dataset', 'reduced_label_by_dataset', '…

In [113]:
@interact
def display_corr(key_v = result.keys()):
     if type(filtered_results[key_v]) ==  pd.Series:
          display(filtered_results[key_v].sort_values(ascending=False))
     else:
          display(filtered_results[key_v])

interactive(children=(Dropdown(description='key_v', options=('label_by_dataset', 'reduced_label_by_dataset', '…

In [93]:
result = mc.compare_correlations(list_ba)

In [94]:
@interact
def display_corr(key_v = result.keys()):
     if type(result[key_v]) ==  pd.Series:
          display(result[key_v].sort_values(ascending=False))
     else:
          display(result[key_v])

interactive(children=(Dropdown(description='key_v', options=('label_by_dataset', 'reduced_label_by_dataset', '…

## Look at the Non-Linear and Linear Models

In [11]:
mm = metric_exploration.Metrics_Models(df,non_metric_columns,categories)
scores, fi_values = mm.run_model(model_type = "RF")

In [12]:
metric_exploration.visualize_fi(fi_values, categories)

interactive(children=(Dropdown(description='key_v', options=('fi_label_by_dataset_bible_human', 'fi_reduced_la…

interactive(children=(Dropdown(description='key_v', options=('fi_label_by_random_0', 'fi_reduced_label_by_rand…

interactive(children=(Dropdown(description='key_v', options=('fi_label_by_radical_or_centralist_centralist', '…

interactive(children=(Dropdown(description='key_v', options=('fi_label_combined', 'fi_reduced_label_combined')…

In [253]:
metric_exploration.visualize_score(scores, "Base Scores")

In [265]:
scores2 = mm.run_model(model_type="MLP")
metric_exploration.visualize_score(scores2, "MSE for MLP")

In [266]:
mm_filt = metric_exploration.Metrics_Models(df_filtered,non_metric_columns,categories)
scores, fi_values = mm_filt.run_model(model_type = "RF")

In [267]:
metric_exploration.visualize_fi(fi_values, categories)

interactive(children=(Dropdown(description=&#39;key_v&#39;, options=(&#39;fi_label_by_dataset_bible_human&#39;, &#39;fi_reduced_la…

interactive(children=(Dropdown(description=&#39;key_v&#39;, options=(&#39;fi_label_by_random_0&#39;, &#39;fi_reduced_label_by_rand…

interactive(children=(Dropdown(description=&#39;key_v&#39;, options=(&#39;fi_label_by_radical_or_centralist_centralist&#39;, &#39;…

interactive(children=(Dropdown(description=&#39;key_v&#39;, options=(&#39;fi_label_combined&#39;, &#39;fi_reduced_label_combined&#39;)…

In [268]:
metric_exploration.visualize_score(scores, "Base Scores")

In [269]:
scores2 = mm_filt.run_model(model_type="MLP")
metric_exploration.visualize_score(scores2, "MSE for MLP")

# See How the model gets better with the filtered dataset

In [270]:
comb_scores = mm.compare_score(list_ba, "RF")
visualize_score(comb_scores, "Improvement of filtered dataset - RF")

In [271]:
comb_scores = mm.compare_score(list_ba, "MLP")
visualize_score(comb_scores, "Improvement of filtered dataset - MLP")