In [None]:
!pip install transformers
!pip install captum

Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6->captum)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6->captum)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6->captum)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6->captum)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.6->captum)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.6->captum)
  Downloading nvidia_cufft_cu1

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
text = 'He is best boy in the class'
text_ids = tokenizer.encode(text)
print(tokenizer.convert_ids_to_tokens(text_ids))


['[CLS]', 'He', 'is', 'best', 'boy', 'in', 'the', 'class', '[SEP]']


In [None]:
text_ids = tokenizer.encode(text, add_special_tokens= True)
print(text_ids)

[101, 1124, 1110, 1436, 2298, 1107, 1103, 1705, 102]


In [None]:
from transformers import BertModel
import torch
model = BertModel.from_pretrained('bert-base-cased')
embeddings = model.embeddings(torch.tensor([text_ids]))
print(embeddings)

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

tensor([[[ 0.4496,  0.0977, -0.2074,  ...,  0.0578,  0.0406, -0.0951],
         [-0.2538,  0.4405,  0.7437,  ...,  0.6523,  0.4797,  0.3767],
         [-1.1912,  0.2042,  0.6887,  ...,  0.5158,  1.0885,  0.7634],
         ...,
         [-0.8708,  0.5851, -0.8524,  ..., -0.2910,  0.9507,  0.3556],
         [-1.2246,  0.6147,  0.4660,  ...,  0.2802,  0.1035,  0.4250],
         [-0.3162,  0.1007,  0.1413,  ...,  0.5393, -0.4997,  0.3309]]],
       grad_fn=<NativeLayerNormBackward0>)


In [None]:
print(embeddings.size())

torch.Size([1, 9, 768])


In [None]:
from torch import nn
class BertClassifier(nn.Module):
  def __init__(self, dropout=0.5):
    super(BertClassifier, self).__init__()

    self.bert = BertModel.from_pretrained('bert-base-cased')
    self.dropout = nn.Dropout(dropout)
    self.linear = nn.Linear(768,2)
    self.relu = nn.ReLU()

  def forward(self, input_id, mask= None):
    _, pooled_output = self.bert(input_ids = input_id, attention_mask = mask, return_dict= False)
    dropout_output = self.dropout(pooled_output)
    linear_output = self.linear(dropout_output)
    final_layer = self.relu(linear_output)

    return final_layer

In [None]:
model = BertClassifier()
#model.load_state_dict(torch.load('path/to/bert_model.pt', map_location = torch.device('cpu')))
model.eval()

BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [None]:
def model_output (inputs):
  return model(inputs)[0]

model_input = model.bert.embeddings
print(model_input(torch.tensor([text_ids])))

tensor([[[ 0.4496,  0.0977, -0.2074,  ...,  0.0578,  0.0406, -0.0951],
         [-0.2538,  0.4405,  0.7437,  ...,  0.6523,  0.4797,  0.3767],
         [-1.1912,  0.2042,  0.6887,  ...,  0.5158,  1.0885,  0.7634],
         ...,
         [-0.8708,  0.5851, -0.8524,  ..., -0.2910,  0.9507,  0.3556],
         [-1.2246,  0.6147,  0.4660,  ...,  0.2802,  0.1035,  0.4250],
         [-0.3162,  0.1007,  0.1413,  ...,  0.5393, -0.4997,  0.3309]]],
       grad_fn=<NativeLayerNormBackward0>)


In [None]:
from captum.attr import LayerIntegratedGradients
lig = LayerIntegratedGradients(model_output, model_input)

In [None]:
text= "hello world hate you"
inputs = tokenizer(text, return_tensors='pt')
print(inputs)

{'input_ids': tensor([[  101, 19082,  1362,  4819,  1128,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}


In [None]:
lig = LayerIntegratedGradients(model_output, model_input)
attributions, delta = lig.attribute(inputs['input_ids'], return_convergence_delta=True)
print("Attributions:", attributions)
print("Convergence Delta:", delta)

Attributions: tensor([[[-5.7459e-06, -3.3684e-05, -9.4564e-06,  ..., -8.1912e-06,
           6.3224e-08, -3.8260e-06],
         [ 2.6778e-07, -9.2725e-06, -1.5579e-07,  ..., -4.4119e-05,
          -4.2206e-06, -3.2510e-06],
         [ 7.8447e-07, -3.7329e-05, -7.7593e-07,  ..., -8.3299e-05,
          -7.3609e-06,  1.3442e-05],
         [-5.6477e-07, -8.4533e-05, -1.2546e-05,  ..., -4.5799e-05,
          -4.5825e-06, -3.6581e-06],
         [ 4.2052e-06, -2.1437e-05, -8.7877e-06,  ...,  3.0046e-05,
          -2.5995e-05,  1.2601e-07],
         [-2.4317e-06, -5.5181e-05, -2.6593e-06,  ...,  1.1308e-05,
           1.3481e-05,  1.1687e-07]]], dtype=torch.float64)
Convergence Delta: tensor([-0.1847,  0.5337], dtype=torch.float64)


In [None]:
def construct_input_and_baseline(text):

    max_length = 510
    baseline_token_id = tokenizer.pad_token_id
    sep_token_id = tokenizer.sep_token_id
    cls_token_id = tokenizer.cls_token_id

    text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)

    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    token_list = tokenizer.convert_ids_to_tokens(input_ids)


    baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list

text = 'i am in fever'
input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')


original text: tensor([[  101,   178,  1821,  1107, 10880,   102]])
baseline text: tensor([[101,   0,   0,   0,   0, 102]])


In [None]:
attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True
                                    )
print(attributions.size())


torch.Size([1, 6, 768])


In [None]:
def summarize_attributions(attributions):

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)

    return attributions

attributions_sum = summarize_attributions(attributions)
print(attributions_sum.size())

torch.Size([6])


In [None]:
from captum.attr import visualization as viz

score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = 1,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),
                        raw_input_ids = all_tokens,
                        convergence_score = delta)

viz.visualize_text([score_vis])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.72),i am in fever,-1.41,[CLS] i am in fever [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.72),i am in fever,-1.41,[CLS] i am in fever [SEP]
,,,,


In [None]:
def interpret_text(text, true_class):

  input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

  attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True
                                    )
  attributions_sum = summarize_attributions(attributions)

  score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = true_class,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),
                        raw_input_ids = all_tokens,
                        convergence_score= delta)

  viz.visualize_text([score_vis])


In [None]:
#text interpretion

text = "Its a heartfelt flim about love, loss, and legacy"
true_class =1
interpret_text(text, true_class)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.53),"Its a heartfelt flim about love, loss, and legacy",2.44,"[CLS] Its a heart ##fe ##lt fl ##im about love , loss , and legacy [SEP]"
,,,,


In [None]:
text = "fuck you bitch harder"
true_class = 0
interpret_text(text, true_class)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,1 (0.56),fuck you bitch harder,-1.76,[CLS] fuck you bitch harder [SEP]
,,,,
