In [1]:
import os
os.chdir('../')

In [57]:
import warnings

warnings.filterwarnings(action='ignore')

import yaml
import json
import pandas as pd

import seaborn as sns
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np



from dataset import create_tokenizer
from models import create_model

from captum.attr import TokenReferenceBase, LayerIntegratedGradients

import ipywidgets as widgets
from IPython.display import display, clear_output

def _get_mask(start_idx, input_ids, sent_ids):
    
    mask = torch.zeros_like(input_ids).to(torch.bool)
    
    if len(input_ids[start_idx:]) < len(sent_ids):
        mask[start_idx:] = True
        return mask
    
    for i in range(start_idx, len(input_ids) - len(sent_ids)):    
        match = input_ids[i:i+len(sent_ids)] == sent_ids

        if match.sum() == len(sent_ids):        
            mask[i:i+len(sent_ids)] = True
            break

    return mask

def get_sentences_mask(input_ids, sents, encode_func):
    sents_mask = []
    start_idx = 0
    for sent in sents:

        mask = _get_mask(
            start_idx = start_idx,
            input_ids = input_ids, 
            sent_ids  = torch.tensor(encode_func(sent))
        )

        sents_mask.append(mask)

        check = torch.where(mask == True)[0]
        if check.sum() == 0:
            break
        else:
            start_idx = torch.where(mask == True)[0][-1] + 1
        
    return sents_mask

def print_sent_with_score(sent, score):
    cmap = plt.cm.get_cmap('Greens')
    
    # calc score of sentence
    temp = np.array(cmap(score))
    temp[:3] = (temp[:3] * 255).astype(np.uint8)
    temp[-1] *= 0.8
    
    return widgets.HTML(f'<span class="barcode"; style="color: black; background-color: rgba{tuple(temp)}">{sent}</span>')


output_model = widgets.Output()

@output_model.capture()
def on_click_model(b: widgets.Button) -> None:
    global model
    global dataset
    global tokenizer
    global cfg
    
    modelname = {
        'BERT':'BERT',
        'FNDNet':'FNDNet',
        'HAND':'HAND'
    }
    
    cfg = yaml.load(
        open(f'./configs/{model_select.value}/{modelname[model_select.value]}-test.yaml','r'), 
        Loader = yaml.FullLoader
    )

    tokenizer, word_embed = create_tokenizer(
        name            = cfg['TOKENIZER']['name'], 
        vocab_path      = cfg['TOKENIZER'].get('vocab_path', None), 
        max_vocab_size  = cfg['TOKENIZER'].get('max_vocab_size', None)
    )

    model = create_model(
        modelname                 = cfg['MODEL']['modelname'],
        hparams                   = cfg['MODEL']['PARAMETERS'],
        word_embed                = word_embed,
        tokenizer                 = tokenizer,
        freeze_word_embed         = cfg['MODEL'].get('freeze_word_embed',False),
        use_pretrained_word_embed = cfg['MODEL'].get('use_pretrained_word_embed',False),
        checkpoint_path           = cfg['MODEL']['CHECKPOINT']['checkpoint_path'],
    )
    model.eval()
    
    dataset = __import__('dataset').__dict__[f'{cfg["DATASET"]["name"]}Dataset'](
        tokenizer = tokenizer,
        **cfg["DATASET"]['PARAMETERS']
    )
    
    
output = widgets.Output(layout=widgets.Layout(width='700px', border='1px solid black'))

@output.capture()
def on_click_run(b: widgets.Button) -> None:
    # transform inputs
    inputs = dataset.transform(
        title = title.value,
        text  = text.value.split('\\n')
    )

    # prediction
    outputs = model(**dict([(k,v.unsqueeze(0)) for k,v in inputs.items()])).detach()[0]
    outputs = torch.nn.functional.softmax(outputs, dim=-1)
    pred = outputs.argmax(dim=-1)

    if model_select.value == 'HAND':
        sents_score = model(
            **dict([(k,v.unsqueeze(0)) for k,v in inputs.items()]),
            output_attentions = True
        )[2].detach()[0]

    else:
        # define word embedding layer
        if model_select.value == 'BERT':
            def bert_encode(src):
                return [tokenizer.convert_tokens_to_ids(s) for s in tokenizer(src)]
            
            layer = model.bert.embeddings.word_embeddings
            refer_token_idx = tokenizer.vocab.token_to_idx['PAD']
            encode_func = bert_encode
            
        elif model_select.value == 'FNDNet':
            layer = model.w2e
            refer_token_idx = tokenizer.pad_token_id
            encode_func = tokenizer.encode

        # calc attribution
        token_reference = TokenReferenceBase(reference_token_idx=refer_token_idx)
        attr = LayerIntegratedGradients(model, layer)

        reference_indices = token_reference.generate_reference(
            sequence_length = cfg['DATASET']['PARAMETERS']['max_word_len'], 
            device          = 'cpu'
        ).unsqueeze(0)

        attr_score, delta = attr.attribute(
            inputs                   = inputs['input_ids'].unsqueeze(0), 
            baselines                = reference_indices, 
            target                   = pred, 
            n_steps                  = 10, 
            return_convergence_delta = True
        )

        attr_score = attr_score.squeeze().sum(dim=-1)


        # sentences mask
        sents_mask = get_sentences_mask(
            input_ids = inputs['input_ids'],
            sents = text.value.split('\\n'),
            encode_func = encode_func
        )


        # get attribution score per sentence
        sents_score = []

        for mask in sents_mask:
            score = attr_score[mask]

            # calculate only positive scores
            pos_mask = score > 0
            score = score[pos_mask].sum().item()

            sents_score.append(score)

    # scaling
    sents_score = np.array(sents_score)
    sents_score /= sents_score.sum()


    # print result
    cls = ['real','fake']
    
    html_list = []
    
    html_list.append(widgets.HTML('<h1>Result</h1>'))
    html_list.append(widgets.HTML(f"{cls[pred].capitalize()} ( {outputs[pred]:.2%} )"))

    
    html_list.append(widgets.HTML('<h1>Title</h1>'))

    html_list.append(widgets.HTML(f'<span class="barcode"; style="color: black;">{title.value}</span>'))

    html_list.append(widgets.HTML('<h1>Context</h1>'))

    for sent, score in zip(text.value.split('\\n'), sents_score):
        html_list.append(print_sent_with_score(sent=sent, score=score))
            
    display(widgets.VBox(html_list))
    
    clear_output(wait=True)
    
    
# =====================
# layout
# =====================

model_select = widgets.Dropdown(
    options=['FNDNet','HAND','BERT'],
    value='BERT',
    description='Model Type:',
    disabled=False,
)
    
title = widgets.Textarea(
    value=' ',
    placeholder='',
    description='Title:',
    layout = widgets.Layout(width='700px', height='30px')
)

text = widgets.Textarea(
    value=' ',
    placeholder='',
    description='Context:',
    layout = widgets.Layout(width='700px', height='500px')
)



button_modelname = widgets.Button(description='select')
button_modelname.on_click(on_click_model)

button_run = widgets.Button(description='run')
button_run.on_click(on_click_run)

display(widgets.HBox([model_select, button_modelname]))

display(widgets.VBox([widgets.HBox([title,button_run]), text]))

display(output)

HBox(children=(Dropdown(description='Model Type:', index=2, options=('FNDNet', 'HAN', 'BERT'), value='BERT'), …

VBox(children=(HBox(children=(Textarea(value=' ', description='Title:', layout=Layout(height='30px', width='70…

Output(layout=Layout(border='1px solid black', width='700px'))