In [None]:
from transformers import BertTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from datasets import load_dataset
import datasets
import torch
import re
from nltk.tokenize import sent_tokenize
import nltk
import accelerate

nltk.download('punkt_tab')

In [None]:
dataset = load_dataset("ErikCikalleshi/new_york_times_news_1987_1995", split='test[:1%]')

In [None]:
unique_dates = list(set(sorted(dataset['date'])))
custom_date_tokens = [f"<year_{d}>" for d in unique_dates]
model_name = "bert-base-uncased"
custom_token = custom_date_tokens
tokenizer = BertTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'extra_special_tokens' : custom_date_tokens})


In [None]:
sentenceData = {'text': []}
for entry in dataset:
    text = entry['content'] # type: ignore
    date = entry['date'] # type: ignore
    for sentence in sent_tokenize(text):
        sentence = f'<year_{date}> '+ sentence
        sentence = sentence[:min(512, len(sentence))]
        sentenceData['text'].append(sentence)
tokenized_dataset = datasets.Dataset.from_dict(sentenceData)


In [None]:
def tokenize_data(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result



In [None]:
# Data collator for Masked Language Modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

tokenized_dataset = tokenized_dataset.map(tokenize_data, batch_size=64, batched=True )

# Load pre-trained model
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
model.resize_token_embeddings(tokenizer.vocab_size)

In [None]:
training_args = TrainingArguments(
    output_dir="./domain_pretrained_model",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    save_steps=10_000,
    save_total_limit=2,
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

# Pretrain the model
trainer.train()