In [1]:
from src.model.nli_models import *
from src.model.novelty_models import *
from src.defaults import *
from torchtext.data import Example 
import pandas as pd
import numpy as np
import html
import random
from IPython.core.display import display, HTML
from IPython.display import IFrame
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import warnings
from transformers import BertTokenizer, DistilBertTokenizer
warnings.filterwarnings("ignore")

def encode_text(text,field):
    ex = Example.fromlist([text],[("text",field)])
    enc = field.process([ex.text])
    return torch.tensor(enc)

def load_novelty_model(_id):
    # load model data 
    check_model(_id)
    def load_model_data(_id):
        model_path = os.path.join("./results/", _id, "model.pt")
        model_data = torch.load(model_path)
        return model_data
    field = load_field(_id)
    model_data = load_model_data(_id)
    encoder_id = model_data["options"]["load_nli"]
    check_model(encoder_id)

    def load_encoder(enc_data):
        if enc_data["options"].get("attention_layer_param", 0) == 0:
            enc_data["options"]["use_glove"] = False
            model = bilstm_snli(enc_data["options"])
        elif enc_data["options"].get("r", 0) == 0:
            enc_data["options"]["use_glove"] = False
            model = attn_bilstm_snli(enc_data["options"])
        else:
            enc_data["options"]["use_glove"] = False
            model = struc_attn_snli(enc_data["options"])
        model.load_state_dict(enc_data["model_dict"])
        return model
    
    enc_data = load_encoder_data(encoder_id)
    encoder = load_encoder(enc_data).encoder

    model = HAN(model_data["options"],encoder)
    model.load_state_dict(model_data["model_dict"])
    return model,field

def decode(inp,field):
    if hasattr(field.nesting_field,"vocab"):
        return [[field.nesting_field.vocab.itos[i] for i in sent] for sent in inp]
    else:
        tok = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        return [tok.convert_ids_to_tokens(i) for i in inp.tolist()]


def attention_combined(inp,field,s_att,w_att=None):
    tok_str = decode(inp,field)
    assert len(tok_str) == s_att.shape[0]
    assert len(tok_str) == w_att.shape[0]
    assert len(tok_str[0]) == w_att.shape[1]
    

    opt = []
    for sent in range(len(tok_str)):
        sent_with_att = []
        for word in range(len(tok_str[0])):
            word_str = tok_str[sent][word]
            if word_str not in ["<pad>",'[PAD]']:
                sent_with_att.append((word_str,w_att[sent][word].item()))
        if sent_with_att!=[]:
            opt.append((sent_with_att,s_att[sent].item()))
    return opt
        


def html_string(word,color,new_line = False):
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = template.format(color, '&nbsp' + word + '&nbsp') + ("<br>" if new_line else "")
    return colored_string


def colorize(attention_list):
    cmap_sent = matplotlib.cm.Blues
    cmap_word = matplotlib.cm.Reds

    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''

    for sent, sent_att in attention_list:
        sent_color = matplotlib.colors.rgb2hex(cmap_sent(sent_att*1)[:3])
        colored_string  += html_string('\t---\t ',sent_color)
        for word,word_att in sent:
            word_color = matplotlib.colors.rgb2hex(cmap_word(word_att)[:3])
            colored_string += html_string(word,word_color)
        colored_string += "<br>"
    colored_string += "<br><br><br><br>"
    return colored_string

    seed_torch()

def plot_attention(src,trg,model,field):
    s_enc = encode_text(src,field)
    t_enc = encode_text(trg,field)

    model.eval()
    with torch.no_grad():
        opt,s_att,t_att = model.forward_with_attn(s_enc,t_enc)
        pred = F.softmax(opt)

    src_att_map = attention_combined(s_enc[0],field,s_att[0][0,:,0],s_att[1][0])
    trg_att_map = attention_combined(t_enc[0],field,t_att[0][0,:,0],t_att[1][0])

    s_html = colorize(src_att_map)
    t_html = colorize(trg_att_map)
    with open('colorize.html', 'w') as f:
        f.write(s_html+t_html)
    print(pred)

def disp_attention():
    IFrame('./colorize.html',width=1200,height=400)




In [2]:
model,field = load_novelty_model('NOV-1145')

In [3]:
source = "We also experimented with the document encoder to find if document level pretraining has any impact on the novelty detection performance. We train our document encoder described in on the Reuters dataset with an objective of 10 class classification. The reuters dataset aligns with the dataset we use for novelty detection, the Reuters dataset contains news articles which are to be classified into categories like Investment, Shipping, Crop, Oil and so on"



In [4]:
target = "Identifing each of these classes requires the ability to extract features which tell which industry the news is related to. We hypothesise that this information is also essential while calculating the novelty of a document, since knowing if the target document is talking about the same thing or topic is also important. This can be seen as assisting the information filtering task. For this experiment we have 3 settings, we test the impact with and without pretraining for Reuters dataset and Reuters+NLI dataset combined. The settings used are listed below."


{'attention_hops': 10,
 'attention_input': 400,
 'attention_layer_param': 200,
 'dataset': 'dlnd',
 'device': 'cuda',
 'dropout': 0.3,
 'encoder_dim': 400,
 'folds': False,
 'freeze_encoder': False,
 'hidden_size': 400,
 'labeled': -1,
 'load_han': 'None',
 'load_nli': 'NLI-93',
 'max_num_sent': 50,
 'num_layers': 1,
 'reset_enc': False,
 'results_dir': 'results',
 'secondary_dataset': 'None',
 'seed': -1,
 'sent_tokenizer': 'spacy'}

In [5]:
a = plot_attention(source,target,model,field)

,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,
           0.0259,  0.0294, -0.0192],
         [ 0.0004,  0.0159,  0.0234,  0.0238,  0.0238, -0.0176, -0.0311,

In [6]:
IFrame('./colorize.html',width=2200,height=1000)


In [7]:
import json

In [8]:
with open('.data/dlnd/TAP-DLND-1.0_LREC2018_modified/dlnd.jsonl','r') as f:
    items = f.readlines()
data = [json.loads(i) for i in items]

In [29]:
example = data[886]
print("Prediction:")
plot_attention(example["source"],example["target_text"],model,field)
print("Actual:")
example["DLA"]

Prediction:
tensor([[1., 0.]])
Actual:


'Novel'

In [30]:
IFrame('./colorize.html',width=2200,height=2000)


In [18]:
lens = []
for i in data:
    lens.append(len(i['source']))
print(lens.index(min(lens)))


284


In [19]:
lens = [(i,lens[i]) for i in range(len(lens))]

In [20]:
model.cuda()

from tqdm import tqdm
def predict(data,model,field):
    wrong_id = []
    for i in tqdm(range(len(data))):
        src = data[i]['source']
        trg = data[i]['target_text']
        true = data[i]['DLA']
        s_enc = encode_text(src,field)
        t_enc = encode_text(trg,field)

        model.eval()
        with torch.no_grad():
            opt,s_att,t_att = model.forward_with_attn(s_enc.cuda(),t_enc.cuda())
            pred = F.softmax(opt)[0][1].item()
        if pred > 0.5:
            pred = "Novel"
        else:
            pred = "Non-Novel"
        if pred!=true:
            wrong_id.append(i)
    return wrong_id

In [21]:
wrong_id = predict(data,model,field)

100%|██████████| 5435/5435 [02:52<00:00, 31.56it/s]


In [22]:
model.cpu()

HAN(
  (encoder): HAN_DOC(
    (encoder): Attn_Encoder(
      (embedding): Embedding(33934, 300, padding_idx=1)
      (translate): Linear(in_features=300, out_features=400, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0.3, inplace=False)
      (lstm_layer): LSTM(400, 400, batch_first=True, dropout=0.3, bidirectional=True)
      (attention): Attention(
        (Ws): Linear(in_features=800, out_features=200, bias=False)
        (Wa): Linear(in_features=200, out_features=1, bias=False)
      )
    )
    (translate): Linear(in_features=800, out_features=400, bias=True)
    (act): ReLU()
    (dropout): Dropout(p=0.3, inplace=False)
    (lstm_layer): LSTM(400, 400, bidirectional=True)
    (attention): StrucSelfAttention(
      (ut_dense): Linear(in_features=800, out_features=200, bias=False)
      (et_dense): Linear(in_features=200, out_features=10, bias=False)
    )
  )
  (act): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=32000, out_features=2,

In [137]:
c=0
for i in sorted(lens,key = lambda x:x[1]): 
    c+=1
    if i[0] in wrong_id:
        print(i)
        break

(292, 1737)


In [23]:
wrong_id

[69,
 83,
 112,
 131,
 148,
 188,
 218,
 265,
 290,
 294,
 296,
 316,
 378,
 380,
 381,
 410,
 419,
 440,
 471,
 473,
 476,
 517,
 525,
 535,
 537,
 549,
 691,
 693,
 707,
 717,
 741,
 755,
 756,
 767,
 773,
 791,
 805,
 816,
 833,
 839,
 841,
 842,
 851,
 853,
 872,
 881,
 882,
 886,
 890,
 939,
 1024,
 1059,
 1070,
 1103,
 1123,
 1138,
 1141,
 1180,
 1203,
 1217,
 1233,
 1263,
 1264,
 1268,
 1280,
 1286,
 1300,
 1306,
 1313,
 1332,
 1338,
 1362,
 1416,
 1628,
 1637,
 1646,
 1647,
 1648,
 1658,
 1696,
 1697,
 1701,
 1714,
 1746,
 1753,
 1754,
 1756,
 1757,
 1759,
 1760,
 1761,
 1769,
 1771,
 1774,
 1778,
 1783,
 1795,
 1805,
 1830,
 1838,
 1839,
 1872,
 1902,
 1910,
 1964,
 1988,
 2015,
 2035,
 2052,
 2057,
 2058,
 2063,
 2070,
 2084,
 2088,
 2120,
 2123,
 2126,
 2136,
 2158,
 2165,
 2174,
 2176,
 2188,
 2255,
 2274,
 2276,
 2281,
 2287,
 2289,
 2296,
 2304,
 2305,
 2314,
 2323,
 2324,
 2326,
 2327,
 2341,
 2354,
 2355,
 2373,
 2398,
 2427,
 2435,
 2439,
 2447,
 2533,
 2565,
 2586,
 2

In [139]:
c

9