# Demonstration of BERT Strengths and Weaknesses

This is a demonstration of Google's BERT LLM capabilities, in context of the capabilities of more recent models like Google's BARD/Gemini. The goal is simply to highlight a some key strengths and weaknesses of BERT in order inform it's useful applications. 

#### Context:
- **BERT paper**: 'BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding': https://arxiv.org/abs/1810.04805
- **Gemini report**: 'Gemini: A Family of Highly Capable Multimodal Models' Gemini Team: https://storage.googleapis.com/deepmind-media/gemini/gemini_1_report.pdf

In [108]:
import torch
from torch.nn.functional import softmax
from transformers import BertTokenizer, BertForNextSentencePrediction, BertForMaskedLM

## Models

We use the pre-trained `bert-large-uncased` model, large variant. This model was trained on both the masked token and next sentence prediction tasks. There are many fine-tuned models built on this foundation, but we will focus on the core comprehension and knowledge capabilities of the model. 

Model documentation: https://huggingface.co/docs/transformers/v4.37.2/en/model_doc/bert

In [109]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

bert_masked_model = BertForMaskedLM.from_pretrained('bert-large-uncased')
bert_masked_model.eval()

bert_sentence_model = BertForNextSentencePrediction.from_pretrained('bert-base-cased')
bert_sentence_model.eval()

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForNextSentencePrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

## Prediction Implementation

### Mask Token Prediction 

In [110]:
MASK_TEXT = bert_tokenizer.mask_token

def print_mask_fill(text_with_single_mask, top_predictions_to_print=3, expected_word=None):
    '''Prints the top predictions for a masked word with some text.
    
    `text_with_single_mask`: str, text with a single mask token.
    `top_predictions_to_print`: int, number of top predictions to print.
    `expected_word`: str, expected word to print the prediction of.'''

    print(f"Text with mask: {text_with_single_mask}")
    mask_token_index, mask_predictions = _predict_mask_fill(text_with_single_mask)
    _print_top_predictions(top_predictions_to_print, mask_token_index, mask_predictions)
    if expected_word:
        _print_prediction_of_word(expected_word, mask_token_index, mask_predictions)

def print_word_appropriateness(word_to_check, text, top_predictions_to_print=3):
    '''Prints the given word's appropriateness in the given text.
    
    `word_to_check`: str, word from the text to have it's appropriateness checked.
    `text`: str, text to check the word's appropriateness in.
    `top_predictions_to_print`: int, number of top predictions to print.'''

    print_mask_fill(text.replace(word_to_check, MASK_TEXT), top_predictions_to_print, word_to_check)

def _predict_mask_fill(text_with_single_mask):
    text_token_ids = bert_tokenizer.encode(text_with_single_mask, return_tensors='pt')
    mask_token_index = torch.where(text_token_ids == bert_tokenizer.mask_token_id)[1]
    with torch.no_grad():
        model_output = bert_masked_model(text_token_ids)
        mask_predictions = model_output[0]
    return mask_token_index, mask_predictions

def _print_top_predictions(top_predictions_to_print, mask_token_index, mask_predictions):
    top_k_predictions = torch.topk(mask_predictions[0, mask_token_index][0], top_predictions_to_print)
    for prediction_index, prediction_token_index in enumerate(top_k_predictions.indices):
        predicted_token_word = bert_tokenizer.decode(prediction_token_index)
        clean_predicted_token_word = predicted_token_word.replace(" ", "")
        predicted_probability = top_k_predictions.values[prediction_index]
        print(f"- Rank: {prediction_index+1}, Word: {clean_predicted_token_word}, Probability: {predicted_probability:.3f}")
        
def _print_prediction_of_word(word, mask_token_index, mask_predictions):
    word_index = bert_tokenizer.encode(word, return_tensors='pt')[0][1]
    word_probability = mask_predictions[0, mask_token_index, word_index][0]
    word_probability_rank = torch.sum(mask_predictions[0, mask_token_index] > word_probability) + 1
    print(f"- Rank: {word_probability_rank}, Word: {word}, Probability: {word_probability:.3f}")


### Sentence Similarity Prediction

In [111]:
def print_top_sentence_matches(query_sentence, sentences, top_predictions_to_print=3):
    '''Prints the top matching sentences for a given query sentence.
    
    `query_sentence`: str, sentence to find the top matching sentences for.
    `sentences`: list of str, sentences to match the query sentence against.
    `top_predictions_to_print`: int, number of top predictions to print.'''
    
    similarities = []
    for i in range(0, len(sentences)):
        sentence_probabilities = _calculate_sentence_relationships(query_sentence, sentences[i])
        similarity = sentence_probabilities[0][0]
        similarities.append((sentences[i], similarity))
    _print_top_matching_sentences(query_sentence, sentences, similarities, top_predictions_to_print)

def _calculate_sentence_relationships(sentence_1, sentence_2):
    sentences_encoding = bert_tokenizer.encode_plus(sentence_1, text_pair=sentence_2, return_tensors='pt')
    sentence_relationship_logits = bert_sentence_model(**sentences_encoding)[0]
    probabilities = softmax(sentence_relationship_logits, dim=1)
    return probabilities

def _print_top_matching_sentences(query_sentence, sentences_to_match, similarities, top_predictions_to_print):
    similarities.sort(key=lambda x: x[1], reverse=True)
    print(f"Top {top_predictions_to_print} matches for sentence: '{query_sentence}'")
    for i in range(0, min(top_predictions_to_print, len(sentences_to_match))):
        print(f"- Similarity: {similarities[i][1]:.4f} Sentence: '{similarities[i][0]}',")


## BERT Weaknesses

### Numerical reasoning

In [112]:
print_mask_fill(f"10 + 10 = {MASK_TEXT}.", expected_word="20")

Text with mask: 10 + 10 = [MASK].
- Rank: 1, Word: 10, Probability: 10.840
- Rank: 2, Word: 0, Probability: 9.743
- Rank: 3, Word: 1, Probability: 9.624
- Rank: 14, Word: 20, Probability: 7.882


### Multi-word expression comprehension (Non-compositionality)

In [113]:
print_mask_fill(f"The gardener is {MASK_TEXT} fingered.", expected_word="green")

Text with mask: The gardener is [MASK] fingered.


- Rank: 1, Word: one, Probability: 8.942
- Rank: 2, Word: two, Probability: 8.621
- Rank: 3, Word: three, Probability: 8.307
- Rank: 51, Word: green, Probability: 4.727


### Dialog tracking

In [114]:
print_mask_fill(f"Alice: What's the weather like there? Bob: It's very cold. Is it the same for you? Alice: Not at all, it's [MASK] here.", expected_word="hot")

# Counter example without dialog:
print_mask_fill(f"The weather's not very cold, it's [MASK] here.", expected_word="hot")

Text with mask: Alice: What's the weather like there? Bob: It's very cold. Is it the same for you? Alice: Not at all, it's [MASK] here.
- Rank: 1, Word: different, Probability: 12.071
- Rank: 2, Word: cooler, Probability: 10.281
- Rank: 3, Word: warmer, Probability: 10.208
- Rank: 12, Word: hot, Probability: 7.528
Text with mask: The weather's not very cold, it's [MASK] here.
- Rank: 1, Word: warm, Probability: 11.216
- Rank: 2, Word: nice, Probability: 10.769
- Rank: 3, Word: beautiful, Probability: 10.242
- Rank: 4, Word: hot, Probability: 9.525


### Knowledge breadth

In [115]:
print_mask_fill(f"The club in Berlin with a notoriously strict door policy is called {MASK_TEXT}.", expected_word="Berghain")

Text with mask: The club in Berlin with a notoriously strict door policy is called [MASK].
- Rank: 1, Word: club, Probability: 6.132
- Rank: 2, Word: inferno, Probability: 5.731
- Rank: 3, Word: astoria, Probability: 5.689
- Rank: 1719, Word: Berghain, Probability: 1.094


## BERT Strengths

Considering BERT's relative fast speed (versus other huge LLMs), and it's optimisation for sentence similarity and masked token prediction, it could still find application in a few contexts, as follows.

### Fixing typos

In [116]:
print_word_appropriateness("fund", "I fund the product to be useful but overpriced.")

Text with mask: I [MASK] the product to be useful but overpriced.
- Rank: 1, Word: found, Probability: 15.158
- Rank: 2, Word: find, Probability: 13.372
- Rank: 3, Word: consider, Probability: 12.052
- Rank: 2564, Word: fund, Probability: 0.499


### Fixing OCR misreads

In [117]:
print_word_appropriateness("hot", "There will be additional costs if the above terms and conditions are hot adhered to.")

Text with mask: There will be additional costs if the above terms and conditions are [MASK] adhered to.
- Rank: 1, Word: not, Probability: 18.069
- Rank: 2, Word: strictly, Probability: 12.863
- Rank: 3, Word: fully, Probability: 11.104
- Rank: 7153, Word: hot, Probability: -1.586


### Discovering similar document titles

In [118]:
query_sentence = "Molecular structure of DNA fragments."
sentences = [
    "Class representation in ancient Egyptian names.", 
    "Methane propellant analysis for impact on marine life.",
    "Genetics and climate.",
    "Material strength of proteins.",
    "Effectiveness of fragmented governments."]

print_top_sentence_matches(query_sentence, sentences)

Top 3 matches for sentence: 'Molecular structure of DNA fragments.'
- Similarity: 1.0000 Sentence: 'Effectiveness of fragmented governments.',
- Similarity: 0.9999 Sentence: 'Material strength of proteins.',
- Similarity: 0.9978 Sentence: 'Genetics and climate.',


### Discovering related sentences in a document

In [119]:
query_sentence = "How long is the warranty?"
sentences = [
    "Size: 30cm x 10cm x 8 cm.", 
    "The warranty length defaults as 3 years.",
    "Keep children away from the wrapping.",
    "There will be a reduction to 1 year in warranty length in the case of multiple users."
    "The guarantee is voided if the product is used for commercial purposes.",
    "Don't use the product in combination with other products."]

print_top_sentence_matches(query_sentence, sentences)

Top 3 matches for sentence: 'How long is the warranty?'
- Similarity: 1.0000 Sentence: 'The warranty length defaults as 3 years.',
- Similarity: 1.0000 Sentence: 'There will be a reduction to 1 year in warranty length in the case of multiple users.The guarantee is voided if the product is used for commercial purposes.',
- Similarity: 0.9994 Sentence: 'Keep children away from the wrapping.',
