In [1]:
 #! pip install accelerate==0.27.2

In [2]:
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    BertForMaskedLM,
    get_scheduler,
    TrainingArguments,
    Trainer,
    TrainerCallback
)
import io
from datasets import load_dataset
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
import wandb
from collections import OrderedDict

In [3]:
from MinioHandler import MinioHandler

minio = MinioHandler()

In [4]:
wandb.login()

wandb.init(
    project='pretrain-bert',
    entity='grammar-bert',
    name="Poly MLM head higher lr"
)

[34m[1mwandb[0m: Currently logged in as: [33mxenomirant[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Currently logged in as: [33mxenomirant[0m ([33mgrammar-bert[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [6]:
TRAIN_PATH = 'data/train_dataset.csv'
TEST_PATH = 'data/test_dataset.csv'
MODEL_NAME = 'DeepPavlov/rubert-base-cased'
SEQ_LEN = 64
BATCH_SIZE = 128
MLM_PROB = 0.15

In [7]:
WEIGHTS_PATH = "ckpt/pretrained_bert/model_epoch_10.pt"

In [8]:
def collate_func(batch):
    batch = [data_collator.torch_call(item) for item in zip(*batch)]
    return batch

In [9]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

tokenizer.pad_token = '[SEP]'
tokenizer.eos_token = '[SEP]'
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=MLM_PROB)

In [10]:
dt = load_dataset("csv", 
                  data_files={"train": TRAIN_PATH,
                                "test": TEST_PATH},)

In [11]:
def tokenize_function(examples):
    return tokenizer(examples["polypers"])

In [12]:
tokenized_dt = dt.map(tokenize_function, batched=True, remove_columns=["Unnamed: 0", "base", "was_changed"])

In [13]:
model = BertForMaskedLM.from_pretrained(MODEL_NAME)
model.to(device)
pass

In [14]:
model

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [15]:
ckpt = minio.get_object(WEIGHTS_PATH, type="model")
model_dict = torch.load(ckpt)

# necessary for averaged models
# model_dict["model_state_dict"] = {".".join(k.split(".")[1:]): v for k, v in model_dict["model_state_dict"].items() if ".".join(k.split(".")[1:])}

In [16]:
model.load_state_dict(model_dict["model_state_dict"])

<All keys matched successfully>

In [17]:
model.cls.predictions.transform(torch.randn((5, 768)).to(model.device))

tensor([[-2.0411, -2.3667, -1.9080,  ..., -1.8968, -1.8275, -1.8176],
        [ 3.2326, -2.7675, -2.5292,  ..., -2.2934,  1.3292, -2.3785],
        [ 2.1270,  0.2167, -1.5603,  ..., -2.1788, -0.1589, -2.4368],
        [ 0.0673,  0.5578, -3.3598,  ...,  3.8780,  9.6325, -3.7395],
        [-2.1819,  2.7796, -2.1316,  ..., -2.0703, -2.4410, -2.5248]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [18]:
for param in model.parameters():
    param.requires_grad = False

for name, param in model.named_parameters():
    if name.startswith("cls"):
        param.requires_grad = True
    

In [19]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

cls.predictions.bias
cls.predictions.transform.dense.weight
cls.predictions.transform.dense.bias
cls.predictions.transform.LayerNorm.weight
cls.predictions.transform.LayerNorm.bias


In [20]:
class SaveCallback(TrainerCallback):


    def on_train_begin(self, args, state, control, **kwargs):
        '''
        A callback that prints a message at the beginning of training
        '''
        print("Starting training")

    def on_epoch_end(self, args, state, control, **kwargs):
        '''
        Saves to S3 at the end of epoch
        '''
        print("Saving model checkpoint...")
        buffer = io.BytesIO()
        torch.save({
                    'epoch': state.epoch,
                    'model_state_dict': OrderedDict({k: v for k, v in kwargs["model"].cls.state_dict().items() if ~k.startswith("predictions.decoder")}),
                    'optimizer_state_dict': kwargs["optimizer"].state_dict(),
                    }, 
                   f=buffer)
                # TODO -- add custom hash to model instead of value
        minio.put_object(buffer.getvalue(), 
                             save_name=f"ckpt/poly_mlm-head_higher_lr_epoch_{int(state.epoch)}.pt")

In [21]:
import sys, os
from transformers.trainer_callback import ProgressCallback
# Disable
def blockPrint():
    sys.stdout = open(os.devnull, 'w')

# Restore
def enablePrint():
    sys.stdout = sys.__stdout__


def on_log(self, args, state, control, logs=None, **kwargs):
    if state.is_local_process_zero and self.training_bar is not None:
        _ = logs.pop("total_flos", None)
ProgressCallback.on_log = on_log

In [22]:
training_args = TrainingArguments(
    output_dir="ckpt/poly_mlm-head higher lr",
    evaluation_strategy="epoch",
    dataloader_drop_last=True,
    dataloader_num_workers=6, 
    learning_rate=3e-3,
    num_train_epochs=5,
    gradient_accumulation_steps=6,
    per_device_train_batch_size=8,
    adafactor=True,
    optim="adafactor",
    warmup_steps=1000,
    report_to="wandb", 
    logging_steps=5000,
    save_steps=25000,
    save_total_limit=10,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dt["train"],
    eval_dataset=tokenized_dt["test"],
    data_collator=data_collator,
    callbacks=[SaveCallback, ProgressCallback]
)



In [23]:
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

#### Infinite tries to disable logging to stdout

In [None]:
trainer.train();

Starting training


  0%|          | 0/215480 [00:00<?, ?it/s]

Epoch,Training Loss,Validation Loss
0,1.8525,1.878477
2,1.7717,1.82357
4,1.7578,1.802954


Saving model checkpoint...
ModularLM/ckpt/poly_mlm-head_higher_lr_epoch_0.pt: |####################| 353.43 MB/353.43 MB 100% [elapsed: 00:04 left: 00:00, 73.03 MB/sec]

  0%|          | 0/28731 [00:00<?, ?it/s]

Saving model checkpoint...
ModularLM/ckpt/poly_mlm-head_higher_lr_epoch_4.pt: |####################| 353.43 MB/353.43 MB 100% [elapsed: 00:04 left: 00:00, 86.31 MB/sec] 

  0%|          | 0/28731 [00:00<?, ?it/s]