In [1]:
import os

# Set your desired cache directory
os.environ['HF_HOME'] = 'D:\Repositories\multilingual_mice\.cache'

In [2]:
import pandas as pd
import sys
sys.path.append("..")
from mmice.utils import html_highlight_diffs
from mmice.edit_finder import EditEvaluator
from mmice.maskers.random_masker import RandomMasker
from transformers import MT5TokenizerFast
from IPython.display import display, HTML
import numpy as np
import spacy
from tqdm import tqdm

nlp = spacy.load("en_core_web_sm")

eval = EditEvaluator(fluency_model_name="google/mt5-small",
                     fluency_masker=RandomMasker(None, MT5TokenizerFast.from_pretrained("google/mt5-small",
                                                                                        # force_download=True,
                                                                                        model_max_length=700, legacy=False), 700))

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5TokenizerFast'.


In [None]:
FIX_FLAG = False
LOAD_BEST = True
TASK = "imdb"
STAGE2EXP = "mice-mt5-small-lora-02"
SAVE_PATH = f"../results/{TASK}/edits/{STAGE2EXP}/"
EDIT_PATH = SAVE_PATH + "edits.csv"

In [4]:
def read_edits(path):
    edits = pd.read_csv(path, sep="\t", lineterminator="\n").dropna()
    edits = edits[edits['data_idx'] != 'data_idx']
    if edits['new_pred'].dtype == np.dtype('float64'):
        edits['new_pred'] = edits.apply(lambda row: str(int(row['new_pred']) if not np.isnan(row['new_pred']) else ""), axis=1)
        edits['orig_pred'] = edits.apply(lambda row: str(int(row['orig_pred']) if not np.isnan(row['orig_pred']) else ""), axis=1)
        edits['contrast_pred'] = edits.apply(lambda row: str(int(row['contrast_pred']) if not np.isnan(row['contrast_pred']) else ""), axis=1)
    else:
        edits['new_pred'].fillna(value="", inplace=True)
        edits['orig_pred'].fillna(value="", inplace=True)
        edits['contrast_pred'].fillna(value="", inplace=True)
    return edits

In [5]:
def get_best_edits(edits):
    """ MiCE writes all edits that are found in Stage 2, 
    but we only want to evaluate the smallest per input. 
    Calling get_sorted_e() """
    edits['sorted_idx'] = pd.to_numeric(edits['sorted_idx'])
    edits['minimality'] = pd.to_numeric(edits['minimality'])
    edits['data_idx'] = pd.to_numeric(edits['data_idx'])
    edits['duration'] = pd.to_numeric(edits['duration'])
    return edits[edits['sorted_idx'] == 0]
    
def evaluate_edits(edits):
    temp = edits[edits['sorted_idx'] == 0]
    minim = temp['minimality'].mean()
    flipped = temp[temp['new_pred'].astype(str)==temp['contrast_pred'].astype(str)]
    nunique = temp['data_idx'].nunique()
    
    flip_rate = len(flipped)/nunique
    duration = temp['duration'].mean()
    metrics = {
        "num_total": nunique,
        "num_flipped": len(flipped),
        "flip_rate": flip_rate,
        "minimality": minim,
        #"fluency": temp['fluency'].mean(),
        "duration": duration,
    }
    for k, v in metrics.items():
        print(f"{k}: \t{round(v, 3)}")
    return metrics

In [6]:
def display_edits(row):
    html_original, html_edited = html_highlight_diffs(row['orig_editable_seg'], row['edited_editable_seg'], nlp)
    minim = round(row['minimality'], 3)
    print(f"MINIMALITY: \t{minim}")
    print("")
    display(HTML(html_original))
    display(HTML(html_edited))

def display_classif_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        print("-----------------------")
        print(f"ORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

def display_race_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        orig_input = eval(row['orig_input'])
        options = orig_input['options']
        print("-----------------------")
        print(f"QUESTION: {orig_input['question']}")
        print("\nOPTIONS:")
        for opt_idx, opt in enumerate(options):
            print(f"  ({opt_idx}) {opt}")
        print(f"\nORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

In [7]:
def fix_windows_corruption(file_path):
    new_file_path = file_path.split("/")
    new_file_path[-1] = 'fixed_edits.csv'
    new_file_path = "/".join(new_file_path)
    with open(file_path, 'r') as f, open(new_file_path, 'w+') as new_file:
        i = 0
        for line in f:
            i += 1
            line_ = line.replace(";", "").strip()
            if len(line_) == 0:
                continue
            if line_[0] == "\"":
                line_ = line_[1:]
            if line_[-1] == "\"":
                line_ = line_[:-1] 
            new_file.write(line_ + "\n")
    return new_file_path


if FIX_FLAG:
    EDIT_PATH = fix_windows_corruption(EDIT_PATH)           

In [8]:
edits = read_edits(EDIT_PATH) #if not LOAD_BEST else read_edits(SAVE_PATH + "best_edits.csv")
edits = get_best_edits(edits)

In [9]:
edits.head()

Unnamed: 0,data_idx,sorted_idx,orig_pred,new_pred,contrast_pred,orig_contrast_prob_pred,new_contrast_prob_pred,orig_input,edited_input,orig_editable_seg,edited_editable_seg,minimality,num_edit_rounds,mask_frac,duration,error\r\r
0,59,0,NEGATIVE,POSITIVE,POSITIVE,0.000894,0.988344,"sex, drugs, racism and of course you abc's. wh...",the video is pretty awesome! i love a kid's s...,"sex, drugs, racism and of course you abc's. wh...",the video is pretty awesome! i love a kid's s...,0.000975,1,0.171875,110.385358,False\r\r
8,21,0,NEGATIVE,POSITIVE,POSITIVE,0.005748,0.951881,"coming from kiarostami, this art-house visual ...",how did he do it? 10 minutes. the camera stand...,"coming from kiarostami, this art - house visua...",how did he do it? 10 minutes. the camera stand...,0.00259,1,0.275,1522.793333,False\r\r
52,56,0,NEGATIVE,POSITIVE,POSITIVE,0.000614,0.978662,i wasn't able to last ten minutes on the this ...,i wasn't able to last ten minutes on the this ...,i wasn't able to last ten minutes on the this ...,i wasn't able to last ten minutes on the this ...,0.000159,1,0.06875,159.305464,False\r\r
69,18,0,POSITIVE,NEGATIVE,NEGATIVE,0.000321,0.973081,definitely a movie for people who ask only to ...,definitely a movie for people who ask only to ...,definitely a movie for people who ask only to ...,definitely a movie for people who ask only to ...,7.2e-05,1,0.034375,293.245705,False\r\r
94,33,0,NEGATIVE,POSITIVE,POSITIVE,0.000878,0.981583,unlike terms of endearment and steel magnolia'...,unlike terms of endearment and steel magnolia'...,unlike terms of endearment and steel magnolia'...,unlike terms of endearment and steel magnolia'...,4e-06,1,0.034375,168.752117,False\r\r


In [10]:
from evaluate import load
perplexity = load("perplexity", module_type="metric")

edited_results = perplexity.compute(predictions=edits['edited_input'].tolist()[:], model_id='facebook/xglm-1.7B', batch_size=1)
orig_results = perplexity.compute(predictions=edits['orig_input'].tolist()[:], model_id='facebook/xglm-1.7B', batch_size=1)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [15]:
print(edited_results['mean_perplexity'])
print(orig_results['mean_perplexity'])


edits['edit_perplexity'] = edited_results['perplexities']
edits['orig_perplexity'] = orig_results['perplexities']
edits.to_parquet(SAVE_PATH + "best_edits.parquet.gzip",
                 compression='gzip')

40.917891302108764
29.09078013420105


In [12]:
# tqdm.pandas(desc='original sequence loss!')
# a = edits["orig_editable_seg"].progress_apply(lambda x: eval.score_fluency(x, 2))

# tqdm.pandas(desc='edited sequence loss!')
# b = edits["edited_editable_seg"].progress_apply(lambda x: eval.score_fluency(x) if isinstance(x, str) else 0)

# edits['fluency'] =  b/a
# edits.to_csv(SAVE_PATH + "best_edits.csv", sep="\t", lineterminator="\n")

In [13]:
#edits = read_edits(SAVE_PATH + "best_edits.csv")
#edits = get_best_edits(edits)
metrics = evaluate_edits(edits)

num_total: 	100
num_flipped: 	100
flip_rate: 	1.0
minimality: 	0.001
duration: 	308.72


In [14]:
random_rows = edits.sample(1)
display_classif_results(random_rows)
# display_race_results(random_rows)

-----------------------
ORIG LABEL: 	POSITIVE
CONTR LABEL: 	NEGATIVE (Orig Pred Prob: 0.002)
NEW LABEL: 	NEGATIVE (New Pred Prob: 0.995)

MINIMALITY: 	0.002

