# KERMIT - Visualizer -

In this notebook you can view the heat parse tree of a sentence data input.


* *Note: Before using this notebook you must have gone through the construction of the syntactic dataset and training the model with weight saving.*

* *If you don't want to train your model and try out KERMIT you can use the colab version (https://colab.research.google.com/github/ART-Group-it/KERMIT/blob/master/examples/Notebooks/KERMITviz.ipynb).*

In [None]:
#to display on an html page the heat parse trees
! pip install -qqq pyngrok

In [None]:
import os
import sys
import transformers
import torch
import torch, pickle, copy, transformers
from torchtext import data as datx
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm


In [None]:
#working gpu control
torch.cuda.is_available()
torch.cuda.get_device_name(0)

## Model definition and weights loading

* *Note: Remember that you must have saved the weights in .pt format in your notebook2 (KERMIT_training.ipynb).*

In [None]:
class DTBert(nn.Module):
    def __init__(self, input_dim_bert, input_dim_dt, output_dim):
        super().__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased').to("cuda" if torch.cuda.is_available() else "cpu")
        self.synth_sem_linear = nn.Linear(input_dim_bert + input_dim_dt, output_dim)
        
    def forward(self, x_sem, x_synth):
        with torch.no_grad():
            x_sem = self.bert(x_sem)[0][:, 0, :]
        x_tot = torch.cat((x_sem, x_synth), 1)
        x_tot = self.synth_sem_linear(x_tot)
        out = F.log_softmax(x_tot, dim=1)
        return out
        

        
BERT_DIM = 768
TREE_DIM = 4000

#number of class of dataset
OUTPUT_DIM = 5

#instantiate the model
model = DTBert(BERT_DIM, TREE_DIM, OUTPUT_DIM)        
model.cuda()

#Define a Loss function and optimizer
criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-5)

#weights of your model
pathModel = './model.pt'

model = torch.load(pathModel)



## kerMIT applied to a single instance given in input

In [None]:
from kerMIT.dtk import DT
from kerMIT.operation import fast_shuffled_convolution
from kerMIT.explain import activationSubtreeLRP as act
from kerMIT.explain import modelToExplain as mte
from kerMIT.explain import kerMITviz 

#parser definition
calculator = DT(dimension=4000, LAMBDA=0.4, operation=fast_shuffled_convolution)

In [None]:
from tree_encode import parse as parse_tree
from kerMIT.samples import utils

#sentence in input
sentence = "Don't waste your time.  We had two different people come to our house to give us estimates for a deck (one of them the OWNER)."
#target value
target = 1 
index = 2


tree_sentence, dtk_sentence, bert_sentence = utils.get_sentence(sentence, calculator)

In [None]:
tree_sentence

In [None]:
from  kerMIT.explain_pytorch import LRP_linear_layer as LRP_t

#prediction of the input sentence
y_predict = model.get_activation(bert_sentence, dtk_sentence).cpu()
#calculation of contributions through the LRP algorithm
hin, w, b, hout, Rout, bias_nb_units, eps, bias_factor = LRP_t.prepare_input_LRP(y_predict, dtk_sentence, model, BERT_DIM)
Rin = LRP_t.lrp_linear_torch(hin.cpu(), w.cpu(), b.cpu(), hout.cpu(), Rout.cpu(), bias_nb_units, eps, bias_factor, debug=False)

act_lrp = act.ActivationSubtreeLRP(calculator)
act_tree_lrp = act_lrp.activationQC(Rin.detach().numpy(), tree_sentence)

In [None]:
from tree_encode import parse as parse_tree

#visualization of previously extracted activations through heat parse tree
heat_parse_tree = kerMITviz.assign_contribution_nodes(act_tree_lrp)
#kerMITviz.show_kerMITviz(heat_parse_tree)

In [None]:
import kerMIT
import pathlib

#Find the directory where the heat parse tree is contained
def search_path():
    path_file_html = os.path.join(kerMIT.__path__[0], 'ACTree', 'tree_visualizer_pyDTE','index.html')
    script_path = str(pathlib.Path().absolute())
    path_list = script_path.split(os.sep)
    script_directory = path_list[0:len(path_list)-1]
    script_directoryp= ["../" for i in script_directory]
    rel_path = path_file_html
    path = "".join(script_directoryp) + "" + rel_path[1:]
    return path

path_js = os.path.join(kerMIT.__path__[0], 'ACTree', 'tree_visualizer_pyDTE','heat_parse_trees','act_trees.js')
path_html = search_path()

## Visualisation of the heat parse tree

In [None]:
import http.server
import socketserver
from pyngrok import ngrok
import os

PORT = 5000

VISUALIZER_URL = path_html


os.chdir(VISUALIZER_URL)
url = ngrok.connect(port=PORT)
Handler = http.server.SimpleHTTPRequestHandler

with socketserver.TCPServer(("", PORT), Handler) as httpd:
    print("The visualizer is running at the following address:", url)
    httpd.serve_forever()