In [1]:
from transformers import BertConfig, BertForMaskedLM, Trainer, TrainingArguments, BertTokenizerFast
import dask.dataframe as dd
import os
from torch.utils.data import Dataset, DataLoader
from dask import delayed
from fastparquet import ParquetFile
import glob
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from transformers import BertTokenizerFast, RobertaTokenizerFast, AutoTokenizer
import multiprocessing
import random
from datetime import datetime
from matplotlib import pyplot as plt
import pandas as pd

RANDOM_SEED = 42

torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

BATCH_SIZE = 64
MAX_SEQ_LEN = 256
VOCAB_SIZE = 30000



In [2]:
class MERTDataset(Dataset):
    def __init__(self, file_dict, max_seq_len, vocab_size, dataframe=None):
        self.path = file_dict

        if file_dict:
            # Read files in chunks
            files = glob.glob(file_dict)
            print(files.__len__())

            ddf = dd.from_delayed([self.load_chunk(f) for f in files])
            self.data = ddf.compute()
        else:
            self.data = dataframe

        # Load tokenizer
        self.max_seq_len = max_seq_len
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
        # self.tokenizer = AutoTokenizer.from_pretrained("tokenizer")
        # self.tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")#RobertaTokenizerFast.from_pretrained("FacebookAI/roberta-base")
        # bert_tokenizer = self.tokenizer.train_new_from_iterator(text_iterator=self.batch_iterator(), vocab_size=VOCAB_SIZE)
        # bert_tokenizer.save_pretrained("tokenizer")
        # self.tokenizer = bert_tokenizer

        self.vocab = self.tokenizer.get_vocab()
        self.pad_i = self.vocab['[PAD]']
        self.mask_i = self.vocab['[MASK]']

    @delayed
    def load_chunk(self, pth):
        x = ParquetFile(pth).to_pandas()
        return x

    def batch_iterator(self):
        for i in tqdm(range(0, len(self.data), self.max_seq_len)):
            yield self.data[i : i + self.max_seq_len ]["text"]


    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        text = self.data['text'][idx]

        # Encode the sentence
        sentence = []
        label_sentence = []
        encoding = self.tokenizer.encode(text, max_length = self.max_seq_len, return_special_tokens_mask=True, truncation=True)
        attention_mask = [0] * self.max_seq_len

        i = 0
        for token in encoding:
            attention_mask[i] = 1
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15
                if prob < 0.8:
                    sentence.append(self.mask_i)
                elif prob < 0.9:
                    sentence.append(random.randrange(len(self.vocab)))
                else:
                    sentence.append(token)

                label_sentence.append(token)
            else:
                sentence.append(token)
                label_sentence.append(-100)

            i += 1

        # Pad sentence so that all sentences have same length
        padding = [self.pad_i for _ in range(self.max_seq_len - len(sentence))]
        sentence.extend(padding)
        padding = [-100 for _ in range(self.max_seq_len - len(label_sentence))]
        label_sentence.extend(padding)

        sentence = torch.as_tensor(sentence)
        label_sentence = torch.as_tensor(label_sentence)
        attention_mask = torch.as_tensor(attention_mask)

        return {'input_ids': sentence, 'labels': label_sentence, 'attention_mask': attention_mask}

In [3]:
dataset = MERTDataset('/media/maxim/DataSets/MERT/MERT-DATA/', max_seq_len=MAX_SEQ_LEN, vocab_size=VOCAB_SIZE)
tokenizer = dataset.tokenizer
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

1


In [4]:
train_args = TrainingArguments(output_dir="output_dir", save_steps=10000, bf16=True, torch_compile=True, per_device_train_batch_size=64)

config = BertConfig()

model = BertForMaskedLM(config=config)
# model = BertForMaskedLM.from_pretrained('bert-base-uncased')

trainer = Trainer(model, args=train_args, train_dataset=dataset)
trainer.train()

I0000 00:00:1712625758.852424   10451 cpu_client.cc:370] TfrtCpuClient created.
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmaxim-g[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/365814 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
model = BertForMaskedLM.from_pretrained('output_dir/checkpoint-10000')

In [None]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [None]:
text = 'The horse is [MASK] fat.'
tokens = tokenizer(text, return_tensors='pt')

print(tokens)

{'input_ids': tensor([[ 101, 1996, 3586, 2003,  103, 6638, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}


In [None]:
vocab = tokenizer.get_vocab()
vocab_swap = dict((v,k) for k,v in tokenizer.get_vocab().items())

In [None]:
pad = torch.zeros(MAX_SEQ_LEN - tokens['input_ids'][0].shape[0])

input_ids = torch.cat((tokens['input_ids'][0], pad), 0).unsqueeze(0).int()
attention_mask = torch.cat((tokens['attention_mask'][0], pad), 0).unsqueeze(0).int()

In [None]:
print(input_ids.shape)

torch.Size([1, 256])


In [None]:
print(attention_mask.shape)

torch.Size([1, 256])


In [None]:
pred_str = ""

for vals in input_ids[0]:
    pred_str += vocab_swap[vals.item()] + " "

print(pred_str)

[CLS] the horse is [MASK] fat . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [None]:
print(tokens)

{'input_ids': tensor([[ 101, 1996, 3586, 2003,  103, 6638, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}


In [None]:
tokens = {"input_ids": input_ids, "attention_mask": attention_mask}

In [None]:
print(tokens)

{'input_ids': tensor([[ 101, 1996, 3586, 2003,  103, 6638, 1012,  102,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,

In [None]:
outputs = model(**tokens)

In [None]:
print(outputs)

MaskedLMOutput(loss=None, logits=tensor([[[ -5.3164,  -3.5232,  -3.5342,  ...,  -2.8094,  -3.6205,  -3.5154],
         [-10.5610,  -6.7457,  -5.4357,  ...,  -6.0925,  -6.6187,  -9.2391],
         [ -7.8776,  -3.8386,  -3.8514,  ...,  -3.9736,  -4.6557,  -6.3922],
         ...,
         [ -8.8214,  -5.8837,  -4.2537,  ...,  -5.3946,  -5.6697,  -7.1925],
         [ -9.1171,  -6.3222,  -4.4915,  ...,  -5.6766,  -5.9392,  -7.3581],
         [ -9.7064,  -7.4902,  -5.4273,  ...,  -6.2417,  -6.4642,  -7.9772]]],
       grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)


In [None]:
vocab = dict((v,k) for k,v in tokenizer.get_vocab().items())
pred_str = ""
for vals in outputs[0][0]:
    pred_str += vocab_swap[vals.argmax().item()] + " "

print(pred_str)

[CLS] the horse is a fat . [SEP] . . . . . . . . . . . . . . . . . . . . . . . . . . . . . there . . . . . . . . . there . . . it there . . . . it there there . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . [SEP] 
