In [1]:
%%capture
! pip install transformers
! pip install labml-nn

# Phase 1 (BERT Embedding)
from : https://nn.labml.ai/transformers/retro/bert_embeddings.html

In [None]:
"""
---
title: BERT Embeddings of chunks of text
summary: >
  Generate BERT embeddings for chunks using a frozen BERT model
---

# BERT Embeddings of chunks of text

This is the code to get BERT embeddings of chunks for [RETRO model](index.html).
"""

from typing import List

import torch
from transformers import BertTokenizer, BertModel

from labml import lab, monit


class BERTChunkEmbeddings:
    """
    ## BERT Embeddings

    For a given chunk of text $N$ this class generates BERT embeddings $\text{B\small{ERT}}(N)$.
    $\text{B\small{ERT}}(N)$ is the average of BERT embeddings of all the tokens in $N$.
    """

    def __init__(self, device: torch.device):
        self.device = device

        # Load the BERT tokenizer from [HuggingFace](https://huggingface.co/bert-base-uncased)
        with monit.section('Load BERT tokenizer'):
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                                           cache_dir=str(
                                                               lab.get_data_path() / 'cache' / 'bert-tokenizer'))

        # Load the BERT model from [HuggingFace](https://huggingface.co/bert-base-uncased)
        with monit.section('Load BERT model'):
            self.model = BertModel.from_pretrained("bert-base-uncased",
                                                   cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))

            # Move the model to `device`
            self.model.to(device)

    @staticmethod
    def _trim_chunk(chunk: str):
        """
        In this implementation, we do not make chunks with a fixed number of tokens.
        One of the reasons is that this implementation uses character-level tokens and BERT
        uses its sub-word tokenizer.

        So this method will truncate the text to make sure there are no partial tokens.

        For instance, a chunk could be like `s a popular programming la`, with partial
        words (partial sub-word tokens) on the ends.
        We strip them off to get better BERT embeddings.
        As mentioned earlier this is not necessary if we broke chunks after tokenizing.
        """
        # Strip whitespace
        stripped = chunk.strip()
        # Break words
        parts = stripped.split()
        # Remove first and last pieces
        stripped = stripped[len(parts[0]):-len(parts[-1])]

        # Remove whitespace
        stripped = stripped.strip()

        # If empty return original string
        if not stripped:
            return chunk
        # Otherwise, return the stripped string
        else:
            return stripped

    def __call__(self, chunks: List[str]):
        """
        ### Get $\text{B\small{ERT}}(N)$ for a list of chunks.
        """

        # We don't need to compute gradients
        with torch.no_grad():
            # Trim the chunks
            trimmed_chunks = [self._trim_chunk(c) for c in chunks]

            # Tokenize the chunks with BERT tokenizer
            tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)

            # Move token ids, attention mask and token types to the device
            input_ids = tokens['input_ids'].to(self.device)
            attention_mask = tokens['attention_mask'].to(self.device)
            token_type_ids = tokens['token_type_ids'].to(self.device)
            # Evaluate the model
            output = self.model(input_ids=input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type_ids)

            # Get the token embeddings
            state = output['last_hidden_state']
            # Calculate the average token embeddings.
            # Note that the attention mask is `0` if the token is empty padded.
            # We get empty tokens because the chunks are of different lengths.
            emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)

            #
            return emb


def _test():
    """
    ### Code to test BERT embeddings
    """
    from labml.logger import inspect

    # Initialize
    device = torch.device('cuda:0')
    bert = BERTChunkEmbeddings(device)

    # Sample
    text = ["Replace me by any text you'd like.",
            "Second sentence"]

    # Check BERT tokenizer
    encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)

    inspect(encoded_input, _expand=True)

    # Check BERT model outputs
    output = bert.model(input_ids=encoded_input['input_ids'].to(device),
                        attention_mask=encoded_input['attention_mask'].to(device),
                        token_type_ids=encoded_input['token_type_ids'].to(device))

    inspect({'last_hidden_state': output['last_hidden_state'],
             'pooler_output': output['pooler_output']},
            _expand=True)

    # Check recreating text from token ids
    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)

    # Get chunk embeddings
    inspect(bert(text))


#
if __name__ == '__main__':
    _test()

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


768


# Phase 2 (SimCSE Embedding)

from: https://github.com/princeton-nlp/SimCSE

In [12]:
!gdown 1Un4CyOBS0tIucz3r74WVuQeubtVtpyba -O model_checkpoint.zip

Downloading...
From: https://drive.google.com/uc?id=1Un4CyOBS0tIucz3r74WVuQeubtVtpyba
To: /content/model_checkpoint.zip
100% 454M/454M [00:06<00:00, 65.1MB/s]


In [13]:
!unzip model_checkpoint.zip -d model_checkpoint

Archive:  model_checkpoint.zip
  inflating: model_checkpoint/config.json  
  inflating: model_checkpoint/optimizer.pt  
  inflating: model_checkpoint/pytorch_model.bin  
  inflating: model_checkpoint/scheduler.pt  
  inflating: model_checkpoint/special_tokens_map.json  
  inflating: model_checkpoint/tokenizer_config.json  
  inflating: model_checkpoint/trainer_state.json  
  inflating: model_checkpoint/training_args.bin  
  inflating: model_checkpoint/vocab.txt  


In [25]:
import torch
from scipy.spatial.distance import cosine
from transformers import AutoModel, AutoTokenizer

class SimCSEChunkEmbeddings:
    def __init__(self, device: torch.device):
        
        self.device = device
        
        self.tokenizer = AutoTokenizer.from_pretrained("model_checkpoint/")
        self.model = AutoModel.from_pretrained("model_checkpoint/")

        self.model.to(self.device)

    def __call__(self, sentences: list[str]):
        # Tokenize input texts

        inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")

        inputs = inputs.to(self.device)
        
        # Get the embeddings
        with torch.no_grad():
            embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output

        return  embeddings


In [27]:
simcse_embedding = SimCSEChunkEmbeddings('cuda:0')

Some weights of the model checkpoint at model_checkpoint/ were not used when initializing BertModel: ['mlp.dense.weight', 'mlp.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at model_checkpoint/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [36]:
texts = [
    "دانش آموزان در حال آماده کردن خود برای امتحان هستند",
    # "دانش آموزان فردا امتحان",
    "هوا تهران امروز ابری است",
    "آزمون میان ترم سخت است"
]

embeddings = simcse_embedding(texts)

In [37]:
embeddings

tensor([[ 1.9058e-01, -1.1787e-01,  5.7443e-01,  ...,  2.5692e-01,
          2.4539e-01,  3.6360e-01],
        [-3.2326e-01, -2.7280e-01,  4.7256e-04,  ..., -5.4758e-01,
          3.6839e-02,  2.4196e-01],
        [-1.6141e-01,  2.7397e-02, -5.0372e-03,  ..., -3.0751e-02,
          1.0698e-01,  6.0327e-01]], device='cuda:0')

In [38]:
# Calculate cosine similarities
# Cosine similarities are in [-1, 1]. Higher means more similar
cosine_sim_0_1 = 1 - cosine(embeddings[0].cpu() , embeddings[1].cpu() )
cosine_sim_0_2 = 1 - cosine(embeddings[0].cpu() , embeddings[2].cpu() )


print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[1], cosine_sim_0_1))
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[2], cosine_sim_0_2))

Cosine similarity between "دانش آموزان در حال آماده کردن خود برای امتحان هستند" and "هوا تهران امروز ابری است" is: 0.137
Cosine similarity between "دانش آموزان در حال آماده کردن خود برای امتحان هستند" and "آزمون میان ترم سخت است" is: 0.426
