In [None]:
from transformers import BertTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, BertForMaskedLM
from datasets import load_dataset
import datasets
import torch
import re
from nltk.tokenize import sent_tokenize
import nltk
import accelerate

nltk.download('punkt_tab')


In [None]:
#Load dataset and choose the model we want to use
#TODO Replace with research dataset
dataset = load_dataset("ErikCikalleshi/new_york_times_news_1987_1995", split='test[:1%]')
model_name = "bert-base-uncased"


In [None]:
# get all dates from dataset, this function currently does years, will need to update to decades
def get_date_tokens(dataset: datasets.Dataset):
    unique_dates = list(set(sorted(dataset['date'])))
    custom_date_tokens = [f"<year_{d}>" for d in unique_dates]
    return custom_date_tokens



In [None]:
#create the tokenizer and add custom tokens
#extra_special_tokens tag is for any non-standard special tokens so we'll use it for all the dates
date_tokens = get_date_tokens(dataset)
tokenizer = BertTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'extra_special_tokens' : date_tokens})

In [None]:
#append date tokens to the start of all sentences and create simplified dataset

#1. create a dict to hold samples
sentenceData = {'text': []}
#2. iterate entries
for entry in dataset:
    text = entry['content'] # type: ignore
    date = entry['date'] # type: ignore
    #3. split entry into sentences.
    for sentence in sent_tokenize(text):
        #4. append date token, limit sentence length to the bert maximum input size, and add to new dataset
        sentence = f'<year_{date}> '+ sentence
        sentence = sentence[:min(512, len(sentence))]
        sentenceData['text'].append(sentence)

#5. create cleaned dataset from sample dict
tokenized_dataset = datasets.Dataset.from_dict(sentenceData)


In [None]:
#tokenizer function used to map cleaned samples to tokenizer token ids
def tokenize_data(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result



In [None]:
# Create Data collator for Masked Language Modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

# Tokenize dataset
tokenized_dataset = tokenized_dataset.map(tokenize_data, batch_size=64, batched=True )

# Load pre-trained model and resize to the custom tokenizer
model = BertForMaskedLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

In [None]:
#Train using 
training_args = TrainingArguments(
    output_dir="./NYT_pretrained_model",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    save_steps=10_000,
    save_total_limit=2,
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

# Pretrain the model
trainer.train()