In [1]:
import torch
from transformers import AutoTokenizer, AutoModel

In [2]:
# Model name and device to use
model_name = "mlburnham/Political_DEBATE_base_v1.0"
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
def get_embedding(text, embedding_type="document", word=None, doc_pooling="cls", word_pooling="mean", model_name = "mlburnham/Political_DEBATE_base_v1.0"):
    """
    Get embeddings for a single document or a list of documents.

    Args:
        text (str or list): The input text (a single document or a list of documents).
        embedding_type (str): Type of embedding. Options: "document" (default) or "word".
        word (str): The word to get the embedding for (required if embedding_type="word").
        doc_pooling (str): Pooling method for document embeddings. Options: "cls" (default) or "mean".
        word_pooling (str): Pooling method for word embeddings. Options: "mean", "max", "sum", "min".

    Returns:
        torch.Tensor or list: The embedding(s) for the document(s).
    """
    # Handle single document input
    if isinstance(text, str):
        texts = [text]
    elif isinstance(text, list):
        texts = text
    else:
        raise ValueError("Input 'text' must be a string or a list of strings.")

    # Initialize the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    # pass the model to the device
    model.to(device)
    
    embeddings = []
    for doc in texts:
        # Tokenize the input text
        inputs = tokenizer(doc, return_tensors="pt", padding=False, truncation=True, max_length=512).to(device)
        1
        # Get the embeddings from the model
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get the token embeddings (last hidden state)
        token_embeddings = outputs.last_hidden_state  # Shape: (batch_size, sequence_length, hidden_size)
        
        if embedding_type == "document":
            if doc_pooling == "cls":
                # Use the [CLS] token's embedding as the document embedding
                embedding = token_embeddings[:, 0, :].squeeze()
            elif doc_pooling == "mean":
                # Use mean pooling over all tokens (excluding padding tokens)
                attention_mask = inputs["attention_mask"]  # Shape: (batch_size, sequence_length)
                expanded_mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * expanded_mask, dim=1)
                num_tokens = torch.sum(attention_mask, dim=1, keepdim=True)
                embedding = sum_embeddings / num_tokens
                embedding = embedding.squeeze()
            else:
                raise ValueError(f"Invalid doc_pooling option: {doc_pooling}. Choose 'cls' or 'mean'.")
        elif embedding_type == "word":
            if word is None:
                raise ValueError("The 'word' argument must be specified for word embeddings.")
            # Check if the word is in the document
            if word not in doc.split():
                raise ValueError(f"The word '{word}' was not found in the document: {doc}")
            # Find the token IDs for the specified word
            word_tokens = tokenizer.tokenize(word)
            word_token_ids = tokenizer.convert_tokens_to_ids(word_tokens)
            # Find the positions of the word tokens in the input sequence
            input_ids = inputs["input_ids"].squeeze().tolist()
            word_positions = [i for i, token_id in enumerate(input_ids) if token_id in word_token_ids]
            if not word_positions:
                raise ValueError(f"The word '{word}' was not found in the document: {doc}")
            # Extract embeddings for the word's tokens
            word_embeddings = token_embeddings[:, word_positions, :]  # Shape: (batch_size, num_occurrences, hidden_size)
            # Apply the specified pooling method
            if word_pooling == "mean":
                embedding = word_embeddings.mean(dim=1).squeeze()
            elif word_pooling == "max":
                embedding = word_embeddings.max(dim=1).values.squeeze()
            elif word_pooling == "sum":
                embedding = word_embeddings.sum(dim=1).squeeze()
            elif word_pooling == "min":
                embedding = word_embeddings.min(dim=1).values.squeeze()
            else:
                raise ValueError(f"Invalid word_pooling option: {word_pooling}. Choose 'mean', 'max', 'sum', or 'min'.")
        else:
            raise ValueError(f"Invalid embedding_type: {embedding_type}. Choose 'document' or 'word'.")
        
        embeddings.append(embedding)
    
    # Return a single embedding for single document input, or a list of embeddings for multiple documents
    return embeddings[0] if isinstance(text, str) else embeddings

In [4]:
# Example usage
document = ["I have a dog and my dog loves to play.", "I have a dog and my dog really loves to play."]
word = "dog"

# Get document embedding
doc_embedding = get_embedding(document, embedding_type="document", doc_pooling="mean")
#print("Document embedding shape:", doc_embedding.shape)

# Get word embedding
word_embedding = get_embedding(document, embedding_type="word", word=word)
#print("Word embedding shape:", word_embedding.shape)