In [2]:
import pandas as pd
import numpy as np
import torch

from transformers import BertTokenizerFast, BertModel
from datasets import load_dataset

from pprint import pprint

import logging
logging.basicConfig(level=logging.INFO) 

import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda', index=0)

In [3]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [37]:
arxiv = load_dataset('json', data_files='dataset/arxiv_data.json')
arxiv

Downloading data files: 100%|██████████| 1/1 [00:00<?, ?it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 90.87it/s]
Generating train split: 10000 examples [00:00, 101331.52 examples/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'submitter', 'authors', 'title', 'comments', 'journal-ref', 'doi', 'report-no', 'categories', 'license', 'abstract', 'versions', 'update_date', 'authors_parsed'],
        num_rows: 10000
    })
})

In [38]:
def tokenize_dataset(data):
    return tokenizer(data['abstract'], padding=True, truncation=True, max_length=256)
    
print(tokenize_dataset(arxiv['train'][:2]))

{'input_ids': [[101, 1037, 3929, 11658, 17208, 1999, 2566, 20689, 14479, 3512, 8559, 10381, 21716, 7716, 18279, 22924, 2003, 3591, 2005, 1996, 2537, 1997, 5294, 26383, 7689, 2012, 2018, 4948, 8902, 24198, 2869, 1012, 2035, 2279, 1011, 2000, 1011, 2877, 2344, 2566, 20689, 14479, 3512, 5857, 2013, 24209, 17007, 1011, 3424, 16211, 8024, 1010, 1043, 7630, 2239, 1011, 1006, 3424, 1007, 24209, 17007, 1010, 1998, 1043, 7630, 2239, 1011, 1043, 7630, 2239, 4942, 21572, 9623, 8583, 2024, 2443, 1010, 2004, 2092, 2004, 2035, 1011, 4449, 24501, 2819, 28649, 1997, 3988, 1011, 2110, 1043, 7630, 2239, 8249, 9398, 2012, 2279, 1011, 2000, 1011, 2279, 1011, 2000, 1011, 2877, 8833, 8486, 2705, 7712, 10640, 1012, 1996, 2555, 1997, 4403, 2686, 2003, 9675, 1999, 2029, 1996, 17208, 2003, 2087, 10539, 1012, 2204, 3820, 2003, 7645, 2007, 2951, 2013, 1996, 10768, 28550, 20470, 8915, 22879, 4948, 1010, 1998, 20932, 2024, 2081, 2005, 2062, 6851, 5852, 2007, 3729, 2546, 1998, 2079, 2951, 1012, 20932, 2024, 3491, 20

In [39]:
encoded_dataset = arxiv.map(tokenize_dataset, batched=True, batch_size=128, remove_columns=['id', 'submitter', 'authors', 'title', 'comments', 'journal-ref', 'doi', 'report-no', 'categories', 'license', 'versions', 'update_date', 'authors_parsed'])

Map: 100%|██████████| 10000/10000 [00:01<00:00, 8372.27 examples/s]


In [40]:
print(f'encoded_dataset: {encoded_dataset}')
print(f'encoded_dataset["train"]: {encoded_dataset["train"]}')

torch.save(encoded_dataset, 'dataset/arxiv_encoded.pt')

encoded_dataset: DatasetDict({
    train: Dataset({
        features: ['abstract', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 10000
    })
})
encoded_dataset["train"]: Dataset({
    features: ['abstract', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 10000
})


In [41]:
loaded = torch.load('dataset/arxiv_encoded.pt')

In [42]:
def extract_hidden_states(batch):
    inputs = {k: v.to(device) for k, v in batch.items() if k in tokenizer.model_input_names}
    with torch.no_grad():
        outputs = model(**inputs)
        hidden_states = outputs.hidden_states
        embeddings = hidden_states[-2].squeeze(0)
    return {'embeddings': embeddings.cpu().numpy()}

encoded_dataset.set_format('torch', columns=['input_ids', 'attention_mask'])

dataset_hidden = encoded_dataset.map(extract_hidden_states, batched=True, batch_size=128)

Map: 100%|██████████| 10000/10000 [00:51<00:00, 194.13 examples/s]


In [48]:
bert_embeddings = dataset_hidden['train']['embeddings']

In [50]:
bert_embeddings

torch.Size([10000, 256, 768])

In [51]:
torch.save(bert_embeddings, 'bert_embeddings_2.pt')