In [1]:
import pandas as pd 
import numpy as np 
import warnings 
warnings.filterwarnings("ignore")
from colorama import Fore , Style,Back
import re 
import torch 
import os 
import torch.nn as nn 
from transformers import AutoModelForMaskedLM,AutoTokenizer,Trainer,LineByLineTextDataset,\
DataCollatorForLanguageModeling , TrainingArguments , AutoModel ,AdamW
from torch.utils.data import Dataset,DataLoader
from sklearn.model_selection import StratifiedKFold
import random
r_ = Fore.RED
G_ = Fore.GREEN
Y_ = Fore.YELLOW
st_ = Style.RESET_ALL

In [2]:
config = {
    "batch_size" : 16 ,
    "lr" : 5e-5,
    "wb" : 2e-5,
    "batch_size" : 16,
    "max_len" : 256,
    "fold" : 5,
    "seed" : 42,
    "epochs" : 5
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONASSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [4]:
seed_everything(config["seed"])

In [5]:
train = pd.read_csv("../input/commonlitreadabilityprize/train.csv")
test = pd.read_csv("../input/commonlitreadabilityprize/test.csv")
datas = pd.concat([train,test])

In [6]:
def clean_text(excerpt):
    punctuations = ".,?!;\(\":-)‘"
    extrait = excerpt
    for p in punctuations : 
      extrait = extrait.replace(p,f" {p} ")
    extrait = re.sub(r"'s"," is ",extrait)
    extrait = extrait.replace("i'm","I'm")
    extrait = extrait.replace("don't","do not")
    extrait = extrait.replace("didn't","did not")
    extrait = extrait.replace("can't","cannot")
    extrait = extrait.replace("i'll","I will")
    extrait = extrait.replace("wouldn't","would not")
    extrait = extrait.replace("i've","I have")
    extrait = re.sub(r"i've","I have",extrait)
    extrait = extrait.replace("won't","will not")
    extrait = extrait.replace("couldn't","could not")
    extrait = extrait.replace("wasn't","was not")
    extrait = extrait.replace("you'll","you will")
    extrait = extrait.replace("isn't","is not")
    extrait = extrait.replace("you're","you are")
    extrait = extrait.replace("hadn't","had not")
    extrait = extrait.replace("you've","you have")
    extrait = extrait.replace("doesn't","does not")
    extrait = extrait.replace("haven't","have not")
    extrait = extrait.replace("they're","they are")
    extrait = extrait.replace("we're","we are")
    #extrait = re.sub(r"(/s+)i(/s+)","I",excerpt)
    #extrait = re.sub(r"don't","do not",extrait)
    #extrait = re.sub(r"i'm","I'm",extrait)
    #extrait = re.sub(r"man's","man is",extrait)
    #extrait = re.sub(r"it's","it is",extrait)
    #extrait = re.sub(r"didn't","did not",extrait)
    #extrait = re.sub(r"can't","cannot",extrait)
    #extrait = re.sub(r"earth's","earth is",extrait)
    #extrait = re.sub(r"father's","father is",extrait)
    #extrait = re.sub(r"i'll","I will",extrait)
    #extrait = re.sub(r"i've","I have",extrait)
    #extrait = re.sub(r"i\'",r"I'",extrait)
    #extrait = re.sub(r"children\'s","children is",extrait)
    
    return extrait 

In [7]:
train["cleaned_excerpt"] = train["excerpt"].map(clean_text)
test["cleaned_excerpt"] = test["excerpt"].map(clean_text)

In [8]:
datas["cleaned_excerpt"] = datas["excerpt"].map(clean_text)

In [9]:
texts = "\n".join(datas["cleaned_excerpt"].values)

In [10]:
model_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
with open("./texts.txt","w") as f :
    f.write(texts)

In [12]:
dataset = LineByLineTextDataset(tokenizer = tokenizer,file_path ="./texts.txt",block_size=256 )
val_dataset = LineByLineTextDataset(tokenizer = tokenizer,file_path="./texts.txt",block_size=256)

In [13]:
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm = True,mlm_probability=0.15)

In [14]:
training_args = TrainingArguments(output_dir="./bert_base_chk",
                                 overwrite_output_dir=True,
                                 num_train_epochs=3,
                                 per_device_train_batch_size=16,
                                 evaluation_strategy= 'steps',
                                 save_total_limit=0, 
                                 eval_steps=150,
                                 metric_for_best_model = 'eval_loss',
                                 greater_is_better=False,
                                 load_best_model_at_end =True,
                                 prediction_loss_only=True,
                                 report_to = "none")

In [15]:
trainer = Trainer(model = model,
                  args = training_args,
                  data_collator = collator,
                  train_dataset = dataset,
                  eval_dataset= val_dataset)

In [16]:
trainer.train()
trainer.save_model("./bert_base_chk")

Step,Training Loss,Validation Loss,Runtime,Samples Per Second
150,No log,1.915052,53.8083,132.21
300,No log,1.830933,53.7622,132.324
450,No log,1.797686,53.8535,132.099
600,2.032400,1.801429,53.8795,132.035
750,2.032400,1.735964,53.9876,131.771
900,2.032400,1.718746,53.9131,131.953
1050,1.917900,1.704699,53.8547,132.096
1200,1.917900,1.70023,53.8798,132.035
