In [None]:
from transformers import BertTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer, pipeline
from functools import partial
import datasets
import torch
import os
import json
import datasets
from utils import *

In [None]:
def load_data(c):
    dsets = []
    dtest = []
    for data in c['datas']:
        if data == 'ncert_data' or data == 'khan_data' or data == 'learn_data' or data == 'siya_data' or data == 'ck_12' or data == 'em_data'or data == 'openstax':
            temp = datasets.load_dataset('csv', data_files=f'data/{data}.csv', cache_dir='./datasets', split="train")
            d = temp.train_test_split(test_size=0.01)
            print(f'load/create data from {d} Corpus for ELECTRA')
            dsets.append(d['train'])
            dtest.append(d['test'])

    merged_dsets = {'train': datasets.concatenate_datasets(dsets), 'test': datasets.concatenate_datasets(dtest)}
    return merged_dsets

In [None]:
medusa_config = {
    'batch_size': 16,
    'datas': ['siya_data', 'ck_12', 'openstax', 'em_data']
}
max_length=512
truncate_longer_samples=True
model_path = 'pretrained-edubert'

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

In [None]:
def encode_with_truncation(examples):
  """Mapping function to tokenize the sentences passed with truncation"""
  return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length, return_special_tokens_mask=True)

def encode_without_truncation(examples):
  """Mapping function to tokenize the sentences passed without truncation"""
  return tokenizer(examples["text"], return_special_tokens_mask=True)

encode = encode_with_truncation if truncate_longer_samples else encode_without_truncation

In [None]:
d = load_data(medusa_config)

In [None]:
train_dataset = d["train"].map(encode, batched=True)
test_dataset = d["test"].map(encode, batched=True)
if truncate_longer_samples:
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
  test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
else:
  test_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])
  train_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])
train_dataset, test_dataset

In [None]:
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= max_length:
        total_length = (total_length // max_length) * max_length
    result = {
        k: [t[i : i + max_length] for i in range(0, total_length, max_length)]
        for k, t in concatenated_examples.items()
    }
    return result

if not truncate_longer_samples:
  train_dataset = train_dataset.map(group_texts, batched=True, batch_size=medusa_config['batch_size'],
                                    desc=f"Grouping texts in chunks of {max_length}")
  test_dataset = test_dataset.map(group_texts, batched=True, batch_size=medusa_config['batch_size'],
                                  num_proc=4, desc=f"Grouping texts in chunks of {max_length}")

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.2
)

In [None]:
training_args = TrainingArguments(
    output_dir=model_path,          # output directory to where save model checkpoint
    evaluation_strategy="steps",    # evaluate each `logging_steps` steps
    overwrite_output_dir=True,      
    num_train_epochs=10,            # number of training epochs, feel free to tweak
    per_device_train_batch_size=medusa_config['batch_size'], # the training batch size, put it as high as your GPU memory fits
    gradient_accumulation_steps=128//medusa_config['batch_size'],  # accumulating the gradients before updating the weights
    per_device_eval_batch_size=medusa_config['batch_size'],  # evaluation batch size
    logging_steps=1000,             # evaluate, log and save model checkpoints every 1000 step
    save_steps=1000,
    load_best_model_at_end=True,  # whether to load the best model (in terms of loss) at the end of training
    gradient_checkpointing=True,  # gradient checkpointing
    fp16=True                     # Multi-precision training
    # save_total_limit=3,           # whether you don't have much space so you let only 3 model weights saved in the disk
)

In [None]:
# initialize the trainer and pass everything to it
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

In [None]:
# train the model
trainer.train()