In [1]:
import torch
import pickle
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel, PeftConfig

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [2]:
# Configuration
MODEL_NAME = 'castorini/repllama-v1-7b-lora-passage' # replace with the actual RapLlaMA model ID
OUTPUT_PATH = "wsj_rapllama_embeddings.pkl"
MAX_LENGTH = 512  # model's max token length

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def get_model(peft_model_name):
    config = PeftConfig.from_pretrained(peft_model_name)
    base_model = AutoModel.from_pretrained(config.base_model_name_or_path).to(device)
    model = PeftModel.from_pretrained(base_model, peft_model_name).to(device)
    model = model.merge_and_unload().to(device)
    model.eval()
    return model

In [4]:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = get_model('castorini/repllama-v1-7b-lora-doc')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
def embed_text(document):
    document_input = tokenizer(f'passage: {document}</s>', return_tensors='pt').to(device)
    with torch.no_grad():
        # compute document embedding
        document_outputs = model(**document_input)
        document_embeddings = document_outputs.last_hidden_state[0][-1]
        document_embeddings = torch.nn.functional.normalize(document_embeddings, p=2, dim=0)
    return document_embeddings

In [6]:
def embed_corpus(documents):
    """
    Embed a list of documents.

    Args:
        documents (List[str]): List of text passages to embed.

    Returns:
        torch.Tensor: A tensor of shape (len(documents), embedding_dim) containing
                      the normalized embeddings for each document.
    """
    # Compute embeddings for each document
    embeddings = [embed_text(doc) for doc in documents]
    
    # Stack into a single tensor
    embeddings = torch.stack(embeddings)
    
    return embeddings

# Example usage:
# docs = ["First document.", "Second document.", "Third document."]
# corpus_embeddings = embed_corpus(docs)
# print(corpus_embeddings.shape)  # -> torch.Size([3, embedding_dim])