In [None]:
import torch
from torch import nn
from transformers import pipeline, DataCollatorForLanguageModeling, Trainer, TrainingArguments, BertTokenizer, BertForMaskedLM
from datasets import load_dataset
import random

Goal: continued pretraining with BERT

In [None]:
#TEMP BLOCK FOR TESTING DATA SET + TOKENIZER
ds = load_dataset("ErikCikalleshi/new_york_times_news_1987_1995")



In [None]:
#TEMP BLOCK FOR TESTING DATA SET + TOKENIZER
print(ds["train"][:10])
unique_dates = list(set(sorted(ds['train']['date'])))
print(unique_dates)
custom_date_tokens = [f"<year_{d}>" for d in unique_dates]
print(custom_date_tokens)

In [None]:
#tokenizer + model init


model_name = "bert-base-uncased"
custom_token = custom_date_tokens

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)



In [None]:
print(len(tokenizer))
# 1. Add the special token to the tokenizer
for i in custom_token:

  tokenizer.add_special_tokens({'additional_special_tokens': [i]})

# 2. Resize the model embeddings to account for the new token
model.resize_token_embeddings(len(tokenizer))

In [None]:
#add date token to data text and tokenize
data = load_dataset("ErikCikalleshi/new_york_times_news_1987_1995", split='test[:1%]')
for row in data:
 row['content'] = f'<year_{row['date']}>' + row['content']


In [None]:
#add date token to data text and tokenize
#Should tokenize to token id's now

i = 0
def add_date_tokens(examples):
    examples['content'] = f'<year_{examples['date']}> ' + examples['content']
    return examples


def tokenize_function(examples):
  result = tokenizer(examples["content"])
  if tokenizer.is_fast:
      result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]

  return result


# Apply to dataset - ensure you don't remove the 'date' column until after mapping
data = data.map(add_date_tokens)
tokenized_dataset = data.map(tokenize_function, batched=True, remove_columns=data.column_names)
print(tokenized_dataset.to_list())

In [None]:
def _get_next_sentence(sentence, next_sentence, paragraphs):
    if random.random() < 0.5:
        is_next = True
    else:
        # `paragraphs` is a list of lists of lists
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentence, next_sentence, is_next

def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    nsp_data_from_paragraph = []
    for i in range(len(paragraph) - 1):
        tokens_a, tokens_b, is_next = _get_next_sentence(
            paragraph[i], paragraph[i + 1], paragraphs)
        # Consider 1 '<cls>' token and 2 '<sep>' tokens
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
            continue
        tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segments, is_next))
    return nsp_data_from_paragraph

In [None]:
#use these to set up the missing tokens to bert to predict
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
                        vocab):
    # For the input of a masked language model, make a new copy of tokens and
    # replace some of them by '<mask>' or random tokens
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_labels = []
    # Shuffle for getting 15% random tokens for prediction in the masked
    # language modeling task
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token = None
        # 80% of the time: replace the word with the '<mask>' token
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            # 10% of the time: keep the word unchanged
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            # 10% of the time: replace the word with a random word
            else:
                masked_token = random.choice(vocab.idx_to_token)
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_labels.append(
            (mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels

def _get_mlm_data_from_tokens(tokens, vocab):
  candidate_pred_positions = []
  # `tokens` is a list of strings
  for i, token in enumerate(tokens):
      # Special tokens are not predicted in the masked language modeling
      # task
      if token in ['<cls>', '<sep>']:
          continue
      candidate_pred_positions.append(i)
  # 15% of random tokens are predicted in the masked language modeling task
  num_mlm_preds = max(1, round(len(tokens) * 0.15))
  mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
      tokens, candidate_pred_positions, num_mlm_preds, vocab)
  pred_positions_and_labels = sorted(pred_positions_and_labels,
                                      key=lambda x: x[0])
  pred_positions = [v[0] for v in pred_positions_and_labels]
  mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
  return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]



https://d2l.ai/chapter_natural-language-processing-pretraining/bert-dataset.html#generating-the-masked-language-modeling-task
^ article on pretraining

We need to follow the masked language prediction steps I believe

In [None]:
class nyt_dataset(torch.utils.data.Dataset):
  def __init__(self, samples, max_len, tokenizer):
    samples = self.prep_content(samples)
    print(samples)
    samples = self.tokenize_samples(tokenizer, samples)

  def prep_content(self, examples):
      sentences = []
      for example in examples:
        paragraph = example['content'].split('.')
        paragraph = [f'<year_{example['date']}> {sentence}' for sentence in paragraph]
        sentences.append([sentence for sentence in paragraph])
      if '' in sentences:
        sentences.remove('')
      return sentences

  def tokenize_samples(self, tokenizer, samples):
    result = tokenizer(text=str(samples), padding='max_length')
    print(result)
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]

nyt_dataset(data, 512, tokenizer)

In [None]:
#pretraining functions
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
                         segments_X, valid_lens_x,
                         pred_positions_X, mlm_weights_X,
                         mlm_Y, nsp_y):
    # Forward pass
    _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                                  valid_lens_x.reshape(-1),
                                  pred_positions_X)
    # Compute masked language model loss
    mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
    mlm_weights_X.reshape(-1, 1)
    mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
    # Compute next sentence prediction loss
    nsp_l = loss(nsp_Y_hat, nsp_y)
    l = mlm_l + nsp_l
    return mlm_l, nsp_l, l


def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
    net(*next(iter(train_iter))[:4])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    trainer = torch.optim.Adam(net.parameters(), lr=0.01)
    step = 0

    num_steps_reached = False
    while step < num_steps and not num_steps_reached:
        for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
            mlm_weights_X, mlm_Y, nsp_y in train_iter:
            tokens_X = tokens_X.to(devices[0])
            segments_X = segments_X.to(devices[0])
            valid_lens_x = valid_lens_x.to(devices[0])
            pred_positions_X = pred_positions_X.to(devices[0])
            mlm_weights_X = mlm_weights_X.to(devices[0])
            mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
            trainer.zero_grad()
            mlm_l, nsp_l, l = _get_batch_loss_bert(
                net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
                pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
            l.backward()
            trainer.step()

            step += 1
            if step == num_steps:
                num_steps_reached = True
                break



In [None]:
#pretraining


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

training_args = TrainingArguments(
    output_dir="./content/Training Data",
    num_train_epochs=3,
    remove_unused_columns=True,
    per_device_train_batch_size=8,
    save_steps=10_000,
    save_total_limit=2,
    logging_steps=500,
    learning_rate=5e-5,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset,
)

# Start pretraining
trainer.train()