## Helper functions for writing LaTeX files.

In [None]:
# Reference: https://github.com/jiesutd/Text-Attention-Heatmap-Visualization

import numpy as np

latex_special_token = ["!@#$%^&*()"]

def generate(text_list, attention_list, sent_pos, new_mentions, color='red', rescale_value = False):
    assert(len(text_list) == len(attention_list))
    if rescale_value:
        attention_list = rescale(attention_list)
    word_num = len(text_list)
    text_list = clean_word(text_list)
    string = ''
    for idx in range(word_num):
        if idx + 1 in sent_pos:
            string += f'[{sent_pos[idx + 1]}]'
        if idx in new_mentions:
             string += "\\colorbox{%s!%.3f}{"%(color, attention_list[idx])+"\\strut " + "\\textbf{" + text_list[idx]+"}} "
        else:
            string += "\\colorbox{%s!%.3f}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} "
            
    string += '\n'
    return string

def rescale(input_list):
    the_array = np.asarray(input_list)
    the_max = np.max(the_array)
    the_min = np.min(the_array)
    rescale = (the_array - the_min)/(the_max-the_min)*100
    return rescale.tolist()


def clean_word(word_list):
    new_word_list = []
    for word in word_list:
        for latex_sensitive in ["\\", "%", "&", "^", "#", "_",  "{", "}"]:
            if latex_sensitive in word:
                word = word.replace(latex_sensitive, '\\'+latex_sensitive)
        new_word_list.append(word)
    return new_word_list


## Load Model and make predictions.

In [None]:
import pickle
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
from model import DocREModel
from prepro import read_docred

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name_or_path = "bert-base-cased"
num_class = 97
num_labels = 4
max_sent_num = 25
evi_thresh = 0.2
transformer_type = "bert"

config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_class,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
)
model = AutoModel.from_pretrained(
    model_name_or_path,
    from_tf=bool(".ckpt" in model_name_or_path),
    config=config,
)

config.transformer_type = transformer_type

read = read_docred    
config.cls_token_id = tokenizer.cls_token_id
config.sep_token_id = tokenizer.sep_token_id

model = DocREModel(config, model, tokenizer,
                num_labels=num_labels,
                max_sent_num=max_sent_num, 
                evi_thresh=evi_thresh)

model.to(device)

In [None]:
import os

load_path = '/path/to/trained/model'
model_path = os.path.join(load_path, "best.ckpt")
model.load_state_dict(torch.load(model_path))

In [None]:
test_file = '/path/to/file/to/predict'  
max_seq_length = 1024
test_features = read(test_file, tokenizer, transformer_type=transformer_type, max_seq_length=max_seq_length)

In [None]:
from torch.utils.data import DataLoader
from utils import collate_fn
from run import load_input

def evaluate(model, features, tag="infer"):
    
    dataloader = DataLoader(features, batch_size=1, shuffle=False, collate_fn=collate_fn, drop_last=False)
    preds, evi_preds = [], []
    scores, topks = [], []
    toks, attns = [], []
    
    for batch in dataloader:
        model.eval()
        
        inputs = load_input(batch, device, tag)
        
        toks.append(inputs['input_ids'])
        
        with torch.no_grad():
            outputs = model(**inputs)
            pred = outputs["rel_pred"]
            pred = pred.cpu().numpy()
            pred[np.isnan(pred)] = 0
            preds.append(pred)

            if "scores" in outputs:
                scores.append(outputs["scores"].cpu().numpy())  
                topks.append(outputs["topks"].cpu().numpy())   

            if "evi_pred" in outputs: # relation extraction and evidence extraction
                evi_pred = outputs["evi_pred"]
                evi_pred = evi_pred.cpu().numpy()
                evi_preds.append(evi_pred)   
            
            if "attns" in outputs: # attention recorded
                attn = outputs["attns"]
                attns.extend([a.cpu().numpy() for a in attn])
        
    return toks, attns

In [None]:
words, attns = evaluate(model, test_features)

## Visualization by writing into a LaTeX file.

In [None]:
doc_id = 925

In [None]:
curr_toks = words[doc_id][0]
curr_attns = attns[doc_id]

In [None]:
test_features[doc_id].keys()

In [None]:
curr_attns = [attn for attn in curr_attns if attn.shape[0] != 0]

In [None]:
# get gold relations.

valid_rels = []

for i, one_hot in enumerate(test_features[doc_id]["labels"]):
    valid = [j for j in range(1, num_class) if one_hot[j] != 0]
    if valid:
        valid_rels.append((i, valid))

In [None]:
# get name of relations.

docred_rel2id = json.load(open('meta/rel2id.json', 'r'))
id2name = json.load(open('meta/rel_info.json'))

docred_id2rel = {v: id2name[k] for (k,v) in docred_rel2id.items() if k in id2name}

In [None]:
def tok2word(toks, attns, sent_pos, ments):
    
    words = []
    new_attns = []
    old2new = {}
    
    for tid, tok in enumerate(toks):
        
        if tok in ['[CLS]', '[SEP]']:
            continue
            
        if tok.startswith('##'):
            words[-1] += tok.strip('##')
            new_attns[-1] += attns[tid]
            old2new[tid] = len(words)
        
        else:
            words.append(tok)
            new_attns.append(attns[tid])
            old2new[tid] = len(words)
            
    new_attns = new_attns/sum(new_attns)

    new_ments = []
    
    for ment in ments:
        new_ments.extend(range(old2new[ment[0]], old2new[ment[1]]))
        
    new_sent_pos = {}
    
    for sid, pos in enumerate(sent_pos):
        new_sent_pos[old2new[pos[0] + 1]] = sid + 1
        
    return words, new_attns, new_sent_pos, new_ments, old2new


In [None]:
def get_rels(words, head, tail, old2new, rels, evi):
    
    head = words[old2new[head[0][0]+1]: old2new[head[0][1]-1]]
    tail = words[old2new[tail[0][0]+1]: old2new[tail[0][1]-1]]
    info = "\\textbf{subject}:" + ' '.join(head) + "; \\textbf{object}:" + ' '.join(tail) + \
            "; \\textbf{relation}:" + ', '.join(rels) + "; \\textbf{evidence}:" 
    
    for e in evi:
        info += str(e) + ','
    
    info = info[:-1]
    info += '\n'
    
    return info

In [None]:
latex_file = 'toy.tex'

with open(latex_file, 'w') as f:
    f.write(r'''\documentclass{standalone}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\usepackage{adjustbox}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}'''+'\n')
    string = r'''{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{'''+"\n"
    
    for rid, rel in valid_rels:
    
        curr_rels =  [docred_id2rel[r] for r in rel]
        curr_doc_toks = tokenizer.convert_ids_to_tokens(curr_toks)
        curr_doc_attn = curr_attns[rid]
        curr_sent_pos = test_features[doc_id]["sent_pos"]
        curr_evi = [sid + 1 for sid, val in enumerate(test_features[doc_id]["sent_labels"][rid]) if val != 0]
        curr_ht = test_features[doc_id]["hts"][rid]
        curr_head = test_features[doc_id]["entity_pos"][curr_ht[0]]
        curr_tail = test_features[doc_id]["entity_pos"][curr_ht[1]]
        mentions = curr_head + curr_tail
        curr_doc_words, curr_doc_attn, sent_pos, new_mentions, old2new = tok2word(curr_doc_toks, curr_doc_attn, curr_sent_pos, mentions)
        curr_rels = get_rels(curr_doc_words, curr_head, curr_tail, old2new, curr_rels, curr_evi)
        string += curr_rels + "\n"
            
        OldRange = curr_doc_attn.max() - curr_doc_attn.min()
        NewRange = 100  
        NewValue = (((curr_doc_attn - curr_doc_attn.min()) * NewRange) / OldRange) + 0
        string += generate(curr_doc_words, NewValue, sent_pos, new_mentions, "red")
        string += '\n'
    string += "\n}}}"
    f.write(string +'\n')
    f.write(r'''\end{CJK*}
\end{document}''')
        
