In [1]:
import pandas as pd
import sys
sys.path.append("..")
from mmice.utils import html_highlight_diffs
from mmice.edit_finder import EditEvaluator
from mmice.maskers.random_masker import RandomMasker
from transformers import MT5TokenizerFast
from IPython.display import display, HTML
import numpy as np
import spacy
from tqdm import tqdm

nlp = spacy.load("en_core_web_sm")

eval = EditEvaluator(fluency_model_name="google/umt5-small",
                     fluency_masker=RandomMasker(None, MT5TokenizerFast.from_pretrained("google/mt5-small", model_max_length=700, legacy=False), 700))



In [2]:
FIX_FLAG = False
LOAD_BEST = True
TASK = "imdb"
STAGE2EXP = "mmice-test-editor-umt5"
SAVE_PATH = f"../results/{TASK}/edits/{STAGE2EXP}/"
EDIT_PATH = SAVE_PATH + "edits.csv"

In [3]:
def read_edits(path):
    edits = pd.read_csv(path, sep="\t", lineterminator="\n").dropna()
    edits = edits[edits['data_idx'] != 'data_idx']
    if edits['new_pred'].dtype == np.dtype('float64'):
        edits['new_pred'] = edits.apply(lambda row: str(int(row['new_pred']) if not np.isnan(row['new_pred']) else ""), axis=1)
        edits['orig_pred'] = edits.apply(lambda row: str(int(row['orig_pred']) if not np.isnan(row['orig_pred']) else ""), axis=1)
        edits['contrast_pred'] = edits.apply(lambda row: str(int(row['contrast_pred']) if not np.isnan(row['contrast_pred']) else ""), axis=1)
    else:
        edits['new_pred'].fillna(value="", inplace=True)
        edits['orig_pred'].fillna(value="", inplace=True)
        edits['contrast_pred'].fillna(value="", inplace=True)
    return edits

In [4]:
def get_best_edits(edits):
    """ MiCE writes all edits that are found in Stage 2, 
    but we only want to evaluate the smallest per input. 
    Calling get_sorted_e() """
    edits['sorted_idx'] = pd.to_numeric(edits['sorted_idx'])
    edits['minimality'] = pd.to_numeric(edits['minimality'])
    edits['data_idx'] = pd.to_numeric(edits['data_idx'])
    edits['duration'] = pd.to_numeric(edits['duration'])
    return edits[edits['sorted_idx'] == 0]
    
def evaluate_edits(edits):
    temp = edits[edits['sorted_idx'] == 0]
    minim = temp['minimality'].mean()
    flipped = temp[temp['new_pred'].astype(str)==temp['contrast_pred'].astype(str)]
    nunique = temp['data_idx'].nunique()
    
    flip_rate = len(flipped)/nunique
    duration = temp['duration'].mean()
    metrics = {
        "num_total": nunique,
        "num_flipped": len(flipped),
        "flip_rate": flip_rate,
        "minimality": minim,
        #"fluency": temp['fluency'].mean(),
        "duration": duration,
    }
    for k, v in metrics.items():
        print(f"{k}: \t{round(v, 3)}")
    return metrics

In [5]:
def display_edits(row):
    html_original, html_edited = html_highlight_diffs(row['orig_editable_seg'], row['edited_editable_seg'], nlp)
    minim = round(row['minimality'], 3)
    print(f"MINIMALITY: \t{minim}")
    print("")
    display(HTML(html_original))
    display(HTML(html_edited))

def display_classif_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        print("-----------------------")
        print(f"ORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

def display_race_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        orig_input = eval(row['orig_input'])
        options = orig_input['options']
        print("-----------------------")
        print(f"QUESTION: {orig_input['question']}")
        print("\nOPTIONS:")
        for opt_idx, opt in enumerate(options):
            print(f"  ({opt_idx}) {opt}")
        print(f"\nORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

In [6]:
def fix_windows_corruption(file_path):
    new_file_path = file_path.split("/")
    new_file_path[-1] = 'fixed_edits.csv'
    new_file_path = "/".join(new_file_path)
    with open(file_path, 'r') as f, open(new_file_path, 'w+') as new_file:
        i = 0
        for line in f:
            i += 1
            line_ = line.replace(";", "").strip()
            if len(line_) == 0:
                continue
            if line_[0] == "\"":
                line_ = line_[1:]
            if line_[-1] == "\"":
                line_ = line_[:-1] 
            new_file.write(line_ + "\n")
    return new_file_path


if FIX_FLAG:
    EDIT_PATH = fix_windows_corruption(EDIT_PATH)           

In [7]:
edits = read_edits(EDIT_PATH) if not LOAD_BEST else read_edits(SAVE_PATH + "best_edits.csv")
edits = get_best_edits(edits)

In [8]:
edits.head()

Unnamed: 0,data_idx,sorted_idx,orig_pred,new_pred,contrast_pred,orig_contrast_prob_pred,new_contrast_prob_pred,orig_input,edited_input,orig_editable_seg,edited_editable_seg,minimality,num_edit_rounds,mask_frac,duration,error\r,perplexity
0,473,0,NEGATIVE,POSITIVE,POSITIVE,0.005314,0.897639,I've got as much testosterone as the next blok...,I've got as much testosterone as the next blok...,I've got as much testosterone as the next blok...,I've got as much testosterone as the next blok...,0.160677,1.0,0.034375,119.766565,False\r,54.948025
1,132,0,NEGATIVE,POSITIVE,POSITIVE,0.000425,0.862147,"This should be re-named """"""""Everybody Loves Se...","This should be re-named """"""""Everybody Loves Se...","This should be re-named """"""""Everybody Loves Se...","This should be re-named """"""""Everybody Loves Se...",0.240575,1.0,0.06875,186.746424,False\r,113.194962
2,472,0,POSITIVE,NEGATIVE,NEGATIVE,0.383635,0.998703,_The Wild Life_ has an obvious resemblance to ...,_The Wild Life_ has an obvious resemblance to ...,_The Wild Life_ has an obvious resemblance to ...,_The Wild Life_ has an obvious resemblance to ...,0.041763,1.0,0.034375,172.030831,False\r,94.364273
3,46,0,POSITIVE,NEGATIVE,NEGATIVE,0.012707,0.654644,"Late night on BBC1, was on my way to bed but c...","Late night on BBC1, was on my way to bed but c...","Late night on BBC1, was on my way to bed but c...","Late night on BBC1, was on my way to bed but c...",0.097391,1.0,0.06875,215.840218,False\r,78.321251
4,40,0,NEGATIVE,POSITIVE,POSITIVE,0.000193,0.980735,This film is bad. It's filled with glaring plo...,This film is this . It's filled with excellen...,This film is bad. It's filled with glaring plo...,This film is this . It's filled with excellen...,0.191919,1.0,0.06875,46.403196,False\r,111.334877


In [16]:
from mmice.ppl import Perplexity

ppl = Perplexity(model_id='gpt2', device='gpu')
results = ppl._compute(edits['edited_input'].tolist()[:], max_length=1024, batch_size=8)
print(results['mean_perplexity'])

edits['perplexity'] = results['perplexities']
edits.to_csv(SAVE_PATH + "best_edits.csv", sep="\t", lineterminator="\n", index=False)

  0%|          | 0/49 [00:00<?, ?it/s]

87.66469121590639


In [9]:
tqdm.pandas(desc='original sequence loss!')
a = edits["orig_editable_seg"].progress_apply(lambda x: eval.score_fluency(x, 2))

original sequence loss!:   5%|▍         | 19/390 [1:17:19<22:54:44, 222.33s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (1309 > 700). Running this sequence through the model will result in indexing errors


In [None]:
tqdm.pandas(desc='edited sequence loss!')
b = edits["edited_editable_seg"].progress_apply(lambda x: eval.score_fluency(x) if isinstance(x, str) else 0)

In [None]:
edits['fluency'] =  b/a
edits.to_csv(SAVE_PATH + "best_edits.csv", sep="\t", lineterminator="\n")

In [None]:
#edits = read_edits(SAVE_PATH + "best_edits.csv")
#edits = get_best_edits(edits)
metrics = evaluate_edits(edits)

num_total: 	498
num_flipped: 	498
flip_rate: 	1.0
minimality: 	0.194
duration: 	37.402


In [None]:
random_rows = edits.sample(1)
display_classif_results(random_rows)
# display_race_results(random_rows)

-----------------------
ORIG LABEL: 	POSITIVE
CONTR LABEL: 	NEGATIVE (Orig Pred Prob: 0.0)
NEW LABEL: 	NEGATIVE (New Pred Prob: 0.587)

MINIMALITY: 	0.226

