In [None]:
# Requires transformers>=4.51.0
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
import transformers
import numpy as np
import pickle as pkl
from tqdm import tqdm
transformers.__version__

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "xxx"
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')
model = AutoModel.from_pretrained(model_path).to(device)
for param in model.parameters(): #froze all parameters of LLM
    param.requires_grad = False

In [None]:
def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
    
def get_embedding(input_texts, max_length):
    max_length = max_length

    # Tokenize the input texts
    batch_dict = tokenizer(
        input_texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    batch_dict.to(model.device)
    outputs = model(**batch_dict)
    embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    return embeddings.detach().cpu(), outputs.last_hidden_state.detach().cpu(), batch_dict['attention_mask'].detach().cpu()

def get_embedding_batch(input_texts: list, 
                       batch_size: int = 8, 
                       max_length: int = 8192) -> np.ndarray:
    embeddings = []
    last_hidden_states = []
    attention_masks = []
    
    progress_bar = tqdm(total=len(input_texts)//batch_size + 1, 
                       desc="Generating embeddings",
                       unit="batch")
    
    with torch.no_grad():
        for i in range(0, len(input_texts), batch_size):
            batch_texts = input_texts[i:i + batch_size]
            
            # Tokenize
            batch_dict = tokenizer(
                batch_texts,
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            ).to(device)
            
            # forward pass
            outputs = model(**batch_dict)
            batch_emb = last_token_pool(outputs.last_hidden_state, 
                                      batch_dict['attention_mask'])
            
            embeddings.append(batch_emb.detach().cpu())
            last_hidden_states.append(outputs.last_hidden_state.detach().cpu())
            attention_masks.append(batch_dict['attention_mask'].detach().cpu())
            
            progress_bar.update(1)
    
    progress_bar.close()
    
    return torch.concatenate(embeddings, axis=0), torch.concatenate(last_hidden_states,axis=0), torch.concatenate(attention_masks, axis=0)