In [None]:
extraction_dataset_max_length = 3200000
bert_model_path = "..\\custom_models\\bert_model"
extraction_dataset_path = "..\\custom_datasets\\extraction_dataset"
open_web_text_dataset_path = "..\\custom_datasets\\open_web_text_dataset"

In [1]:
import os
import sys
sys.path.append(os.path.realpath("../"))

In [2]:
import gc
import torch
from datasets import load_from_disk, Dataset, DatasetDict
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BertForPreTraining, DefaultDataCollator
from TokenizedBERTDatasetModule import TokenizedBERTDataset

In [3]:
device = torch.device("cuda") if torch.cuda.is_available else torch.device("cpu")

In [None]:
open_web_text_dataset = load_from_disk(open_web_text_dataset_path)

In [5]:
data_loader = DataLoader(
    TokenizedBERTDataset([open_web_text_dataset['train']], truncate_resulting_item_flag=True, include_idx_flag=True), batch_size=1,
    collate_fn=DefaultDataCollator(), shuffle=True, pin_memory=True, pin_memory_device="cuda:0"
)



In [6]:
bert_model = BertForPreTraining.from_pretrained(bert_model_path)
bert_model.to(device)
bert_model.eval()

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [7]:
def build_processed_labels(batch, predictions):
    processed_labels = [(batch['labels'][0][i] if batch['labels'][0][i] == -100 else predictions[0][i]).item() for i in range(len(predictions[0]))]
    processed_labels = torch.tensor(processed_labels, dtype=torch.long).to(device)
    return processed_labels

In [8]:
def extraction_dataset_gen(dataset_name, max_length=0):
    print(dataset_name + ": ")
    progress_bar = tqdm(range(len(data_loader) if max_length <= 0 else max_length))
    
    index = 0
    for batch in data_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        idx = batch['idx']
        del batch['idx']
        outputs = bert_model(**batch)
        predictions = torch.argmax(outputs.prediction_logits, dim=-1)
        extracted_batch = {
            'idx': idx[0],
            'input_ids': batch['input_ids'][0],
            'token_type_ids': batch['token_type_ids'][0],
            'attention_mask': batch['attention_mask'][0],
            'labels': build_processed_labels(batch, predictions),
            'next_sentence_label': torch.argmax(outputs.seq_relationship_logits, dim=-1)[0]
        }
        yield extracted_batch

        #   Descomente as linhas abaixo somente se estiver tendo problemas de consumo de memória ao executar o script, já que elas aumentam o tempo
        #de execução
        
        # gc.collect()
        # torch.cuda.empty_cache()
        progress_bar.update()
        index += 1
        if(max_length > 0 and index >= max_length):
            break

In [9]:
def build_extraction_dataset_dict():
    return DatasetDict({
        "train": Dataset.from_generator(extraction_dataset_gen, gen_kwargs={"dataset_name": "train", "max_length": extraction_dataset_max_length})
    })

In [10]:
extraction_dataset = build_extraction_dataset_dict()

Generating train split: 0 examples [00:00, ? examples/s]

train: 


  0%|          | 0/80000 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

validation: 


  0%|          | 0/20000 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

test: 


  0%|          | 0/40000 [00:00<?, ?it/s]

In [11]:
extraction_dataset.save_to_disk(extraction_dataset_path)

Saving the dataset (0/2 shards):   0%|          | 0/80000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/20000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/40000 [00:00<?, ? examples/s]