In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from transformer import *
from util import *
from data import *
import matplotlib.pyplot as plt
from heatmap import html_heatmap
from IPython.display import display, HTML

In [2]:
train, valid, test = get_dis('/home/yuhuiz/Transformer/data/csu/', 'csu_bpe', 'csu', 600, False)
text_encoder = TextEncoder('/home/yuhuiz/Transformer/data/sage/encoder_bpe_50000.json', '/home/yuhuiz/Transformer/data/sage/vocab_50000.bpe')
encoder = text_encoder.encoder
encoder['_pad_'] = len(encoder)
encoder['_start_'] = len(encoder)
encoder['_end_'] = len(encoder)
encoder['_unk_'] = len(encoder)
decoder = {v: k for k, v in encoder.items()}

for split in ['s1']:
    for data_type in ['train', 'valid', 'test']:
        num_sents = []
        y_sents = []
        for sent in eval(data_type)[split]:
            num_sent = text_encoder.encode([sent], lazy=True, bpe=False)[0]
            num_sents.append([encoder['_start_']] + num_sent + [encoder['_end_']])
        eval(data_type)[split] = np.array(num_sents)

In [3]:
# good example [(16, 39), (47, 26), (21, 38), (21, 40), ]
def interpret_transformer(sample_idx, label_idx, silent=True):
    # model
    model = torch.load('/home/yuhuiz/Transformer/exp/bpe/transformer_auxiliary_pretrain/model-8.pickle', map_location='cpu')
    model.eval()

    s1 = test['s1'] # test or valid
    target = test['label'] # test or valid

    # data
    s1_batch = pad_batch(s1[sample_idx:sample_idx+1], encoder['_pad_'])
    label_batch = target[sample_idx:sample_idx+1]
    b = Batch(s1_batch, label_batch, [], encoder['_pad_'])

    # interpret
    x = model.tgt_embed[0](b.s1)
    xx = model.tgt_embed[1](x)
    u_h = model.decoder(xx, b.s1_mask)
    u = model.pick_h(u_h, b.s1_lengths)
    picked_s1_mask = model.pick_mask(b.s1_mask, b.s1_lengths)
    u = model.projection_layer(u, u_h, u_h, picked_s1_mask)
    clf_output = model.classifier(u)
    pred = (torch.sigmoid(clf_output) > 0.5)
    y = clf_output[0][label_idx]
    model.zero_grad()
    grad = x * torch.autograd.grad(y, x)[0]

    # visualize
    grad = grad.sum(-1).data.squeeze().numpy().tolist()
    text_id = b.s1.squeeze().numpy().tolist()
    text = [decoder[i] for i in text_id]
    label = get_labels('csu')
    if not silent:
        print('logits =', clf_output.squeeze())
        print('pred =', pred.squeeze().nonzero().squeeze().numpy().tolist())
        print('pred =', [label[i] for i in pred.squeeze().nonzero().squeeze().numpy().tolist()])
        print('label =', b.label.squeeze().nonzero().squeeze().numpy().tolist())
        print('label =', [label[i] for i in b.label.squeeze().nonzero().squeeze().numpy().tolist()])
    display(HTML(html_heatmap(text[1:], grad[1:])))

In [4]:
def interpret_lstm(sample_idx, label_idx, silent=True):
    # model
    model = torch.load('/home/yuhuiz/Transformer/exp/bpe/lstm_auxiliary_pretrain/model-8.pickle', map_location='cpu')
    model.eval()

    s1 = test['s1'] # test or valid
    target = test['label'] # test or valid

    # data
    s1_batch = pad_batch(s1[sample_idx:sample_idx+1], encoder['_pad_'])
    label_batch = target[sample_idx:sample_idx+1]
    b = Batch(s1_batch, label_batch, [], encoder['_pad_'])

    # interpret
    x = model.tgt_embed[0](b.s1)
    xx = model.tgt_embed[1](x)
    u_h = model.autolen_rnn(xx, b.s1_lengths)
    u = model.pick_h(u_h, b.s1_lengths)
    picked_s1_mask = model.pick_mask(b.s1_mask, b.s1_lengths)
    u = model.projection_layer(u, u_h, u_h, picked_s1_mask)
    clf_output = model.classifier(u)
    pred = (torch.sigmoid(clf_output) > 0.5)
    y = clf_output[0][label_idx]
    model.zero_grad()
    grad = x * torch.autograd.grad(y, x)[0]

    # visualize
    grad = grad.sum(-1).data.squeeze().numpy().tolist()
    text_id = b.s1.squeeze().numpy().tolist()
    text = [decoder[i] for i in text_id]
    label = get_labels('csu')
    if not silent:
        print('logits =', clf_output.squeeze())
        print('pred =', pred.squeeze().nonzero().squeeze().numpy().tolist())
        print('pred =', [label[i] for i in pred.squeeze().nonzero().squeeze().numpy().tolist()])
        print('label =', b.label.squeeze().nonzero().squeeze().numpy().tolist())
        print('label =', [label[i] for i in b.label.squeeze().nonzero().squeeze().numpy().tolist()])
    display(HTML(html_heatmap(text[1:], grad[1:])))

In [6]:
interpret_transformer(47, 26, silent=False)
interpret_lstm(47, 26, silent=False)



logits = tensor([-15.8784, -17.1387, -14.8995, -15.5421, -15.3664, -15.9373,
        -14.7549, -12.1735, -10.4198, -10.1664, -11.0308, -11.1979,
        -11.5704, -10.9385,  -8.3562,  -5.7506,  -5.3576,  -9.3167,
        -11.0783,  -7.6009,  -8.8532,  -7.9168,  -4.7730,   8.4960,
         -6.7209,  -0.1891,   8.5590,  -7.2602, -10.5429,  -7.0183,
         -5.5561,  -7.4463,  -6.0965,  -6.9961,  -2.9750,  -9.2255,
         -6.2264,  -7.4816,   5.0928,  -4.2168,  -6.5594,   6.6768])
pred = [23, 26, 38, 41]
pred = ['Propensity to adverse reactions (disorder)', 'Hypersensitivity condition (disorder)', 'Disorder of integument (disorder)', 'Clinical finding (finding)']
label = [23, 26, 38, 41]
label = ['Propensity to adverse reactions (disorder)', 'Hypersensitivity condition (disorder)', 'Disorder of integument (disorder)', 'Clinical finding (finding)']




logits = tensor([-16.7165, -16.6613, -15.7218, -16.8673, -16.3042, -17.4315,
        -11.0038, -11.1158,  -9.9076,  -9.7221, -10.7704,  -9.3859,
         -7.8312, -10.5752,  -9.0651,  -7.8931, -10.2982,  -8.1445,
         -9.1579,  -9.5837,  -9.1034,  -7.4837,  -7.3437,   7.4584,
         -7.8867,  -1.0263,   9.5782,  -5.3957,  -8.1797,  -6.8150,
         -5.2576,  -4.8732,  -4.0158,  -6.8626,  -2.1679,  -7.0167,
         -4.4482,  -6.1156,   6.9229,  -4.6591,  -3.1277,   6.7496])
pred = [23, 26, 38, 41]
pred = ['Propensity to adverse reactions (disorder)', 'Hypersensitivity condition (disorder)', 'Disorder of integument (disorder)', 'Clinical finding (finding)']
label = [23, 26, 38, 41]
label = ['Propensity to adverse reactions (disorder)', 'Hypersensitivity condition (disorder)', 'Disorder of integument (disorder)', 'Clinical finding (finding)']
