# Contextualized embedding with transformer models illustrated

In this notebook, we begin to peek under the hood of a BERT transformer model to understand how contextualized embedding work.
We then also introduce a couple of potential use cases that leverage contextualized embeddings.

<br>
<a target="_blank" href="https://colab.research.google.com/github/haukelicht/advanced_text_analysis/blob/main/notebooks/embedding/contextualized_embedding_transformers_explained.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

### Setup

#### Colab

In [None]:
# check if on colab
COLAB = True
try:
    import google.colab
except:
    COLAB=False

if COLAB:
    # shallow clone of current state of main branch 
    !git clone --branch main --single-branch --depth 1 --filter=blob:none https://github.com/haukelicht/advanced_text_analysis.git

    !pip install -q transformers==4.44.1 matplotlib==3.9.2 umap-learn==0.5.6 bertviz==1.4.0


#### Required packages

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd

import torch
from transformers import BertTokenizer, BertModel
from transformers import (
    BertForMaskedLM, 
    BertModel, BertTokenizer
)

from bertviz import head_view
from bertviz.transformers_neuron_view import BertModel as BertVizModel 
from bertviz.transformers_neuron_view import BertTokenizer as BertVizTokenizer
from bertviz.neuron_view import show

import umap
from sklearn.cluster import KMeans

from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity

import matplotlib.pyplot as plt

## Intro to the `transformers` library

In python, the standard library to work with transformer models is `transformers`.
It provides access to pre-trained transformers models through its [model hub]().
The `transformers` library is developed and maintained by Hugging Face Inc.

### pre-trained models and tokenizers

To use a pre-trained model for embedding texts, we need two things:

1. the model's tokenizer
1. and of course the model itself

We use the model to process a text though its **layers** to obtain the text's **embedding**.
But to be able to do this, we need to **tokenize** the text to convert it into number – because deep neural network can only process with numbers, not with raw text.

Below we load a pre-trained BERT model, specifically "bert-base-uncased", which is a smallish version of BERT (hence 'base' instead of 'large') that does not distinguish between upper- and lowercase letters (hence 'uncased'). 

In [None]:
# define the name of the model we want to load
model_id = 'bert-base-uncased'

# load the pre-trained model and tokenizer 
model = BertModel.from_pretrained(model_id)
tokenizer = BertTokenizer.from_pretrained(model_id)
# NOTE: this will trigger downloading the model and tokenizer if you haven't done so before

Let's get some information about the model by looking at its configuration attribute (`config`):

In [None]:
# let's get some important information about the model
print('embedding dimensionality:', model.config.hidden_size)
print('number of layers:', model.config.num_hidden_layers)
print('vocabulary size:', model.config.vocab_size)

In [None]:
# lets' have a look at the model architecture
print(model)

- the models first component is a `BertEmbeddings` module that contains
    1. the initial word embedding layer
    2. the positional embedding
- after this we have the `BertEncoder` module that consists of 12 `BertLayer`s

If we just want to get the initial word embeddings, we can access them like this.

In [None]:
model.embeddings.word_embeddings.weight.shape

In [None]:
print(model.embeddings.word_embeddings)

# let's get the first five values of the first embedding
model.embeddings.word_embeddings.weight[0][:5].detach().numpy()

notes: 

- the layers are attributes of the `model` and they are organized and nested as can be seen when calling `print(model)` 
- we get the actual parameters of the model from a layer's "weigths" (weights is just the machine learning term for parameters)
- weights are $n$-dimensional arrays (called "tensors" in `pytorch` etc.) and we can index them just like numpy arrays
- we use `detach()` because the model and its weights (parameters) are tracked by the optimization algorithm, which we dont need when we only want to see the weight values

But the main reason we use BERT & Co. is to obtain contextualized embeddings.

## Contextualized embedding

To illustrate how contextualized embedding works in transformers, we will first look at how embeddings of the same word differ if their context differs.

Let's take two sentences what contain the word "bank" but use it with different meanings:

In [None]:
sentences = [
    "Today, I will hike along the bank of a river.",
    "Today, I will open a new account at my bank and deposit some money.",
] 

To get the transformer embedding of the word "bank" in these two sentences, we need to follow three steps:

1. tokenizer thge texts and convert tokens into tokens IDs (to look-up their input embeddings)
2. process these inputs through the model
3. locate the embedding of the focal word in the two sentences.

#### 1) tokenize

The tokenizer converts the text into tokens and maps the tokens to token IDs

Token IDs indicate tokens' locations in the tokenizers vocabulary and hence the model's input embedding. 

In [None]:
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)

In [None]:
inputs

In [None]:
inputs['input_ids']

We can "decode" these token IDs into their tokens:

In [None]:
tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

Notes: 

- the `[CLS]` token is a special token used to summarize the information in a sequence (e.g., for classification tasks)
- the `[SEP]` token is the special "separator" token that indicates sequence boundaries
- the `[PAD]` token is the special "padding" token that is appended to sequences that are shorter than the other sequences in a batch to make the input rectengular (e.g., all rows have an equal number of columns)

In [None]:
# let's use the tokenizer to get the token ID of the focal word 
focal_word_id = tokenizer.convert_tokens_to_ids('bank')
focal_word_id

In [None]:
# create maks that is true where input ID == focal word ID
mask = inputs['input_ids'] == focal_word_id
mask

#### 2) embed (process through model)

In [None]:
# get the intial emebdding of the focal word ("bank")
model.embeddings.word_embeddings.weight[focal_word_id].shape

In [None]:
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

**_Note:_** We use `torch.no_grad()` to disable gradient tracking, which is used for "back propagation" – the method used to optimize deep neural networks' parameters  

In [None]:
print(type(outputs))
# list the object's attributes
list(dict(outputs).keys())

In [None]:
outputs.hidden_states[3].shape

In [None]:
# hiden states are the embeddings after each layer
len(outputs.hidden_states)

In [None]:
# the final embedding can be accessed like this: 
outputs.hidden_states[-1][0][0].shape

In [None]:
# let'S look at the shape:
outputs.hidden_states[-1].shape

#### 3) get the words' contextualized embeddings 

In [None]:
# final transformer embeddings of bank in different contexts
embeddings = outputs.last_hidden_state[mask]

In [None]:
embeddings.shape

In [None]:
# compute cosine similarity between the two embeddings
cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))

Below you can see that the similarity of "bank"'s transformer embedding deepends on the model layer we look at.

In [None]:
# iterate over all layers
for i, layer in enumerate(outputs.hidden_states):
    # skip input embeddings
    if i == 0:
        continue
    embeddings = layer[mask]
    similarity = cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))
    print(f'layer {i}: {similarity}')

### 🔥 Competition 🔥

**Try it your self!** 

- Define a pair of sentences that use a word with different meanings.
- Whoever gets the **lowest similar score** (at the final layer of `bert-base-uncase`) for their example pair wins!

*Bonus:* Can you think of a word from your research domain or area of interest that has multiple meanins. If so, does BERT seem to distinguish these meanings?

## Attention &mdash; peeking under the hood

Let's use the amazing `bertviz` library to have a deeper look into the workings of transformers.
Let's define 

In [None]:
sentence = "Today, I will hike along the bank of a river."

### 🔥 Exercise 🔥

Use the interactive attention head and neuron views below to answer the following questions

1. In what layers does BERT attend to the "bank"'s context token "river"? And which head focuses the most on this context word?
1. What other context tokens of "bank" does BERT attend to across layers?
1. How does that change across layers?

### Head View

<b>The head view visualizes attention in one or more heads from a single Transformer layer.</b> Each line shows the attention from one token (left) to another (right). Line weight reflects the attention value (ranges from 0 to 1), while line color identifies the attention head. When multiple heads are selected (indicated by the colored tiles at the top), the corresponding  visualizations are overlaid onto one another.  For a more detailed explanation of attention in Transformer models, please refer to the [blog](https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1).

In [None]:
# load the model
model = BertModel.from_pretrained(model_id, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_id)

In [None]:
# retrieve attention weights
inputs = tokenizer.encode_plus(sentence, return_tensors='pt')
input_ids = inputs['input_ids']
attention = model(input_ids)[-1]
input_id_list = input_ids[0].tolist()
tokens = tokenizer.convert_ids_to_tokens(input_id_list)

In [None]:
# visualize
head_view(attention, tokens)

#### *Usage*

- **Hover** over any **token** on the left/right side of the visualization to filter attention from/to that token. <br/>
- **Double-click** on any of the **colored tiles** at the top to filter to the corresponding attention head.<br/>
- **Single-click** on any of the **colored tiles** to toggle selection of the corresponding attention head. <br/>
- **Click** on the **Layer** drop-down to change the model layer (zero-indexed).

### Neuron View
<b>The neuron view visualizes the intermediate representations (e.g. query and key vectors) that are used to compute attention.</b> In the collapsed view (initial state), the lines show the attention from each token (left) to every other token (right). In the expanded view, the tool traces the chain of computations that produce these attention weights. For a detailed explanation of the attention mechanism, please refer to the [blog](https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1).

In [None]:
bertviz_model = BertVizModel.from_pretrained(model_id, output_attentions=True)
bertviz_tokenizer = BertVizTokenizer.from_pretrained(model_id)
show(bertviz_model, 'bert', bertviz_tokenizer, sentence, layer=4, head=3)

#### *Usage*

- **Hover** over any of the tokens on the left side of the visualization to filter attention from that token.<br/>
- Then **click** on the **plus** icon that is revealed when hovering. This exposes the query vectors, key vectors, and other intermediate representations used to compute the attention weights. Each color band represents a single neuron value, where color intensity indicates the magnitude and hue the sign (blue=positive, orange=negative).<br/>
- Once in the expanded view, **hover** over any other **token** on the left to see the associated attention computations.<br/>
- **Click** on the **Layer** or **Head** drop-downs to change the model layer or head (zero-indexed).

### 🔥 Brainstorming session 🔥

Assume you have a BERT model that has only been trained on a corpus of texts specific to your research domain.

**Questions:** 

- Can you think of any uses of the attention-level information we inspected above to understand language use and discourse in this text copus?
- Do you think you could as well us the pretrained `bert-base-uncased` model loaded above? Why or why not?


## Example application: Using BERT for word sense disambiguation

**Question:** How can we use transformers to categorize in what sense a word is used in its context?

**_Idea:_**

1. a words' context clarifies its meaning
2. contextualized embeddings capture this by shifting embeddings to their context
3. this means that contextualized embeddings of a word different senses occupy different "locations" in the embedding space
4. given that the embeddings are high-dimensional numeric vectors, we can cluster them to disambiguate senses.


#### Implementation

I have asked OpenAI's GPT-4o to generate a list of sentences that use the word "bank" in different senses.
Below, we'll use this data to see how well BERT's ability to generate contextualized embeddings allows us to disambiguate between this word's contextual meanings.

In [None]:

# Load the pre-trained model for masked language modeling
model = BertModel.from_pretrained(model_id)
tokenizer = BertTokenizer.from_pretrained(model_id)

In [None]:

base_path = '/content/advanced_text_analysis/' if COLAB else '../../'
base_path = Path(base_path)
data_path = base_path / 'data' / 'misc'

# load the file
fp = data_path / 'bank_sentences_with_senses.csv'
df = pd.read_csv(fp)

In [None]:
# check that all sentences contain the word "bank"
df.text.str.contains('bank').value_counts()

In [None]:
df.groupby('sense').sample(1, random_state=42)

In [None]:
df.value_counts('sense')

In [None]:
# tokenizer the sentences
inputs = tokenizer(df.text.to_list(), return_tensors="pt", padding=True, truncation=True)

In [None]:
# create maks that is true where input ID == focal word ID
focal_word_id = tokenizer.convert_tokens_to_ids('bank')
mask = inputs['input_ids'] == focal_word_id

In [None]:
# process the inputs through the model
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

In [None]:
# Apply the mask to the last hidden layer output to get the focal words embeddings
focal_word_embeddings = outputs.last_hidden_state[mask]

In [None]:
# now we have the 768-dimensional embeddings of the focal word "bank" in each sentence
focal_word_embeddings.shape

**Questions:** How can we *see* whether or not and, if so, how the embeddings of words used in similar senses occupy similar locations in the embedding space?

**Answer:** dimensionality reduction

In [None]:
# Reduce the embeddings to 2D
reducer = umap.UMAP(n_components=2, random_state=42, n_jobs=1)
embeddings_2d = reducer.fit_transform(focal_word_embeddings)

In [None]:
# plot the 2D embeddings by sense, using different colors and a legend indicating the sense
for sense in df.sense.unique():
    idxs = df.sense == sense
    plt.scatter(embeddings_2d[idxs, 0], embeddings_2d[idxs, 1], label=sense, s=10)
plt.legend()
plt.show()

Now, we eyeballed the data to find clusters.

**Question:** (How reliably) Can we automate this disambiguation approach?

In [None]:
# cluster in 2D using k-means with k=3
kmeans = KMeans(n_clusters=3, random_state=42)
df['cluster'] = kmeans.fit_predict(focal_word_embeddings)

In [None]:
# get category indicator of the sense
# cross tabulate the cluster labels with the sense labels
pd.crosstab(df.cluster, df.sense)

In [None]:
# use majority to label induced clusters
cluster_to_sense = {0: 'geographical', 1: 'motion', 2: 'financial'}
df['cluster_label'] = df.cluster.map(cluster_to_sense)

In [None]:
# compute the accuracy of the clustering
accuracy_score(df.sense, df.cluster_label)

In [None]:
# get examples where cluster label disagrees with label
for row in df[df.sense != df.cluster_label].itertuples():
    print(f'in cluster \'{row.cluster_label}\' instead of \'{row.sense}\': "{row.text}"')

### 🔥 Brainstorming session 🔥

- Can you think of any potential uses BERT's ability to contextualize words' embeddings in your research, for example to study differences in word use across actors or domains?
- Do you think you could as well us the pretrained `bert-base-uncased` model loaded above? Why or why not?


## Predicting masked-out words

In [None]:
# Load the pre-trained model for masked language modeling
model = BertForMaskedLM.from_pretrained(model_id)

# Define the text with a masked token
text = "He was walking along the [MASK] of the river."

# Tokenize the input text
inputs = tokenizer(text, return_tensors="pt")

# Get the index of the masked token
masked_index = (inputs['input_ids'] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1].item()

# Predict the masked token
with torch.no_grad():
    outputs = model(**inputs)
    predictions = outputs.logits

In [None]:
predictions[0, masked_index].shape

In [None]:
# Get the log probabilities of the 10 best fitting words
log_probs = torch.log_softmax(predictions[0, masked_index], dim=-1)
top_10_log_probs, top_10_indices = torch.topk(log_probs, 10)

# Convert indices to tokens
top_10_tokens = tokenizer.convert_ids_to_tokens(top_10_indices.tolist())

# Print the results
for token, log_prob in zip(top_10_tokens, top_10_log_probs):
    print(f"{token}: {log_prob.item()}")

## Example application: Using masked token prediction to study gender bias

In [None]:
# Load the pre-trained model for masked language modeling
model = BertForMaskedLM.from_pretrained(model_id)
tokenizer = BertTokenizer.from_pretrained(model_id)

def get_topk_words(text):
    
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt")

    # Get the index of the masked token
    masked_index = (inputs['input_ids'] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1].item()

    # Predict the masked token
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits

    # Get the log probabilities of the 10 best fitting words
    log_probs = torch.log_softmax(predictions[0, masked_index], dim=-1)
    top_10_log_probs, top_10_indices = torch.topk(log_probs, 10)

    # Convert indices to tokens
    top_10_tokens = tokenizer.convert_ids_to_tokens(top_10_indices.tolist())

    # Print the results
    out = pd.DataFrame({'token': top_10_tokens, 'log_prob': top_10_log_probs})
    out['prob'] = np.exp(out.log_prob.to_numpy())
    return out

In [None]:
get_topk_words('He was very [MASK].')

In [None]:
get_topk_words('She was very [MASK].')

In [None]:
print("'Homosexuals'\n", get_topk_words('Homosexuals are making our country [MASK].'))
print("'Straights'\n", get_topk_words('Straights are making our country [MASK].'))

In [None]:
print("'Muslims'\n", get_topk_words('Muslims are making our country [MASK].'))
print("'Christians'\n", get_topk_words('Christians are making our country [MASK].'))

In [None]:
print("'Kids'\n", get_topk_words('Kids are very [MASK].'))
print("'Teens'\n", get_topk_words('Teens are very [MASK].'))
print("'Adults'\n", get_topk_words('Adults are very [MASK].'))