In [7]:
from src.model.nli_models import *
from src.model.novelty_models import *
from src.defaults import *
from torchtext.data import Example 
import pandas as pd
import numpy as np
import html
import random
from IPython.core.display import display, HTML
from IPython.display import IFrame
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import warnings
from transformers import BertTokenizer, DistilBertTokenizer
warnings.filterwarnings("ignore")

def encode_text(text,field):
    ex = Example.fromlist([text],[("text",field)])
    enc = field.process([ex.text])
    return torch.tensor(enc)

def load_novelty_model(_id):
    # load model data 
    check_model(_id)
    def load_model_data(_id):
        model_path = os.path.join("./results/", _id, "model.pt")
        model_data = torch.load(model_path)
        return model_data
    field = load_field(_id)
    model_data = load_model_data(_id)
    encoder_id = model_data["options"]["load_nli"]
    check_model(encoder_id)

    def load_encoder(enc_data):
        if enc_data["options"].get("attention_layer_param", 0) == 0:
            enc_data["options"]["use_glove"] = False
            model = bilstm_snli(enc_data["options"])
        elif enc_data["options"].get("r", 0) == 0:
            enc_data["options"]["use_glove"] = False
            model = attn_bilstm_snli(enc_data["options"])
        else:
            enc_data["options"]["use_glove"] = False
            model = struc_attn_snli(enc_data["options"])
        model.load_state_dict(enc_data["model_dict"])
        return model
    
    enc_data = load_encoder_data(encoder_id)
    encoder = load_encoder(enc_data).encoder

    model = HAN(model_data["options"],encoder)
    model.load_state_dict(model_data["model_dict"])
    return model,field

def decode(inp,field):
    if hasattr(field.nesting_field,"vocab"):
        return [[field.nesting_field.vocab.itos[i] for i in sent] for sent in inp]
    else:
        tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        return [tok.convert_ids_to_tokens(i) for i in inp.tolist()]


def attention_combined(inp,field,s_att,w_att=None):
    tok_str = decode(inp,field)
    assert len(tok_str) == s_att.shape[0]
    assert len(tok_str) == w_att.shape[0]
    assert len(tok_str[0]) == w_att.shape[1]
    

    opt = []
    for sent in range(len(tok_str)):
        sent_with_att = []
        for word in range(len(tok_str[0])):
            word_str = tok_str[sent][word]
            if word_str not in ["<pad>",'[PAD]']:
                sent_with_att.append((word_str,w_att[sent][word].item()))
        if sent_with_att!=[]:
            opt.append((sent_with_att,s_att[sent].item()))
    return opt
        


def html_string(word,color,new_line = False):
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = template.format(color, '&nbsp' + word + '&nbsp') + ("<br>" if new_line else "")
    return colored_string


def colorize(attention_list):
    cmap_sent = matplotlib.cm.Blues
    cmap_word = matplotlib.cm.Reds

    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''

    for sent, sent_att in attention_list:
        sent_color = matplotlib.colors.rgb2hex(cmap_sent(sent_att*2)[:3])
        colored_string  += html_string('\t---\t ',sent_color)
        for word,word_att in sent:
            word_color = matplotlib.colors.rgb2hex(cmap_word(word_att)[:3])
            colored_string += html_string(word,word_color)
        colored_string += "<br>"
    colored_string += "<br><br><br><br>"
    return colored_string

    


In [78]:
model,field = load_novelty_model('NOV-1015')

In [79]:
source = "A balanced exercise routine for an athlete must consists of cardio as well as weight training. The way we perform these two exercises is rather important . The researchers believe that one should follow their research which suggests the ideal method for working out. A workout should start with some stretching and small dynamic movements to warm up thebody. This should be followed by the weight training or a High Intensity Interval training. We must also include a couple minutes of reset between this transition. The can be followed by aproper session of cardio exercises like running or swimming. One should adjust the intensityof this cardio exercises as per their training goal. Researchers also say that a training sessionshould end with a cool-down which includes stretching and relaxation"



In [80]:
target = "Research published by a team of sports scientists reveals some interesting in-formation about the recipe for in ideal workout session. A balanced exercise routine for an athlete must consists of cardio as well as weight training. The way we perform these twoe xercises is rather important . This should be followed by the weight training or a High In-tensity Interval training. A workout should start with some stretching and small dynamicmovements to warm up the body. The can be followed by a proper session of cardio exerciseslike running or swimming. A couple minutes of rest between these high intensity training isalso considered beneficial. This should be followed by the weight training or a High IntensityInterval training. The fligt from usa to india got delayed.In regex, captures are numbered, but in some implementations, captures can be named. To understand the basics first, let’s see an example in python where capturing is performed in the figure 5.1. In this example, we can see that the input string of ’ac’ is matched."


In [81]:
seed_torch()

def plot_attention(src,trg,model,field):
    s_enc = encode_text(src,field)
    t_enc = encode_text(trg,field)

    model.eval()
    with torch.no_grad():
        opt,s_att,t_att = model.forward_with_attn(s_enc,t_enc)
        pred = F.softmax(opt)

    src_att_map = attention_combined(s_enc[0],field,s_att[0][0,:,0],s_att[1][0])
    trg_att_map = attention_combined(t_enc[0],field,t_att[0][0,:,0],t_att[1][0])

    s_html = colorize(src_att_map)
    t_html = colorize(trg_att_map)
    with open('colorize.html', 'w') as f:
        f.write(s_html+t_html)
    print(pred)

def disp_attention():
    IFrame('./colorize.html',width=1200,height=400)




In [85]:
a = plot_attention(source,target,model,field)

tensor([[0.0631, 0.9369]])


In [86]:
IFrame('./colorize.html',width=2200,height=1000)


In [59]:
import json

In [60]:
with open('.data/dlnd/TAP-DLND-1.0_LREC2018_modified/dlnd.jsonl','r') as f:
    items = f.readlines()
data = [json.loads(i) for i in items]

In [124]:
example = data[1164]
print("Prediction:")
plot_attention(example["source"],example["target_text"],model,field)
print("Actual:")
example["DLA"]

Prediction:
tensor([[0.1512, 0.8488]])
Actual:


'Non-Novel'

In [125]:
IFrame('./colorize.html',width=2200,height=1000)


In [65]:
lens = []
for i in data:
    lens.append(len(i['source']))
print(lens.index(min(lens)))


284


In [74]:
lens = [(i,lens[i]) for i in range(len(lens))]

In [108]:
model.cuda()

from tqdm import tqdm
def predict(data,model,field):
    wrong_id = []
    for i in tqdm(range(len(data))):
        src = data[i]['source']
        trg = data[i]['target_text']
        true = data[i]['DLA']
        s_enc = encode_text(src,field)
        t_enc = encode_text(trg,field)

        model.eval()
        with torch.no_grad():
            opt,s_att,t_att = model.forward_with_attn(s_enc.cuda(),t_enc.cuda())
            pred = F.softmax(opt)[0][1].item()
        if pred > 0.5:
            pred = "Novel"
        else:
            pred = "Non-Novel"
        if pred!=true:
            wrong_id.append(i)
    return wrong_id

In [109]:
wrong_id = predict(data,model,field)


  0%|          | 0/5435 [00:00<?, ?it/s][A
  0%|          | 5/5435 [00:00<02:04, 43.75it/s][A
  0%|          | 10/5435 [00:00<02:02, 44.25it/s][A
  0%|          | 15/5435 [00:00<01:59, 45.21it/s][A
  0%|          | 20/5435 [00:00<02:02, 44.14it/s][A
  0%|          | 24/5435 [00:00<02:08, 42.23it/s][A
  1%|          | 29/5435 [00:00<02:12, 40.94it/s][A
  1%|          | 33/5435 [00:00<02:21, 38.22it/s][A
  1%|          | 37/5435 [00:00<02:31, 35.54it/s][A
  1%|          | 41/5435 [00:01<02:30, 35.73it/s][A
  1%|          | 45/5435 [00:01<02:38, 34.10it/s][A
  1%|          | 49/5435 [00:01<02:39, 33.86it/s][A
  1%|          | 53/5435 [00:01<02:32, 35.26it/s][A
  1%|          | 57/5435 [00:01<02:33, 34.97it/s][A
  1%|          | 61/5435 [00:01<02:29, 35.94it/s][A
  1%|          | 65/5435 [00:01<02:30, 35.77it/s][A
  1%|▏         | 69/5435 [00:01<02:30, 35.71it/s][A
  1%|▏         | 73/5435 [00:01<02:37, 34.11it/s][A
  1%|▏         | 77/5435 [00:02<02:37, 34.05it/s][A
  

In [117]:
model.cpu()

HAN(
  (encoder): HAN_DOC(
    (encoder): Attn_Encoder(
      (embedding): Embedding(33934, 300, padding_idx=1)
      (translate): Linear(in_features=300, out_features=400, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0.3, inplace=False)
      (lstm_layer): LSTM(400, 400, batch_first=True, dropout=0.3, bidirectional=True)
      (attention): Attention(
        (Ws): Linear(in_features=800, out_features=200, bias=False)
        (Wa): Linear(in_features=200, out_features=1, bias=False)
      )
    )
    (translate): Linear(in_features=800, out_features=400, bias=True)
    (act): ReLU()
    (dropout): Dropout(p=0.3, inplace=False)
    (lstm_layer): LSTM(400, 400, bidirectional=True)
    (attention): Attention(
      (Ws): Linear(in_features=800, out_features=200, bias=False)
      (Wa): Linear(in_features=200, out_features=1, bias=False)
    )
  )
  (act): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=3200, out_features=2, bias=True)
)

In [115]:
wrong_id

[25,
 69,
 80,
 171,
 188,
 211,
 243,
 273,
 292,
 294,
 296,
 418,
 459,
 464,
 473,
 489,
 520,
 538,
 559,
 561,
 578,
 583,
 587,
 679,
 688,
 691,
 701,
 703,
 707,
 715,
 734,
 736,
 737,
 772,
 790,
 791,
 814,
 819,
 837,
 848,
 875,
 878,
 929,
 942,
 1015,
 1024,
 1028,
 1082,
 1107,
 1130,
 1164,
 1203,
 1233,
 1251,
 1254,
 1258,
 1305,
 1370,
 1371,
 1372,
 1386,
 1394,
 1395,
 1403,
 1518,
 1626,
 1627,
 1631,
 1638,
 1652,
 1658,
 1714,
 1720,
 1783,
 1786,
 1803,
 1872,
 1902,
 1910,
 1967,
 2011,
 2097,
 2110,
 2113,
 2125,
 2131,
 2188,
 2200,
 2203,
 2204,
 2221,
 2224,
 2266,
 2291,
 2315,
 2323,
 2325,
 2339,
 2353,
 2373,
 2385,
 2398,
 2412,
 2413,
 2419,
 2432,
 2439,
 2441,
 2443,
 2449,
 2451,
 2465,
 2473,
 2486,
 2489,
 2526,
 2529,
 2533,
 2553,
 2562,
 2567,
 2589,
 2618,
 2634,
 2657,
 2658,
 2676,
 2689,
 2716,
 2721,
 2732,
 2770,
 2786,
 2795,
 2854,
 2862,
 2868,
 2873,
 3027,
 3055,
 3153,
 3279,
 3403,
 3412,
 3454,
 3471,
 3472,
 3482,
 3484,
 348

In [113]:
200/(2*543)

0.1841620626151013