In [5]:
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

from transformers import (AutoModel,AutoModelForMaskedLM, 
                          AutoTokenizer, LineByLineTextDataset,
                          DataCollatorForLanguageModeling,
                          Trainer, TrainingArguments)

In [6]:
train_data = pd.read_csv('../input/commonlitreadabilityprize/train.csv')
test_data = pd.read_csv('../input/commonlitreadabilityprize/test.csv')

data = pd.concat([train_data,test_data])
data['excerpt'] = data['excerpt'].apply(lambda x: x.replace('\n',''))

text  = '\n'.join(data.excerpt.tolist())

with open('text.txt','w') as f:
    f.write(text)

In [7]:
model_name = 'roberta-base'
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained('./clrp_roberta_base');

In [8]:
train_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="text.txt",
    block_size=256)

valid_dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="text.txt", 
    block_size=256)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(
    output_dir="./checkpoints", 
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy= 'steps',
    save_total_limit=2,
    eval_steps=200,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    load_best_model_at_end =True,
    prediction_loss_only=True,
    report_to = "none")

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset)

In [9]:
trainer.train()
trainer.save_model(f'./clrp_roberta_base')

Step,Training Loss,Validation Loss,Runtime,Samples Per Second
200,No log,1.441347,31.9752,88.85
400,No log,1.383025,32.0221,88.72
600,1.550100,1.340914,32.0562,88.626
800,1.550100,1.319844,32.092,88.527


In [10]:
print('Done')

Done


In [11]:
!ls

__notebook_source__.ipynb  checkpoints	clrp_roberta_base  text.txt
