# Setup

In [1]:
import json, os, sys, re
import pandas as pd
from collections import defaultdict
import numpy as np
import torch
from scipy import stats
from scipy.stats import entropy
from datasets import load_dataset, Dataset
import itertools
import torch
from pathlib import Path
from tqdm.auto import tqdm
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from scipy.special import softmax
import pickle
from joblib import Parallel, delayed
import language_tool_python
from itertools import combinations
# add src folder to path
sys.path.append('..')

from dev.ProbLM import JointLM, ConditionalLM
from exp_3_set_proba.prepare_data import correct_grammar, few_shot_examples 
from exp_3_set_proba.analyze import calculate_ranking, calculate_instance_probability # was calculate_p_t_V2

from exp_3_set_proba.utils import hist_of_all_p_t_values, classify, stacked_p_t_plot, hist_of_all_p_t_values, evaluate_classifier, boxplots, scatterplots, calculate_macro_avg, plot_roc_curve, plot_coverage_risk_curve_2, calculate_entropies, get_data_permutations
from exp_3_set_proba.utils import get_data, save_plot, combine_stats_dfs
from dev.ProbLM import JointLM, ConditionalLM


from data_utils import get_wiki_summary
%load_ext autoreload
%autoreload 2

HOME_PATH = os.path.expanduser("~/")

BASE_PATH = Path(f"{HOME_PATH}/Desktop/exp_3_set_proba_V4/") # TODO

stat_metrics  = ['n_objs', 'n_subjs', 'n_para', 'n_instances',
       'dataset', 'model', 'run_name']
metrics_global = ['coverage_abs', 'coverage_rel','precision_global', 'recall_global', 'f1_global',
       'accuracy_global', 'fpr_global', 
       'precision_argmax_global', 'recall_argmax_global', 'f1_argmax_global',
       'accuracy_argmax_global',  'fpr_argmax_global',
       ] # others: 'tp_global', 'tn_global', 'fp_global', 'fn_global', 'tp_argmax_global', 'tn_argmax_global', 'fp_argmax_global', 'fn_argmax_global',
metrics_selective = ['precision_selective']
metrics_global_0_thershold = ['auc_global', 'fpr_by_threshold_global',
       'tpr_by_threshold_global', 'roc_thresholds_global', 'fpr_by_threshold_argmax_global', 'tpr_by_threshold_argmax_global',
       'roc_thresholds_argmax_global', 'auc_argmax_global']
metrics_per_paraphrase = ['precision_argmax_pp', 'recall_argmax_pp', 'f1_argmax_pp',
       'accuracy_argmax_pp', 'fpr_argmax_pp'] # others: 'tp_argmax_pp', 'tn_argmax_pp', 'fp_argmax_pp','fn_argmax_pp',
metrics_per_paraphrase_0_threshold = ['fpr_by_threshold_argmax_pp',
       'tpr_by_threshold_argmax_pp', 'roc_thresholds_argmax_pp', 'auc_argmax_pp']

BASE_PATH


PosixPath('/Users/dug/Desktop/exp_3_set_proba_V4')

In [2]:
run_names = ['hypernymy_2000_50_mistral7B', 'trex_test_2000_50_mistral7B', 'PopQA_test_2000_50_mistral7B']
dataset_per_run = ['hypernymy']#, 'trex', 'PopQA']
df_stats, df_instance_permutations = get_data(run_names[0], BASE_PATH)
df_instance_permutations['obj_label'].nunique()

df_instance_permutations.columns

Index(['r_s_id', 'o_permutation_n', 'sub_label', 'obj_label', 'sequence',
       'label', 'orig_relation_template', 'orig_relation_id',
       'paraphrased_relation_template', 'paraphrase_id', 'sequence_original',
       'log_seq_prob', 's_r_t_id', 'p_t_s_r', 'p_t_over_paraphrases', 'p_t',
       'rank_o', 'rank_o|p_t_s_r,r_s_t(r)_i'],
      dtype='object')

# Subset for Manual Evaluation

In [3]:
# Select random subset per relation 
RERUN = False # Overwrite annotated data!
n = 1
# 50% pos / 50% neg / 2 of each paraphrase template

if RERUN:
    for r, run_name in enumerate(run_names):
        df_stats, df_instance_permutations = get_data(run_name, BASE_PATH)
        
        if run_name == 'hypernymy_2000_50_mistral7B':
            # has only 1 relation: use objects instead
            objects_unsampled = df_instance_permutations['obj_label'].unique()
            relations = np.random.choice(objects_unsampled, 10, replace=False)
            print(f'Data: {run_name}, objects sampled: {len(relations)}/{len(objects_unsampled)}')
        else:
            relations = df_instance_permutations['orig_relation_id'].unique()
            print(f'Relations: {relations.shape}')
            print(f'Data: {run_name}, n_relations: {len(relations)}')
        
        for i, relation in enumerate(relations):
            if run_name == 'hypernymy_2000_50_mistral7B':
                df_relation = df_instance_permutations[df_instance_permutations['obj_label'] == relation]
            else:
                df_relation = df_instance_permutations[df_instance_permutations['orig_relation_id'] == relation]
            
            paraphrase_templates = df_relation['paraphrase_id'].unique()
            try:
                paraphrase_sample = np.random.choice(paraphrase_templates, 10, replace=False)
            except ValueError:
                print(f'Not enough paraphrases for relation {relation}: {len(paraphrase_templates)}')
                paraphrase_sample = paraphrase_templates
            for t, paraphrase_template in enumerate(paraphrase_sample):
                if t== 0 and i == 0:
                    print(f'unique t(r): {df_relation["paraphrase_id"].nunique()}')
                df_r_p = df_relation[df_relation['paraphrase_id'] == paraphrase_template]
                pos_examples = df_r_p[df_r_p["label"] == 'pos'].sample(n)
                neg_examples = df_r_p[df_r_p["label"] == 'neg'].sample(n)
                
                if i == 0 and t == 0:
                    samples = pd.concat([pos_examples, neg_examples])
                else:
                    samples = pd.concat([samples, pos_examples, neg_examples])
                
        # reearange columns and save
        samples = samples[['r_s_id', 'o_permutation_n', 
        'label', 'orig_relation_template', 'orig_relation_id',
            'paraphrase_id', 'sequence_original',
        'log_seq_prob', 's_r_t_id', 'p_t_s_r', 'p_t_over_paraphrases', 'p_t',
        'rank_o', 'rank_o|p_t_s_r,r_s_t(r)_i', 'paraphrased_relation_template', 'sub_label', 'obj_label', 'sequence']]
            
            
        samples.to_csv(BASE_PATH / f"sequence_sample_{dataset_per_run[r]}.csv", index=False)
        print(BASE_PATH / f"sequence_sample_{dataset_per_run[r]}.csv")
        print(len(samples))
        


# Hypernymy: Context Evaluation
Amount of times S is contained in context


In [4]:
BASE_PATH_CODE = Path(f"{HOME_PATH}/Py/MAI_Codebase/exp_3_set_proba/") # TODO
with open(BASE_PATH_CODE / "s_contexts_hypernymy.json") as f:
    s_contexts_hypernymy = json.load(f)

s_in_context = 0
for s in s_contexts_hypernymy.keys():
    if s.lower() in s_contexts_hypernymy[s].lower():
        s_in_context += 1
        
result = s_in_context / len(s_contexts_hypernymy)
print(f'Subject contained in context in: {s_in_context}/{len(s_contexts_hypernymy)} = {result*100:.2f}%')

Subject contained in context in: 533/574 = 92.86%


# Grammar Postprocessing Errors
- change of Subject or Object

In [5]:

run_names = ['hypernymy_2000_50_mistral7B', 'trex_test_2000_50_mistral7B', 'PopQA_test_2000_50_mistral7B']
dataset_per_run = ['hypernymy', 'trex', 'PopQA']

for r, run_name in enumerate(run_names):
    print(r, run_name)
    df_stats, df_instance_permutations = get_data(run_name, BASE_PATH)
    
    subjects = df_instance_permutations['sub_label'].to_list()
    objects = df_instance_permutations['obj_label'].to_list()
    postprocessed_seqs = df_instance_permutations['sequence'].to_list()

    s_mistakes, o_mistakes = [], []
    for i in range(len(subjects)):
        s = subjects[i].lower()
        o = objects[i].lower()
        seq = postprocessed_seqs[i].lower()
        if s not in seq:
            s_mistakes.append(1)
        else:
            s_mistakes.append(0)
        if o not in seq:
            o_mistakes.append(1)
        else:
            o_mistakes.append(0)
            
    df_instance_permutations['s_mistakes'] = s_mistakes
    df_instance_permutations['o_mistakes'] = o_mistakes
    
    df_instance_permutations.to_csv(BASE_PATH / f"grammar_postprocessing_mistakes_{dataset_per_run[r]}.csv")
    
    s_mistakes_sum = np.array(s_mistakes).sum()
    o_mistakes_sum = np.array(o_mistakes).sum()
        
    print(f"Subjects affected by postprocessing: {s_mistakes_sum}/{len(subjects)} %: {s_mistakes_sum/len(subjects)*100}" )
    print(f"Objects affected by postprocessing: {o_mistakes_sum}/{len(objects)} %: {o_mistakes_sum/len(objects)*100}")
    

0 hypernymy_2000_50_mistral7B
Subjects affected by postprocessing: 6504/415440 %: 1.5655690352397458
Objects affected by postprocessing: 0/415440 %: 0.0
1 trex_test_2000_50_mistral7B
Subjects affected by postprocessing: 710833/2293951 %: 30.987279152867693
Objects affected by postprocessing: 18663/2293951 %: 0.8135744835003015
2 PopQA_test_2000_50_mistral7B
Subjects affected by postprocessing: 84059/267091 %: 31.47204510822154
Objects affected by postprocessing: 68712/267091 %: 25.726063401612187
