In [1]:
# Импортируем необходимые библиотеки
import torch
import torch.nn as nn
import numpy as np
from warnings import filterwarnings
filterwarnings('ignore')
from transformers import AutoTokenizer
from transformers import BertModel
from transformers import RobertaModel
from transformers import DistilBertModel

# Определяем функцию для загрузки предобученных моделей
def get_model(model_name):
    assert model_name in ['bert', 'roberta', 'distilbert']
    checkpoint_names = {
        'bert': 'bert-base-cased',
        'roberta': 'roberta-base',
        'distilbert': 'distilbert-base-cased'
    }
    model_classes = {
        'bert': BertModel,
        'roberta': RobertaModel,
        'distilbert': DistilBertModel
    }
    return AutoTokenizer.from_pretrained(checkpoint_names[model_name]), model_classes[model_name].from_pretrained(checkpoint_names[model_name])

In [2]:
from tqdm import tqdm

# Определяем функцию для получения эмбеддингов и меток из модели
@torch.inference_mode()
def get_embeddings_labels(model, loader):
    model.eval()
    
    total_embeddings = []
    labels = []
    
    for batch in tqdm(loader):
        labels.append(batch['labels'].unsqueeze(1))

        batch = {key: batch[key].to(device) for key in ['attention_mask', 'input_ids']}

        embeddings = model(**batch)['last_hidden_state'][:, 0, :]

        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0), torch.cat(labels, dim=0).to(torch.float32)

In [3]:
from datasets import load_dataset

# Загружаем набор данных IMDB
dataset = load_dataset("imdb", split="train")

Found cached dataset imdb (C:/Users/Alex/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


In [4]:
# Инициализируем модель и токенизатор
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer, model = get_model('roberta')
model = model.to(device)

# Определяем функцию для токенизации текста
def tokenization(example):
    return tokenizer.batch_encode_plus(example['text'], add_special_tokens=True, return_token_type_ids=False, truncation=True, padding='max_length', max_length=512)


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
# Применяем функцию токенизации к набору данных и устанавливаем формат набора данных
dataset = dataset.map(tokenization, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

                                                                   

In [8]:
from torch.utils.data import DataLoader, Subset

# Генерируем 200 случайных индексов
np.random.seed(100)
idx = np.random.randint(len(dataset), size=200).tolist()

from transformers import DataCollatorWithPadding

# Создаем data_collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Создаем DataLoader с data_collator
loader = DataLoader(Subset(dataset, idx), batch_size=16, shuffle=False, collate_fn=data_collator)


In [9]:
# Получаем эмбеддинги и метки из модели
embeddings, labels = get_embeddings_labels(model, loader)

# Проверяем размерность эмбеддингов
assert embeddings.shape == (200, 768), 'The embeddings tensor has the wrong shape.'

  0%|          | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 13/13 [00:03<00:00,  3.32it/s]


In [10]:
print(embeddings.shape, labels.shape)

torch.Size([200, 768]) torch.Size([200, 1])


In [None]:
torch.save(embeddings, 'wtf.pt')