### After running the tokens through the model and obtaining them from the last hidden state, it is cruitial to extract the essential information from the resulted output. This notebook will walk through some important embedding pooling steps and compare them.

In [58]:
# for auto reload when changes are made in the package
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [59]:
from transformers import AutoModel, AutoTokenizer
from helpers import show_tokenization, cosine_similarity, euclidean_distance

In [60]:
model_card = 'emilyalsentzer/Bio_ClinicalBERT'

In [61]:
tokenizer = AutoTokenizer.from_pretrained(model_card)
model = AutoModel.from_pretrained(model_card)

for param in model.parameters():
    param.requires_grad = False

Performing all steps: input > tokenization > model > output from `last_hidden_state`. One addition we add to experiment better is the `padding` and `max_length` to the tokenizer, where the number of tokens will be equal to the the defined max length, and the padded tokens will be given 0 with attention value equal to 0.

In [73]:
text = 'Patient has no history of fatigue, weight change, loss of appetite, or weakness.'
inputs = tokenizer(text, return_tensors="pt") #, padding="max_length", max_length=512) #, padding="max_length", max_length=512)

print(inputs)
print(f"\ntotal number of tokens is {len(inputs['input_ids'][0])}")
# show_tokenization(inputs, tokenizer)

output = model(**inputs)['last_hidden_state'] # batch_size, sequence_length, hidden_size
print(f'last_hidden_state outputs shape: {output.shape}')

{'input_ids': tensor([[  101,  5351,  1144,  1185,  1607,  1104, 18418,   117,  2841,  1849,
           117,  2445,  1104, 21518,   117,  1137, 11477,   119,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

total number of tokens is 19
last_hidden_state outputs shape: torch.Size([1, 19, 768])


To derive a single embedding from an LLM, you typically pool the hidden states using strategies like averaging the embeddings of all tokens, using the [CLS] token’s embedding, or other methods such as max pooling. The pooling approach often depends on the task and model design. Attention masks are used during pooling to avoid the influence of padding tokens, but may be less relevant for strategies like [CLS]. 

<p align="center">
  <img src="./public/embedding-summary.png" style="borders:none"/>
</p>

In [74]:
cls_pooling = embedding_pooling(output, 'cls')
eos_pooling = embedding_pooling(output, inputs['attention_mask'], 'eos')
max_pooling = embedding_pooling(output, inputs['attention_mask'], 'max')
mean_pooling = embedding_pooling(output, inputs['attention_mask'], 'mean')

#### To evaluate and compare the results from different pooling strategies, we can take several approaches to analyze how well each strategy preserves the meaning or aligns with the original tokenized text. We can measure the similarity between the embeddings produced by each pooling strategy and the original token embeddings or between embeddings from different strategies.

In [75]:
similarity_cls_cls = cosine_similarity(cls_pooling[0].cpu().detach().numpy(), cls_pooling[0].cpu().detach().numpy())

# this has to be one just as a validation
assert similarity_cls_cls == 1, "Wrong implementation for cosine similarity"

In [76]:
similarity_cls_eos = cosine_similarity(cls_pooling[0].cpu().detach().numpy(), eos_pooling[0].cpu().detach().numpy())
similarity_cls_max = cosine_similarity(cls_pooling[0].cpu().detach().numpy(), max_pooling[0].cpu().detach().numpy())
similarity_cls_mean = cosine_similarity(cls_pooling[0].cpu().detach().numpy(), mean_pooling[0].cpu().detach().numpy())

In [77]:
print(f'Cosine similarity between CLS and EOS pooled embeddings: {similarity_cls_eos:.4f}')
print(f'Cosine similarity between CLS and MAX pooled embeddings: {similarity_cls_max:.4f}')
print(f'Cosine similarity between CLS and MEAN pooled embeddings: {similarity_cls_mean:.4f}')

Cosine similarity between CLS and EOS pooled embeddings: 0.2897
Cosine similarity between CLS and MAX pooled embeddings: 0.1315
Cosine similarity between CLS and MEAN pooled embeddings: 0.8014


The mean pooled is the closest in similarity to the CLS token.