# Extract text embeddings from image captions

## Load dataset

In [None]:
from datasets import load_dataset
from tqdm import tqdm
import torch
import transformers
import numpy as np

In [2]:
rhf_dataset_dict = load_dataset('RAraghavarora/RichHumanFeedback')

## Extract embeddings

### BERT

In [3]:
from transformers import BertTokenizer, BertModel

In [4]:
def extract_bert_embedding(dataset):
    """
    Extract BERT CLS embeddings for all image captions in the given dataset
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
    model = BertModel.from_pretrained("bert-large-uncased")
    model.to(device)
    cls_embeddings = []
    for text in tqdm(dataset):      
        inputs = tokenizer(text,return_tensors='pt')
        inputs = {key: value.to(device) for key, value in inputs.items()}
        with torch.no_grad():
                outputs = model(**inputs)
                cls_embedding = outputs.last_hidden_state[:, 0, :]
        cls_embeddings.append(cls_embedding.cpu().numpy())
    cls_embeddings = np.vstack(cls_embeddings)
    torch.cuda.empty_cache()
    return cls_embeddings

### Sentence-T5

In [5]:
from sentence_transformers import SentenceTransformer

In [6]:
def extract_sentence_T5_embedding(dataset):
    """
    Extract sentence-T5 embeddings for all image captions in the given dataset
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SentenceTransformer('sentence-transformers/sentence-t5-large',device=device)
    embeddings = []
    for text in tqdm(dataset):      
        embedding = model.encode(text)
        embeddings.append(embedding)
    embeddings = np.vstack(embeddings)
    torch.cuda.empty_cache()
    return embeddings

### NV-Embed-v2

In [7]:
def extract_nv_embed_v2_embedding(dataset):
    """
    Extract NV-Embed-V2 embeddings for all image captions in the given dataset
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SentenceTransformer('nvidia/NV-Embed-v2', trust_remote_code=True, device=device)
    model.max_seq_length = 32768
    model.tokenizer.padding_side="right"
    def add_eos(input_examples):
        input_examples = [input_example + model.tokenizer.eos_token for input_example in input_examples]
        return input_examples
    embeddings = []
    for text in tqdm(dataset):      
        embedding = model.encode(add_eos(text), batch_size=1, normalize_embeddings=True)
        embeddings.append(embedding)
    embeddings = np.vstack(embeddings)
    torch.cuda.empty_cache()
    return embeddings

### gte-Qwen2-7B-instruct

In [8]:
def extract_qwen_embedding(dataset):
    """
    Extract gte-Qwen2-7B-instruct embeddings for all image captions in the given dataset
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-7B-instruct", trust_remote_code=True, device=device)
    embeddings = []
    for text in tqdm(dataset):      
        embedding = model.encode(text)
        embeddings.append(embedding)
    embeddings = np.vstack(embeddings)
    torch.cuda.empty_cache()
    return embeddings

### SFR-Embedding-2_R

In [9]:
def extract_sfr_embedding(dataset):
    """
    Extract SFR-Embedding-2_R embeddings for all image captions in the given dataset
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SentenceTransformer("Salesforce/SFR-Embedding-2_R", trust_remote_code=True, device=device)
    embeddings = []
    for text in tqdm(dataset):      
        embedding = model.encode(text)
        embeddings.append(embedding)
    embeddings = np.vstack(embeddings)
    torch.cuda.empty_cache()
    return embeddings

## Run on entire dataset

In [None]:
from datasets import Dataset, DatasetDict

In [None]:
text_emb_dataset_dict = DatasetDict()
for split in ['train','test','dev']:
    data = rhf_dataset_dict[split]['caption']
    bert_embeddings = extract_bert_embedding(data)
    sentence_t5_embeddings = extract_sentence_T5_embedding(data)
    nv_embeddings = extract_nv_embed_v2_embedding(data)
    qwen_embeddings = extract_qwen_embedding(data)
    sfr_embeddings = extract_sfr_embedding(data)
    text_emb_dataset_dict[split] = Dataset.from_dict({
        'BERT_text_embedding': bert_embeddings,
        'Sentence_T5_text_embedding': sentence_t5_embeddings,
        'NV_Embed_v2_text_embedding': nv_embeddings,
        'Qwen2_text_embedding': qwen_embeddings,
        'SFR_text_embedding': sfr_embeddings
        })
text_emb_dataset_dict

100%|██████████| 955/955 [00:09<00:00, 95.77it/s]
100%|██████████| 955/955 [00:15<00:00, 61.53it/s]
100%|██████████| 995/995 [00:09<00:00, 100.20it/s]
100%|██████████| 995/995 [00:16<00:00, 61.80it/s]


DatasetDict({
    test: Dataset({
        features: ['BERT_text_embedding', 'Sentence_T5_text_embedding'],
        num_rows: 955
    })
    dev: Dataset({
        features: ['BERT_text_embedding', 'Sentence_T5_text_embedding'],
        num_rows: 995
    })
})

In [None]:
text_emb_dataset_dict.push_to_hub("appliedml2024/text_embedding")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/enwq/text_embedding/commit/778c66e46a4739b6c48604ac888b1627b66d3b51', commit_message='Upload dataset', commit_description='', oid='778c66e46a4739b6c48604ac888b1627b66d3b51', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/enwq/text_embedding', endpoint='https://huggingface.co', repo_type='dataset', repo_id='enwq/text_embedding'), pr_revision=None, pr_num=None)