In [6]:
import sys
from IPython.display import display, HTML
sys.path.append('./divergentmBERT')
import pandas as pd
import numpy as np
import shap
import os
import argparse
from scorer import DivergentmBERTScorer
import torch
device = torch.device("cuda" if torch.cuda.is_available() and not True else "cpu")
n_gpu = torch.cuda.device_count()
device = device

In [7]:
div_to_color = {'3': '#FF7F50',
                '2': '#FFBF00', 
                '1': '#DFFF00',
                '0': '#FFFFFF'}

def color_html(label):
    return f'<span style="background-color:{div_to_color[label]} ">'

def display_HTML(text, div, threshold):
    displayed_text = ''
    previous = 'eq'
    mark_end = '</span>'
    for token, label in zip(text.split(' '), div):
        if float(label) > threshold:
            mark = color_html('0')
            displayed_text += f' {token}'
            previous = 'eq'
        else:
            mark = color_html('1')
            if previous == 'div':
                displayed_text += f'{mark} {token}{mark_end}'
            else:
                displayed_text += f' {mark}{token}{mark_end}'
            previous = 'div'
        
    return display(HTML(displayed_text))

In [8]:
class DivergentmBERTWrapper():
    def __init__(self, model_path, tokenizer_path, do_lower_case, device, explain_source):
        self.scorer = DivergentmBERTScorer(model_path, tokenizer_path, do_lower_case, device)
        self.explain_source = explain_source
        self.text_a = None
        self.text_b = None

    def __call__(self, translations):
        if self.explain_source:
            target = [self.text_b] * len(translations)
            divergent_scores = self.scorer.compute_divergentscore(self.flatten(translations), target, self.explain_source)
        else:
            source = [self.text_a] * len(translations)
            divergent_scores = self.scorer.compute_divergentscore(source, self.flatten(translations), self.explain_source)
        return np.array(divergent_scores)

    def flatten(self, l):
        return [item for sublist in l for item in sublist]
    
    def tokenize_sent(self, sentence):
        return sentence.split(' ')

    def detokenize_sent(self, tokens):
        return ' '.join(tokens)

    def build_feature(self, trans_sent):
        tokens = self.tokenize_sent(trans_sent)
        tdict = {}
        for i, tt in enumerate(tokens):
            tdict['{}_{}'.format(tt, i)] = tt

        df = pd.DataFrame(tdict, index=[0])
        return df

    def mask_model(self, mask, x):
        tokens = []
        #print(f'MASK: {mask}')
        #print(f'{x}\n')
        #exit(0)
        for mm, tt in zip(mask, x):
            if mm:
                tokens.append(tt)
            else:
                tokens.append('[MASK]')
        trans_sent = self.detokenize_sent(tokens)
        sentence = pd.DataFrame([trans_sent])
        #print(f'Sentence: {trans_sent}\n')
        return sentence


In [9]:
class ExplainableDivergentmBERT():
    def __init__(self, model, tokenizer, do_lower_case, device,  explain_source):
        self.wrapper = DivergentmBERTWrapper(model, tokenizer, do_lower_case, device, explain_source)
        self.explainer = shap.Explainer(self.wrapper, self.wrapper.mask_model)
        self.explain_source = explain_source

    def __call__(self, text_a, text_b):
        if self.explain_source:
            return self.wrapper([text_a])
        else:
            return self.wrapper([text_b])

    def explain(self, text_a, text_b, plot=False):
        if self.explain_source: text = text_a
        else: text = text_b
        value = self.explainer(self.wrapper.build_feature(text))
        if plot: shap.waterfall_plot(value[0])
        all_tokens = self.wrapper.tokenize_sent(text)

        return [[token, sv] for token, sv in zip(all_tokens, value[0].values)]

In [10]:
texts_a = ["She made a courtesy call to the Hawaiian Islands at the end of the year and proceeded thence to Puget Sound where she arrived on 2 February 1852 ."]
texts_b = ["Il fait une escale aux îles Hawaï à la fin de l' année , au Puget Sound , le 2 février 1852 ."]
texts_a = ["Colonel-General Heinz Guderian , the Chief of the German General Staff , insisted to Adolf Hitler that the troops in Courland should be evacuated by sea and used for the defense of the Reich ."]
texts_b = ["Le général Heinz Guderian insiste auprès d' Adolf Hitler pour évacuer les soldats par la mer afin de les utiliser pour la défense du Reich ."]

texts_a = ["On his arrival the young man noticed a tree that was somewhat dried up; he split it in two, and found inserted in the middle of it a cross of a brown color and of a regular form."]
texts_a = ["He split it in two, and found inserted in the middle of it a cross of a brown color and of a regular form."]
texts_b = ["À son arrivée, le jeune homme remarque un arbre quelque peu desséché; il le fend en deux et trouve insérée au milieu de celui-ci une croix de couleur brune et de forme régulière."]


exp_scores = []
for explain_source in [True, False]:
        # Initialize explainable DivergentmBERT model
        print('Initialize Explainable DivergentmBERT')
        model = ExplainableDivergentmBERT(model='/fs/clip-divergences/xling-SemDiv/trained_bert/from_WikiMatrix.en-fr.tsv.filtered_sample_50000.moses.seed/contrastive_divergence_ranking/rdpg',
                                  tokenizer='bert-base-multilingual-cased',
                                  do_lower_case=False,
                                  device=device,
                                  explain_source=explain_source)
        print('Instance initialized...')
        for (text_a, text_b) in zip(texts_a, texts_b):
            model.wrapper.text_a = text_a
            model.wrapper.text_b = text_b
            #score = model(text_a, text_b)
            #print(score)
            #exit(0)
            exps = model.explain(text_a, text_b)
            exp_scores.append([float(entry[1]) for entry in exps])

Initialize Explainable DivergentmBERT
Instance initialized...


NameError: name 'json' is not defined

In [6]:
#print(score)
print('Source explanations')
print(exp_scores[0])
print('Target explanations')
print(exp_scores[1])

Source explanations
[0.5110032081604003, 0.8749134540557861, 0.6046473026275635, 0.10313496589660645, 0.0038985252380371095, -0.0007464408874511719, 0.030797290802001952, 0.2368807315826416, 0.17978758811950685, 0.18314628601074218, 0.22305402755737305, 0.23884763717651367, 0.19750890731811524, 0.18692369461059571, -0.23330540657043458, 0.10131587982177734, 0.11740713119506836, 0.35654115676879883, -0.005339956283569336, 0.01355600357055664, -0.03993921279907227, -0.09286775588989257, 0.6138678073883057, 0.028708744049072265]
Target explanations
[0.311589172908238, -0.06536245346069336, 0.5941666194370815, -0.12101922716413226, -0.16486140659877233, -0.3734444890703474, 1.0671052932739258, 0.2023179190499442, -0.7382615634373256, 0.07970571517944336, -0.13772950853620256, -0.060001986367361884, 0.12427571841648646, 0.15236265318734304, 0.23349945885794504, 0.17543772288731166, 0.08234514508928571, 0.23329871041434153, -0.39384310586111887, 0.15059638023376465, -0.06949778965541295, 0.0

In [11]:
display_HTML(text_a, exp_scores[0], 0)
display_HTML(text_b, exp_scores[1], 0)

In [8]:
display_HTML(text_a, exp_scores[0], 0.5)
display_HTML(text_b, exp_scores[1], 0.5)

In [16]:
sum(exp_scores[1][:11])/len(exp_scores[1][:11])

0.059473279234650844

In [15]:
sum(exp_scores[1][11:])/len(exp_scores[1][11:])

0.17512211861548482