In [1]:
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    BertForMaskedLM,
    get_scheduler,
    TrainingArguments,
    Trainer,
    TrainerCallback
)
import io
from datasets import load_dataset
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
import wandb

In [2]:
from MinioHandler import MinioHandler

minio = MinioHandler()

In [3]:
wandb.login()

wandb.init(
    project='pretrain-bert',
    entity='grammar-bert'
)

[34m[1mwandb[0m: Currently logged in as: [33mxenomirant[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33mxenomirant[0m ([33mgrammar-bert[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [5]:
TRAIN_PATH = 'data/train_dataset.csv'
TEST_PATH = 'data/test_dataset.csv'
MODEL_NAME = 'DeepPavlov/rubert-base-cased'
WEIGHTS_PATH = "ckpt/pretrained_bert/model_epoch_10.pt"
SEQ_LEN = 64
BATCH_SIZE = 16
MLM_PROB = 0.15

In [6]:
def collate_func(batch):
    batch = [data_collator.torch_call(item) for item in zip(*batch)]
    return batch

In [7]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

tokenizer.pad_token = '[SEP]'
tokenizer.eos_token = '[SEP]'
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=MLM_PROB)

In [8]:
dt = load_dataset("csv", 
                  data_files={"train": "data/train_dataset.csv",
                                "test": "data/test_dataset.csv"},)

In [9]:
def tokenize_function(examples):
    return tokenizer(examples["base"])

In [10]:
tokenized_dt = dt.map(tokenize_function, batched=True, remove_columns=["Unnamed: 0", "polypers", "was_changed"])

In [11]:
model = BertForMaskedLM.from_pretrained(MODEL_NAME)
model.to(device)
pass

In [12]:
ckpt = minio.get_object(WEIGHTS_PATH, type="model")
model_dict = torch.load(ckpt)

In [13]:
model_dict.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict'])

In [14]:
model.load_state_dict(model_dict["model_state_dict"])

<All keys matched successfully>

In [15]:
model.train();

In [26]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [17]:
class SaveCallback(TrainerCallback):


    def on_train_begin(self, args, state, control, **kwargs):
        '''
        A callback that prints a message at the beginning of training
        '''
        print("Starting training")

    def on_epoch_end(self, args, state, control, **kwargs):
        '''
        Saves to S3 at the end of epoch
        '''
        print("Saving model checkpoint...")
        buffer = io.BytesIO()
        torch.save({
                    'epoch': state.epoch,
                    'model_state_dict': kwargs["model"].state_dict(),
                    'optimizer_state_dict': kwargs["optimizer"].state_dict(),
                    }, 
                   f=buffer)
                # TODO -- add custom hash to model instead of value
        minio.put_object(buffer.getvalue(), 
                             save_name=f"ckpt/pretrained_bert_v2/epoch_{int(state.epoch // 1)}.pt")

In [18]:
import sys, os
from transformers.trainer_callback import ProgressCallback
# Disable
def blockPrint():
    sys.stdout = open(os.devnull, 'w')

# Restore
def enablePrint():
    sys.stdout = sys.__stdout__


def on_log(self, args, state, control, logs=None, **kwargs):
    if state.is_local_process_zero and self.training_bar is not None:
        _ = logs.pop("total_flos", None)
ProgressCallback.on_log = on_log

In [27]:
training_args = TrainingArguments(
    output_dir="ckpt/pretrained_bert",
    dataloader_drop_last=True,
    dataloader_num_workers=6, 
    learning_rate=2e-5,
    num_train_epochs=12,
    weight_decay=0.001, 
    per_device_train_batch_size=6,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant': True}, 
    optim="adafactor",
    warmup_steps=1000,
    report_to="wandb", 
    logging_steps=2000,
    save_steps=5000,
    save_total_limit=10,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dt["train"],
    eval_dataset=tokenized_dt["test"],
    data_collator=data_collator,
    callbacks=[SaveCallback, ProgressCallback]
)

In [20]:
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

#### Infinite tries to disable logging to stdout

In [21]:
# blockPrint()

In [28]:
trainer.train(resume_from_checkpoint=True);

You are resuming training from a checkpoint trained with 4.36.2 of Transformers but your current version is 4.39.3. This is not recommended and could yield to errors or unwanted behaviors.
There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Starting training


  0%|          | 0/775728 [00:00<?, ?it/s]

Step,Training Loss
646000,1.6602
648000,1.6857
650000,1.6937
652000,1.6864


Saving model checkpoint...
ModularLM/ckpt/pretrained_bert_v2/epoch_9.pt: |####################| 681.16 MB/681.16 MB 100% [elapsed: 00:08 left: 00:00, 84.46 MB/sec] 

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.57 GiB. GPU 0 has a total capacity of 7.78 GiB of which 644.81 MiB is free. Process 1594116 has 7.13 GiB memory in use. Of the allocated memory 5.68 GiB is allocated by PyTorch, and 1.33 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)