Two Options:
1. pooled_output from self.bert is (batch_size, seq_len, embedding_size). Basically, pooled_output can be viewed as a representation learned. Since this vector has the same seq_len as the input, we can do an aggregation over the third axis so that each token will have a score for heatmap.
2. directly back-prop to input_ids.

Original Notebook: https://colab.research.google.com/drive/1PHv-IRLPCtv7oTcIGbsgZHqrB5LPvB7S#scrollTo=PGnlRWvkY-2c

In [1]:
!pip install transformers==2.6.0

Collecting transformers==2.6.0
  Downloading transformers-2.6.0-py3-none-any.whl (540 kB)
[K     |████████████████████████████████| 540 kB 5.3 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 37.9 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 32.2 MB/s 
[?25hCollecting tokenizers==0.5.2
  Downloading tokenizers-0.5.2-cp37-cp37m-manylinux1_x86_64.whl (5.6 MB)
[K     |████████████████████████████████| 5.6 MB 12.0 MB/s 
Collecting boto3
  Downloading boto3-1.21.16-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 33.6 MB/s 
Collecting jmespath<1.0.0,>=0.7.1
  Downloading jmespath-0.10.0-py2.py3-none-any.whl (24 kB)
Collecting botocore<1.25.0,>=1.24.16
  Downloading botocore-1.24.16-py3-none-any.whl (8.6 MB)
[K     |████████████████████

In [9]:
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertModel, BertTokenizer
import torch

import numpy as np
import pandas as pd

import torch.nn.functional as F

In [10]:
# Global Path Vairables
ROOT_DIR =  "drive/MyDrive/11877-AMMML/"
DATASET_DIR = ROOT_DIR + "dataset/random/nlp/"

In [11]:
df = pd.read_csv(DATASET_DIR + "reviews.csv")
def to_sentiment(rating):
  rating = int(rating)
  if rating <= 2:
    return 0
  elif rating == 3:
    return 1
  else: 
    return 2

df['sentiment'] = df.score.apply(to_sentiment)
class_names = ['negative', 'neutral', 'positive']
df = df.sample(frac=1).reset_index(drop=True)

In [14]:
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
bert_model = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [121]:
class TextDataset(Dataset):
  def __init__(self, texts, targets, tokenizer, max_len):
    self.texts = texts
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
  
  def __len__(self):
    return len(self.texts)
  
  def __getitem__(self, item):
    text = str(self.texts[item])
    target = self.targets[item]

    encoding = self.tokenizer.encode_plus(
      text,
      add_special_tokens=True,
      max_length=self.max_len,
      return_token_type_ids=False,
      # pad_to_max_length=True,
      padding="max_length",
      return_attention_mask=True,
      return_tensors='pt',
    )

    return {
      'text': text,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long)
    }

In [122]:
dataset = TextDataset(
    texts=df.content.to_numpy()[:10],
    targets=df.sentiment.to_numpy()[:10],
    tokenizer=tokenizer,
    max_len=160
  )

dataloader = DataLoader(dataset=dataset, shuffle=True, batch_size=1)

In [158]:
class SentimentClassifier(nn.Module):
  def __init__(self, n_classes, visualization="gradcam"):
    super(SentimentClassifier, self).__init__()
    self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
    self.drop = nn.Dropout(p=0.3)
    self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    # raw or gradcam
    self.visualization = visualization 

    # placeholder for the gradients
    self.gradients = None


  # hook for the gradients of the activations
  def activations_hook(self, grad):
    self.gradients = grad

  
  def forward(self, input_ids, attention_mask):
    self.raw_input = input_ids.clone().detach().cpu().numpy()
    last_hidden_state, pooled_output = self.bert(
      input_ids=input_ids,
      attention_mask=attention_mask
    ) # We got a bunch of zero gradient problem because of the pooling operation, if we directly use last_hidden_state, this might be solved, but need retrain

    # pooled_output = torch.mean(last_hidden_state, dim=1)
    if self.visualization == "gradcam":
      h = last_hidden_state.register_hook(self.activations_hook)
    output = self.drop(pooled_output)
    return self.out(output)


  # method for the gradient extraction
  def get_activations_gradient(self):
      return self.gradients

  # method for the activation exctraction
  def get_activations(self, input_ids, attention_mask):
      if self.visualization == "gradcam":
        last_hidden_state, pooled_output = self.bert(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
        return last_hidden_state.detach().cpu().numpy()
      elif self.visualization == "raw":
        return self.raw_input
      else:
        return None

  def get_raw_input(self):
    return self.raw_input

In [159]:
# !gdown --id 1V8itWtowCYnb2Bc9KlK9SxGff9WwmogA
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SentimentClassifier(len(class_names))
model.load_state_dict(torch.load(DATASET_DIR + 'sentiment.bin'))
model = model.to(device)

In [160]:
model.eval()
data = next(iter(dataloader))

text = data["text"]
input_ids = data["input_ids"].to(device)
attention_mask = data["attention_mask"].to(device)
targets = data["targets"].to(device)

outputs = model(
  input_ids=input_ids,
  attention_mask=attention_mask
)

probs = F.softmax(outputs, dim=1)
_, preds = torch.max(outputs, dim=1)

pred_index = preds.detach().cpu().numpy()[0]

print("Input text is", text)
print("Length of input_ids is", len(input_ids[0]))
print("Input tokens are", tokenizer.convert_ids_to_tokens(input_ids[0]))
print("Ground truth label is", class_names[targets.detach().cpu().numpy()[0]])
print("Prediction is", class_names[pred_index])

Input text is ["I have been using forest since 2017 and had collected over 3000 coins. Unfortunately my phone got stolen and I had to buy a new one. Now when I reinstalled the app it's saying that I need to start collecting again and to log in I need a premium account which I have never had. How can I get my points back? Thank you."]
Length of input_ids is 75
Input tokens are ['[CLS]', 'I', 'have', 'been', 'using', 'forest', 'since', '2017', 'and', 'had', 'collected', 'over', '3000', 'coins', '.', 'Unfortunately', 'my', 'phone', 'got', 'stolen', 'and', 'I', 'had', 'to', 'buy', 'a', 'new', 'one', '.', 'Now', 'when', 'I', 'reins', '##tal', '##led', 'the', 'app', 'it', "'", 's', 'saying', 'that', 'I', 'need', 'to', 'start', 'collecting', 'again', 'and', 'to', 'log', 'in', 'I', 'need', 'a', 'premium', 'account', 'which', 'I', 'have', 'never', 'had', '.', 'How', 'can', 'I', 'get', 'my', 'points', 'back', '?', 'Thank', 'you', '.', '[SEP]']
Ground truth label is neutral
Prediction is positive

In [161]:
model.get_activations(input_ids, attention_mask)

array([[[ 2.0243027 , -0.77808666, -0.17660998, ...,  0.01387538,
          0.15429512,  0.37072766],
        [ 2.2487502 , -0.73420584,  0.06617729, ..., -0.03584908,
          0.21698795,  0.26387444],
        [ 1.9699428 , -0.8088339 , -0.5741832 , ...,  0.02192109,
         -0.13866429,  0.37770966],
        ...,
        [ 2.5377512 , -0.4671873 ,  0.15997146, ..., -0.13483839,
          0.3124979 ,  1.174147  ],
        [ 2.0100698 , -0.48299772, -0.336403  , ..., -0.14898776,
          0.04484099,  0.9281454 ],
        [ 2.3049784 , -0.6505696 ,  0.08683315, ..., -0.28160125,
          0.55691224,  0.22923389]]], dtype=float32)

In [162]:
outputs[:, pred_index].backward()

In [163]:
gradients = model.get_activations_gradient()

In [164]:
gradients

tensor([[[ 4.4222e-04,  4.4790e-04,  8.6981e-05,  ...,  1.5007e-04,
           2.0501e-04, -4.5824e-04],
         [ 4.4222e-04,  4.4790e-04,  8.6981e-05,  ...,  1.5007e-04,
           2.0501e-04, -4.5824e-04],
         [ 4.4222e-04,  4.4790e-04,  8.6981e-05,  ...,  1.5007e-04,
           2.0501e-04, -4.5824e-04],
         ...,
         [ 4.4222e-04,  4.4790e-04,  8.6981e-05,  ...,  1.5007e-04,
           2.0501e-04, -4.5824e-04],
         [ 4.4222e-04,  4.4790e-04,  8.6981e-05,  ...,  1.5007e-04,
           2.0501e-04, -4.5824e-04],
         [ 4.4222e-04,  4.4790e-04,  8.6981e-05,  ...,  1.5007e-04,
           2.0501e-04, -4.5824e-04]]], device='cuda:0')

In [165]:
pooled_gradients = torch.mean(gradients, dim=2)

In [166]:
pooled_gradients

tensor([[-7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06, -7.3654e-06,
         -7.3654e-06, -7.3654e-06, -7.