In [1]:
import pandas as pd
import sys
sys.path.append("..")
from mmice.utils import html_highlight_diffs
from mmice.edit_finder import EditEvaluator
from mmice.maskers.random_masker import RandomMasker
from transformers import MT5TokenizerFast
from IPython.display import display, HTML
import numpy as np
import spacy
from tqdm import tqdm

nlp = spacy.load("en_core_web_sm")

eval = EditEvaluator(fluency_model_name="google/mt5-small",
                     fluency_masker=RandomMasker(None, MT5TokenizerFast.from_pretrained("google/mt5-small", model_max_length=700, legacy=False), 700))



In [2]:

TASK = "imdb"
STAGE2EXP = "mmice-test-editor-bert"
SAVE_PATH = f"../results/{TASK}/edits/{STAGE2EXP}/"
EDIT_PATH = SAVE_PATH + "edits.csv"

In [3]:
def read_edits(path):
    edits = pd.read_csv(path, sep="\t", lineterminator="\n").dropna()
    edits = edits[edits['data_idx'] != 'data_idx']
    if edits['new_pred'].dtype == np.dtype('float64'):
        edits['new_pred'] = edits.apply(lambda row: str(int(row['new_pred']) if not np.isnan(row['new_pred']) else ""), axis=1)
        edits['orig_pred'] = edits.apply(lambda row: str(int(row['orig_pred']) if not np.isnan(row['orig_pred']) else ""), axis=1)
        edits['contrast_pred'] = edits.apply(lambda row: str(int(row['contrast_pred']) if not np.isnan(row['contrast_pred']) else ""), axis=1)
    else:
        edits['new_pred'].fillna(value="", inplace=True)
        edits['orig_pred'].fillna(value="", inplace=True)
        edits['contrast_pred'].fillna(value="", inplace=True)
    return edits

In [4]:
def get_best_edits(edits):
    """ MiCE writes all edits that are found in Stage 2, 
    but we only want to evaluate the smallest per input. 
    Calling get_sorted_e() """
    edits['sorted_idx'] = pd.to_numeric(edits['sorted_idx'])
    edits['minimality'] = pd.to_numeric(edits['minimality'])
    edits['data_idx'] = pd.to_numeric(edits['data_idx'])
    edits['duration'] = pd.to_numeric(edits['duration'])
    return edits[edits['sorted_idx'] == 0]
    
def evaluate_edits(edits):
    temp = edits[edits['sorted_idx'] == 0]
    minim = temp['minimality'].mean()
    flipped = temp[temp['new_pred'].astype(str)==temp['contrast_pred'].astype(str)]
    nunique = temp['data_idx'].nunique()
    
    flip_rate = len(flipped)/nunique
    duration = temp['duration'].mean()
    metrics = {
        "num_total": nunique,
        "num_flipped": len(flipped),
        "flip_rate": flip_rate,
        "minimality": minim,
        #"fluency": temp['fluency'].mean(),
        "duration": duration,
    }
    for k, v in metrics.items():
        print(f"{k}: \t{round(v, 3)}")
    return metrics

In [5]:
def display_edits(row):
    html_original, html_edited = html_highlight_diffs(row['orig_editable_seg'], row['edited_editable_seg'], nlp)
    minim = round(row['minimality'], 3)
    print(f"MINIMALITY: \t{minim}")
    print("")
    display(HTML(html_original))
    display(HTML(html_edited))

def display_classif_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        print("-----------------------")
        print(f"ORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

def display_race_results(rows):
    for _, row in rows.iterrows():
        orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)
        new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)
        orig_input = eval(row['orig_input'])
        options = orig_input['options']
        print("-----------------------")
        print(f"QUESTION: {orig_input['question']}")
        print("\nOPTIONS:")
        for opt_idx, opt in enumerate(options):
            print(f"  ({opt_idx}) {opt}")
        print(f"\nORIG LABEL: \t{row['orig_pred']}")
        print(f"CONTR LABEL: \t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})")
        print(f"NEW LABEL: \t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})")
        print("")
        display_edits(row)

In [6]:
edits = read_edits(EDIT_PATH)
edits = get_best_edits(edits)

In [7]:
edits

Unnamed: 0,data_idx,sorted_idx,orig_pred,new_pred,contrast_pred,orig_contrast_prob_pred,new_contrast_prob_pred,orig_input,edited_input,orig_editable_seg,edited_editable_seg,minimality,num_edit_rounds,mask_frac,duration,error\r\r
0,473,0,NEGATIVE,POSITIVE,POSITIVE,0.005314,0.969722,I've got as much testosterone as the next blok...,"i've got as muche as the nextquel, e, and ra ...",I've got as much testosterone as the next blok...,"i've got as muche as the nextquel, e, and ra ...",0.487315,1.0,0.481250,6.673160,False\r\r
4,132,0,NEGATIVE,POSITIVE,POSITIVE,0.000425,0.942704,"This should be re-named ""Everybody Loves Sebas...",""" "" "" "" "" of the best of,,, the,,,,,,,,,,,,,,...","This should be re-named ""Everybody Loves Sebas...",""" "" "" "" "" of the best of,,, the,,,,,,,,,,,,,,...",0.923399,3.0,0.412500,39.963010,False\r\r
19,472,0,POSITIVE,NEGATIVE,NEGATIVE,0.383635,0.996140,_The Wild Life_ has an obvious resemblance to ...,_ the wild life _ has an obvious resemblance ...,_The Wild Life_ has an obvious resemblance to ...,_ the wild life _ has an obvious resemblance ...,0.133411,1.0,0.034375,26.357021,False\r\r
79,46,0,POSITIVE,NEGATIVE,NEGATIVE,0.012707,0.980625,"Late night on BBC1, was on my way to bed but c...","late night on bbc1, was on my way to bed but ...","Late night on BBC1, was on my way to bed but c...","late night on bbc1, was on my way to bed but ...",0.079130,1.0,0.034375,43.300757,False\r\r
138,40,0,NEGATIVE,POSITIVE,POSITIVE,0.000193,0.523847,This film is bad. It's filled with glaring plo...,this film is with interesting it's filled wit...,This film is bad. It's filled with glaring plo...,this film is with interesting it's filled wit...,0.225589,1.0,0.171875,2.589823,False\r\r
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10617,36,0,NEGATIVE,POSITIVE,POSITIVE,0.002066,0.874596,"Lovely music. Beautiful photography, some of s...",one of how about the farmers and of life in t...,"Lovely music. Beautiful photography, some of s...",one of how about the farmers and of life in t...,0.572046,2.0,0.515625,7.929349,False\r\r
10632,97,0,NEGATIVE,POSITIVE,POSITIVE,0.000240,0.955944,Wow! What a movie if you want to blow your bud...,i i i this this. on the.... on the... the hav...,Wow! What a movie if you want to blow your bud...,i i i this this. on the.... on the... the hav...,0.560897,2.0,0.240625,2.600762,False\r\r
10640,449,0,POSITIVE,NEGATIVE,NEGATIVE,0.032546,0.990759,"""Where to begin, where to begin . . ?(Savannah...",""" where to begin, where to begin..? ( savanna...","""Where to begin, where to begin . . ?(Savannah...",""" where to begin, where to begin..? ( savanna...",0.099490,1.0,0.034375,44.027389,False\r\r
10698,461,0,NEGATIVE,POSITIVE,POSITIVE,0.001476,0.969125,A VERY un-Tom and Jerry short. Jerry narrates ...,tom tom tom deanna and about this story that ...,A VERY un-Tom and Jerry short. Jerry narrates ...,tom tom tom deanna and about this story that ...,0.556619,2.0,0.171875,4.617000,False\r\r


In [8]:
tqdm.pandas(desc='original sequence loss!')
a = edits["orig_editable_seg"].progress_apply(lambda x: eval.score_fluency(x, 1))

original sequence loss!:   0%|          | 2/464 [01:10<4:31:35, 35.27s/it]


KeyboardInterrupt: 

In [None]:
tqdm.pandas(desc='edited sequence loss!')
b = edits["edited_editable_seg"].progress_apply(lambda x: eval.score_fluency(x) if isinstance(x, str) else 0)

In [None]:
edits['fluency'] =  b/a
edits.to_csv(SAVE_PATH + "best_edits.csv", sep="\t", lineterminator="\n")

In [9]:
#edits = read_edits(SAVE_PATH + "best_edits.csv")
#edits = get_best_edits(edits)
metrics = evaluate_edits(edits)


num_total: 	464
num_flipped: 	464
flip_rate: 	1.0
minimality: 	0.409
duration: 	24.218


In [20]:
for i, j in enumerate(edits['data_idx'].sort_values()):
    if i != j:
        print(i, j)

5 7
6 8
7 9
8 10
9 11
10 12
11 13
12 14
13 15
14 16
15 17
16 18
17 19
18 20
19 21
20 22
21 23
22 24
23 25
24 26
25 27
26 28
27 29
28 30
29 31
30 32
31 33
32 34
33 35
34 36
35 37
36 38
37 40
38 41
39 42
40 43
41 44
42 45
43 46
44 47
45 48
46 49
47 50
48 51
49 52
50 53
51 54
52 55
53 56
54 57
55 58
56 59
57 61
58 62
59 63
60 64
61 65
62 66
63 67
64 68
65 69
66 70
67 71
68 72
69 74
70 75
71 76
72 77
73 78
74 79
75 81
76 82
77 83
78 84
79 85
80 86
81 87
82 88
83 89
84 90
85 91
86 92
87 93
88 94
89 95
90 96
91 97
92 99
93 100
94 101
95 102
96 103
97 104
98 105
99 106
100 107
101 108
102 109
103 110
104 111
105 112
106 113
107 114
108 115
109 116
110 117
111 118
112 120
113 121
114 123
115 124
116 125
117 126
118 127
119 128
120 129
121 130
122 131
123 132
124 133
125 134
126 135
127 136
128 138
129 139
130 140
131 141
132 142
133 143
134 144
135 145
136 146
137 147
138 148
139 149
140 150
141 151
142 152
143 153
144 154
145 155
146 156
147 157
148 158
149 159
150 160
151 161
152 162
153 163

In [None]:
random_rows = edits.sample(1)
display_classif_results(random_rows)
# display_race_results(random_rows)

-----------------------
ORIG LABEL: 	NEGATIVE
CONTR LABEL: 	POSITIVE (Orig Pred Prob: 0.004)
NEW LABEL: 	POSITIVE (New Pred Prob: 0.99)

MINIMALITY: 	0.55

