In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from transformers import BertTokenizer, BertModel, DataCollatorWithPadding
from datasets import load_dataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("imdb", split="train")

Found cached dataset imdb (C:/Users/Alex/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


In [3]:
np.random.seed(100)
idx = np.random.randint(len(dataset), size=200)
index_list = idx.tolist()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
if device.type == 'cuda':
    print("GPU:", torch.cuda.get_device_name())

Device: cuda:0
GPU: NVIDIA GeForce RTX 3080


In [4]:
model = BertModel.from_pretrained("bert-base-cased").to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
def tokenization(example):
    return tokenizer.batch_encode_plus(example['text'], add_special_tokens=True, return_token_type_ids=False,
                                       truncation=True)

In [6]:
subset_dataset = dataset.map(tokenization, batched=True)
subset_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

                                                                  

In [7]:
loader = DataLoader(Subset(subset_dataset, index_list), batch_size=64, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
                          pin_memory=True, shuffle=False)

In [8]:
@torch.no_grad()
def get_embeddings_labels(model, loader):
    model.eval()

    total_embeddings = []
    labels = []

    for batch in tqdm(loader):
#         print(batch)
        labels.append(batch['labels'].unsqueeze(1))

        batch = {key: batch[key].to(device) for key in ['attention_mask', 'input_ids']}

        embeddings = model(**batch)['last_hidden_state'][:, 0, :]

        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0), torch.cat(labels, dim=0).to(torch.float32)

In [9]:
embeddings, labels = get_embeddings_labels(model, loader)

print("Embeddings shape:", embeddings.shape)
print("Labels shape:", labels.shape)

100%|██████████| 4/4 [00:56<00:00, 14.05s/it]

Embeddings shape: torch.Size([200, 768])
Labels shape: torch.Size([200, 1])





In [10]:
torch.save(embeddings, 'embeddings.pt')

In [11]:
from transformers import RobertaTokenizer, RobertaModel

tokenizer_roberta = RobertaTokenizer.from_pretrained('roberta-base')
model_roberta = RobertaModel.from_pretrained('roberta-base').to(device)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
def tokenization(example):
    return tokenizer_roberta.batch_encode_plus(example['text'], add_special_tokens=True, return_token_type_ids=False,
                                       truncation=True)

In [13]:
subset_dataset = dataset.map(tokenization, batched=True)
subset_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

                                                                   

In [14]:
loader = DataLoader(Subset(subset_dataset, index_list), batch_size=64, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer_roberta),
                          pin_memory=True, shuffle=False)

In [15]:
embeddings_roberta, labels_roberta = get_embeddings_labels(model_roberta, loader)

print("Embeddings shape:", embeddings.shape)
print("Labels shape:", labels.shape)

 25%|██▌       | 1/4 [00:48<02:25, 48.58s/it]

In [None]:
torch.save(embeddings_roberta, 'embeddings_roberta.pt')

In [None]:
embeddings_roberta

tensor([[-0.0379,  0.0640, -0.0357,  ..., -0.0944, -0.0324,  0.0180],
        [-0.1240,  0.0969,  0.0223,  ..., -0.0980, -0.0369, -0.0156],
        [-0.0631,  0.0646, -0.0219,  ..., -0.0620, -0.0220, -0.0716],
        ...,
        [-0.0834,  0.0961,  0.0103,  ..., -0.0602, -0.0039, -0.0268],
        [-0.0933,  0.1287,  0.0146,  ..., -0.0457, -0.0085,  0.0006],
        [-0.0812,  0.0838, -0.0058,  ..., -0.1366, -0.0729, -0.0551]])

# Distilbert

In [None]:
from transformers import DistilBertTokenizer, DistilBertModel

tokenizer_distilbert = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model_distilbert = DistilBertModel.from_pretrained("distilbert-base-cased").to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def tokenization(example):
    return tokenizer_distilbert.batch_encode_plus(example['text'], add_special_tokens=True, return_token_type_ids=False,
                                       truncation=True)

In [None]:
subset_dataset = dataset.map(tokenization, batched=True)
subset_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

  0%|          | 0/25 [00:00<?, ?ba/s]

In [None]:
loader = DataLoader(Subset(subset_dataset, index_list), batch_size=64, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer_distilbert),
                          pin_memory=True, shuffle=False)

In [None]:
embeddings_distilbert, labels_distilbert = get_embeddings_labels(model_distilbert, loader)

print("Embeddings shape:", embeddings_distilbert.shape)
print("Labels shape:", labels_distilbert.shape)

100%|██████████| 4/4 [00:01<00:00,  2.11it/s]

Embeddings shape: torch.Size([200, 768])
Labels shape: torch.Size([200, 1])





In [None]:
torch.save(embeddings_distilbert, 'embeddings_distilbert.pt')

In [None]:
embeddings_distilbert

tensor([[ 2.1681e-01, -1.9808e-01,  1.0595e-01,  ..., -9.3751e-02,
          5.8220e-01,  3.1306e-01],
        [ 1.6410e-01, -1.0337e-01,  1.8239e-02,  ...,  1.1011e-01,
          4.9750e-01,  1.3359e-01],
        [ 3.2910e-01, -1.7247e-01,  1.4982e-01,  ..., -1.1308e-01,
          5.1901e-01,  4.0267e-01],
        ...,
        [-5.1256e-02, -1.6513e-01, -6.0256e-02,  ...,  7.7827e-05,
          6.9394e-01,  3.7176e-01],
        [ 3.6995e-02, -1.8831e-01,  1.3551e-01,  ..., -1.1412e-01,
          4.9967e-01,  2.5347e-01],
        [ 1.1076e-01, -1.1344e-01,  1.5002e-01,  ..., -7.0742e-04,
          5.4633e-01,  4.0668e-01]])