In [1]:
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
import torch
import numpy as np
from tqdm import tqdm

In [2]:
dataset = load_dataset("SetFit/20_newsgroups")


Repo card metadata block was not found. Setting CardData to empty.


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [4]:
train = dataset["train"]
test = dataset["test"]

In [5]:
model_path = 'Alibaba-NLP/gte-base-en-v1.5'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(device)



In [28]:
def get_embeddings(data, max_length=4096):
    texts, labels = data['text'], data['label']
    embeddings_res = []
    labels_res = []
    for i, (text, label) in tqdm(enumerate(zip(texts, labels))):
        input = tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
        input_len = len(input['input_ids'][0])
        if input_len > max_length:
            continue 
        train_outputs = model(**input)
        embedding = train_outputs.last_hidden_state[:, 0].detach().to('cpu').numpy()
        embeddings_res.append(embedding)
        labels_res.append(label)

        del input, train_outputs
        torch.cuda.empty_cache() 
    embeddings_res = np.concatenate(embeddings_res, axis=0)
    return embeddings_res, np.array(labels_res)

In [7]:
train_embeddings, train_labels = get_embeddings(train, max_length=4096)
test_embeddings, test_labels = get_embeddings(test, max_length=4096)

11314it [06:19, 29.78it/s]


In [25]:
np.save('train_embeddings.npy', train_embeddings)
np.save('train_labels.npy', train_labels)
np.save('test_embeddings.npy', test_embeddings)
np.save('test_labels.npy', test_labels)
