In [1]:
import pandas as pd
import pickle
import torch
from transformers import BertTokenizer, LineByLineTextDataset
from pathlib import Path

In [3]:
vocab_file_path = Path('30_tok-vocab.txt')
tokenizer = BertTokenizer.from_pretrained(vocab_file_path)



Датасет из pickle быстрее грузить

In [5]:
%%time
with open('dataset.pickle', 'rb') as handle:
    dataset = pickle.load(handle)

CPU times: user 2min 2s, sys: 3 s, total: 2min 5s
Wall time: 2min 5s


In [9]:
len(dataset)

3268226

In [12]:
3268226 - 10000

3258226

Валидация на 10к примеров

In [7]:
train_set, val_set = torch.utils.data.random_split(dataset, [3258226, 10000])

In [8]:
from transformers import BertConfig, BertForMaskedLM, DataCollatorForLanguageModeling

config = BertConfig(
    vocab_size=30000,
    hidden_size=768, 
    num_hidden_layers=6, 
    num_attention_heads=6,
    max_position_embeddings=512
)
 
model = BertForMaskedLM(config)
print('No of parameters: ', model.num_parameters())


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)


No of parameters:  66585648


In [9]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='output/',
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=64,
    save_steps=10_000,
    save_total_limit=10,
    prediction_loss_only=True,
    dataloader_num_workers=8,
    evaluation_strategy='steps',
    eval_steps=10000
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_set,
    eval_dataset=val_set
)

In [None]:
%%time
trainer.train()
trainer.save_model('big10e/')

Step,Training Loss,Validation Loss,Runtime,Samples Per Second
10000,0.8767,1.447483,20.2844,492.991
20000,0.7565,1.222717,18.0336,554.521
30000,0.6818,1.15485,17.7638,562.944
40000,0.6501,1.089546,17.829,560.884
50000,0.6213,1.028589,17.5484,569.851
60000,0.6007,0.985878,18.2523,547.877
70000,0.5746,0.953577,17.8206,561.148
80000,0.5522,0.900199,17.2356,580.195
