In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from transformers import BertTokenizer, BertModel





In [5]:
model_name = "bert-base-uncased"

tokenizer = BertTokenizer.from_pretrained(model_name)

model = BertModel.from_pretrained(
    model_name,
    output_attentions=True
)

model.eval()


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [None]:
sentence = "The mechanic inspected the engine because it was noisy."


In [None]:
inputs = tokenizer(sentence, return_tensors="pt")


In [None]:
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
tokens

In [None]:
with torch.no_grad():
    outputs = model(**inputs)

attentions = outputs.attentions


In [None]:
layer = 0
head = 0

attention_matrix = attentions[layer][0, head].cpu().numpy()

In [None]:
plt.figure(figsize=(10, 8))
plt.imshow(attention_matrix, cmap="viridis")

plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.yticks(range(len(tokens)), tokens)

plt.colorbar()
plt.title("Self-Attention Heatmap (Layer 1, Head 0)")
plt.tight_layout()

plt.savefig("outputs/attention_heatmap.png")
plt.show()

A2 â€” Understand Positional Encoding

In [6]:
sentence_original = "The cat sat on the mat"
sentence_scrambled = "Mat the on sat cat the"


In [7]:
import torch.nn.functional as F

def get_sentence_embedding(sentence):
    inputs = tokenizer(sentence, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    # Last hidden state: (batch, seq_len, hidden_dim)
    token_embeddings = outputs.last_hidden_state

    # Mean pooling over tokens
    sentence_embedding = token_embeddings.mean(dim=1)

    return sentence_embedding


In [8]:
emb_original = get_sentence_embedding(sentence_original)
emb_scrambled = get_sentence_embedding(sentence_scrambled)

In [9]:
sentence_similarity = F.cosine_similarity(
    emb_original, emb_scrambled
)

sentence_similarity.item()

0.7193946838378906

In [10]:
def get_token_embeddings(sentence):
    inputs = tokenizer(sentence, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    return outputs.last_hidden_state.squeeze(0)


In [11]:
tokens_original = get_token_embeddings(sentence_original)
tokens_scrambled = get_token_embeddings(sentence_scrambled)

token_level_similarity = F.cosine_similarity(
    tokens_original.mean(dim=0),
    tokens_scrambled.mean(dim=0),
    dim=0
)

token_level_similarity.item()

0.7193946838378906