In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from transformer import *
from util import *
from data import *
import matplotlib.pyplot as plt
from heatmap import html_heatmap
from IPython.display import display, HTML

In [3]:
train, valid, test = get_dis('/home/yuhuiz/Transformer/data/csu/', 'csu_bpe', 'csu', 600, False)
text_encoder = TextEncoder('/home/yuhuiz/Transformer/data/sage/encoder_bpe_50000.json', '/home/yuhuiz/Transformer/data/sage/vocab_50000.bpe')
encoder = text_encoder.encoder
encoder['_pad_'] = len(encoder)
encoder['_start_'] = len(encoder)
encoder['_end_'] = len(encoder)
encoder['_unk_'] = len(encoder)
decoder = {v: k for k, v in encoder.items()}

for split in ['s1']:
    for data_type in ['train', 'valid', 'test']:
        num_sents = []
        y_sents = []
        for sent in eval(data_type)[split]:
            num_sent = text_encoder.encode([sent], lazy=True, bpe=False)[0]
            num_sents.append([encoder['_start_']] + num_sent + [encoder['_end_']])
        eval(data_type)[split] = np.array(num_sents)

In [11]:
# good example [(16, 39), (47, 26), (21, 38), (21, 40), ]
def interpret_transformer(sample_idx, label_idx, silent=True, threshold=None):
    # model
    model = torch.load('/home/yuhuiz/Transformer/exp/bpe/transformer_auxiliary_pretrain/model-8.pickle', map_location='cpu')
    model.eval()

    s1 = test['s1'] # test or valid
    target = test['label'] # test or valid

    # data
    s1_batch = pad_batch(s1[sample_idx:sample_idx+1], encoder['_pad_'])
    label_batch = target[sample_idx:sample_idx+1]
    b = Batch(s1_batch, label_batch, [], encoder['_pad_'])

    # interpret
    x = model.tgt_embed[0](b.s1)
    xx = model.tgt_embed[1](x)
    u_h = model.decoder(xx, b.s1_mask)
    u = model.pick_h(u_h, b.s1_lengths)
    picked_s1_mask = model.pick_mask(b.s1_mask, b.s1_lengths)
    u = model.projection_layer(u, u_h, u_h, picked_s1_mask)
    clf_output = model.classifier(u)
    pred = (torch.sigmoid(clf_output) > 0.5)
    y = clf_output[0][label_idx]
    model.zero_grad()
    grad = x * torch.autograd.grad(y, x)[0]

    # visualize
    grad = grad.sum(-1).data.squeeze().numpy()
    if threshold != None:
        grad = grad * (grad > threshold)
    grad = grad.tolist()
    
    text_id = b.s1.squeeze().numpy().tolist()
    text = [decoder[i] for i in text_id]
    label = get_labels('csu')
    if not silent:
        print('logits =', clf_output.squeeze())
        print('pred =', pred.squeeze().nonzero().squeeze().numpy().tolist())
        print('pred =', [label[i] for i in pred.squeeze().nonzero().squeeze().numpy().tolist()])
        print('label =', b.label.squeeze().nonzero().squeeze().numpy().tolist())
        print('label =', [label[i] for i in b.label.squeeze().nonzero().squeeze().numpy().tolist()])
    display(HTML(html_heatmap(text[1:], grad[1:])))

In [4]:
def interpret_lstm(sample_idx, label_idx, silent=True):
    # model
    model = torch.load('/home/yuhuiz/Transformer/exp/bpe/lstm_auxiliary_pretrain/model-8.pickle', map_location='cpu')
    model.eval()

    s1 = test['s1'] # test or valid
    target = test['label'] # test or valid

    # data
    s1_batch = pad_batch(s1[sample_idx:sample_idx+1], encoder['_pad_'])
    label_batch = target[sample_idx:sample_idx+1]
    b = Batch(s1_batch, label_batch, [], encoder['_pad_'])

    # interpret
    x = model.tgt_embed[0](b.s1)
    xx = model.tgt_embed[1](x)
    u_h = model.autolen_rnn(xx, b.s1_lengths)
    u = model.pick_h(u_h, b.s1_lengths)
    picked_s1_mask = model.pick_mask(b.s1_mask, b.s1_lengths)
    u = model.projection_layer(u, u_h, u_h, picked_s1_mask)
    clf_output = model.classifier(u)
    pred = (torch.sigmoid(clf_output) > 0.5)
    y = clf_output[0][label_idx]
    model.zero_grad()
    grad = x * torch.autograd.grad(y, x)[0]

    # visualize
    grad = grad.sum(-1).data.squeeze().numpy().tolist()
    text_id = b.s1.squeeze().numpy().tolist()
    text = [decoder[i] for i in text_id]
    label = get_labels('csu')
    if not silent:
        print('logits =', clf_output.squeeze())
        print('pred =', pred.squeeze().nonzero().squeeze().numpy().tolist())
        print('pred =', [label[i] for i in pred.squeeze().nonzero().squeeze().numpy().tolist()])
        print('label =', b.label.squeeze().nonzero().squeeze().numpy().tolist())
        print('label =', [label[i] for i in b.label.squeeze().nonzero().squeeze().numpy().tolist()])
    display(HTML(html_heatmap(text[1:], grad[1:])))

In [30]:
# [9]
interpret_transformer(9, 40, silent=False, threshold=None)
# interpret_lstm(47, 26, silent=False)



logits = tensor([-21.3143, -25.2147, -29.8355, -24.4539, -24.8130, -22.6674,
        -27.4273, -23.7908,  -9.7768, -17.0644, -18.7451, -13.3065,
        -11.5323, -14.0741, -16.4290, -16.0483, -10.5599, -10.9031,
         -8.1654,  -3.0968,  -5.1971,  -6.2017, -12.4415,  -1.8935,
         -9.1356, -11.2041,  -2.9586,  -4.3047, -11.0542,  -6.8276,
         -4.0453,  -3.7381,   7.9594, -10.4458,  -0.2958,  -9.0764,
         -2.1545,  -3.0684,   6.3305,  -5.4667,   4.2302,   4.9700])
pred = [32, 38, 40, 41]
pred = ['Traumatic AND/OR non-traumatic injury (disorder)', 'Disorder of integument (disorder)', 'Neoplasm and/or hamartoma (disorder)', 'Clinical finding (finding)']
label = [40, 41]
label = ['Neoplasm and/or hamartoma (disorder)', 'Clinical finding (finding)']
