In [2]:
import torch

import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader, Dataset
from torchtext.data.functional import to_map_style_dataset
from torchtext.datasets import AG_NEWS
from torch.utils.data.dataset import random_split

from bert import EncodedDataset, Classifier_model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def evaluate_topk(model, dataloader, topk=1):
    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for idx, data in enumerate(dataloader):
            label = data['targets'] - torch.ones(data['targets'].shape, dtype=torch.long)
            predicted_label = F.softmax(model(data), dim=1)
            top_preds, top_idx = predicted_label.topk(k=topk, dim=1)
            for i, k in enumerate(top_idx):
                total_acc += int(label[i] in k)
            total_count += label.size(0)
    return total_acc/total_count

In [4]:
device = torch.device("cpu")

batch_size = 16
num_epochs = 64
max_len = 64
vsplit = 0.05
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_iter = AG_NEWS(split='train')
trainset = to_map_style_dataset(train_iter)
n_train = int(len(trainset) * vsplit)
n_val = len(trainset) - n_train
valset, trainset = random_split(trainset, [n_train, n_val])

valset = EncodedDataset(valset, tokenizer, max_len, device)
trainset = EncodedDataset(trainset, tokenizer, max_len, device)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(valset, batch_size=batch_size, shuffle=True)

test_iter = AG_NEWS(split='test')
testset = to_map_style_dataset(test_iter)
testset = EncodedDataset(testset[:100], tokenizer, max_len, device)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

print(f"Datasets loaded - train: {len(trainset)}, val: {len(valset)}, test: {len(testset)}")




Datasets loaded - train: 114000, val: 6000, test: 100


In [5]:
model = torch.load('saved_bert', map_location=torch.device('cpu'))
for name, param in model.named_parameters():
    param.requires_grad = True

In [6]:
import inspect

argspec = inspect.getfullargspec(model.bert.forward)
print(argspec)

FullArgSpec(args=['self', 'input_ids', 'attention_mask', 'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds', 'encoder_hidden_states', 'encoder_attention_mask', 'past_key_values', 'use_cache', 'output_attentions', 'output_hidden_states', 'return_dict'], varargs=None, varkw=None, defaults=(None, None, None, None, None, None, None, None, None, None, None, None, None), kwonlyargs=[], kwonlydefaults=None, annotations={})


In [5]:
model.bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [6]:
evaluate_topk(model, testloader, topk=2)

0.98

In [7]:
idx2label = {0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tec'}

In [34]:
def _register_embedding_list_hook(model, embeddings_list):
    def forward_hook(module, inputs, output):
        embeddings_list.append(output.squeeze(0).clone().cpu().detach().numpy())
    embedding_layer = model.bert.embeddings.word_embeddings
    handle = embedding_layer.register_forward_hook(forward_hook)
    return handle

def _register_embedding_gradient_hooks(model, embeddings_gradients):
    def hook_layers(module, grad_in, grad_out):
        # TODO maybe try to do autograd(grad_out, module.weight) or something and append that
        embeddings_gradients.append(grad_out[0])
        #d2grad = torch.autograd.grad(grad_out, embedding_layer)
        #embeddings_gradients.append(d2grad)
    embedding_layer = model.bert.embeddings.word_embeddings
    hook = embedding_layer.register_full_backward_hook(hook_layers)
    return hook


In [33]:
list(model.bert.embeddings.word_embeddings.parameters())[0][0]
model.bert.embeddings.word_embeddings.

tensor([-1.0183e-02, -6.1549e-02, -2.6497e-02, -4.2061e-02,  1.1672e-03,
        -2.8272e-02, -4.4500e-02, -2.2465e-02, -4.6553e-03, -8.2129e-02,
        -5.0238e-03, -4.6508e-02, -4.9514e-02,  2.1517e-02, -1.6588e-02,
        -3.7279e-02, -7.2888e-02, -4.6671e-02,  1.9787e-03, -5.5847e-02,
        -2.8919e-02, -2.2304e-02, -4.4846e-03, -1.5506e-02, -1.0986e-01,
        -2.6746e-02,  8.3565e-03, -5.3755e-02,  8.1516e-03, -2.5817e-02,
        -2.8301e-02, -2.6342e-03, -1.7270e-02, -1.7444e-02, -5.0403e-02,
        -5.4036e-02, -3.3925e-02, -1.9397e-02, -6.2235e-02, -1.9178e-03,
        -3.0086e-02, -3.1459e-02, -5.0693e-02, -1.8174e-02,  6.8573e-03,
        -8.9839e-03, -1.1808e-02, -3.2866e-02, -3.8003e-03, -2.7472e-02,
        -3.3144e-02, -1.6076e-02, -5.8682e-02,  1.0107e-01, -2.9100e-02,
        -2.4062e-02, -1.5432e-02,  5.2106e-03, -2.3103e-03,  4.4728e-03,
        -1.1664e-02, -1.4309e-02,  1.0915e-01, -4.0001e-02, -2.9073e-02,
        -1.1655e-02, -2.0877e-02, -3.0113e-02, -6.7

In [35]:
def compute_bert_saliency_map(X, label, model, loss):
    torch.enable_grad()

    embeddings_list = []
    handle = _register_embedding_list_hook(model, embeddings_list)
    embeddings_gradients = []
    hook = _register_embedding_gradient_hooks(model, embeddings_gradients)

    model.zero_grad()
    A = model(X)
    #pred_label_ids = np.argmax(A.logits[0].detach().numpy())
    A[0,label].backward()
    handle.remove()
    hook.remove()


    saliency_grad = embeddings_gradients[0].detach().cpu().numpy()  
    saliency_grad = np.sum(saliency_grad[0] * embeddings_list[0], axis=1)
    norm = np.linalg.norm(saliency_grad, ord=1)
    # saliency_grad = [e / norm for e in saliency_grad] 
    saliency = np.abs(np.array([e / norm for e in saliency_grad]))
    return saliency


In [38]:
def compute_bert_saliency_map_d2(X, label, model, loss):
    torch.enable_grad()

    embeddings_list = []
    handle = _register_embedding_list_hook(model, embeddings_list)
    embeddings_gradients = []
    hook = _register_embedding_gradient_hooks(model, embeddings_gradients)

    model.zero_grad()
    A = model(X)
    #pred_label_ids = np.argmax(A.logits[0].detach().numpy())
    #print(A[0,label], X)
    #first_derivative = torch.autograd.grad(A[0,label], X['ids'], create_graph=True)
    #second_derivative = torch.autograd.grad(first_derivative, X)[0]

    A[0,label].backward()
    handle.remove()
    hook.remove()

    #print(embeddings_list[0].requires_grad)
    print(embeddings_gradients[1])


    #grad2 = torch.autograd.grad(embeddings_gradients[0], embeddings_list[0], create_graph=True)

    #print(embeddings_gradients)

    #saliency_grad = embeddings_gradients[0].detach().cpu().numpy()  
    saliency_grad = embeddings_gradients[1].detach().cpu().numpy()  
    print(f'd2 saliency grad shape: {saliency_grad.shape}')
    saliency_grad = np.sum(saliency_grad[0] * embeddings_list[0], axis=1)
    norm = np.linalg.norm(saliency_grad, ord=1)
    # saliency_grad = [e / norm for e in saliency_grad] 
    saliency = np.abs(np.array([e / norm for e in saliency_grad]))
    return saliency


In [39]:
def show_text_saliency_maps(X, y, tokenizer, correct_label, label_dict, model):

    prediction = model(X)
    predicted_class = prediction.argmax()
    print(f'The predicted class is: {label_dict[predicted_class.item()]}, the correct class is: {correct_label}')
    y_tensor = torch.LongTensor(y)
    saliencies = np.zeros(shape = (X['ids'].shape[1], *y.size()))

    loss = nn.CrossEntropyLoss()
    
    for i, label in enumerate(y):
        #saliencies[:, i] = compute_bert_saliency_map(X, label, model, loss)
        #saliencies[:, i] = compute_alt_saliency_map(X, label, model, loss)
        saliencies[:, i] = compute_bert_saliency_map_d2(X, label, model, loss)
        
    detoked = np.array(tokenizer.decode(X['ids'].flatten()).split())

    N = y.size()[0]
    pad_idx = np.min(np.where(detoked == '[SEP]')[0])
    x_ticks = np.arange(pad_idx+1)
    fig, axes = plt.subplots(N, 1, sharex=False, sharey=False, figsize=(8,60))
    for i, label in enumerate(y):
        axes[i].plot(saliencies[:pad_idx+1, i], x_ticks, '-o')
        # axes[i].set_xticklabels(np.linspace(start=0, stop=saliencies[:, i].max(), num=10))
        axes[i].set_yticks(ticks=x_ticks, labels=detoked[:pad_idx+1])
        axes[i].set_title(f'class \'{label_dict[label.item()]}\' prediction')
        axes[i].grid()
        axes[i].invert_yaxis()
    plt.show()

In [11]:
n = np.random.randint(low=0, high=len(testset))
test_point = testset[n]

In [40]:
idx = test_point['targets'].item()-1
label = idx2label[idx]
for key in test_point:
    test_point[key] = torch.reshape(test_point[key], shape=(1,-1))

show_text_saliency_maps(test_point, torch.tensor([0, 1, 2, 3]), tokenizer, label, idx2label, model)

The predicted class is: Business, the correct class is: Sci/Tec


TypeError: 'Embedding' object is not iterable

In [23]:
for name, param in model.named_parameters():
    param.requires_grad = True