<a href="https://colab.research.google.com/github/Servat0r/HLT-Project-2023/blob/master/utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Common utilities for the HLT project.

In [None]:
!pip install --quiet "transformers[sentencepiece]" "transformers[torch]" datasets evaluate openai python-dotenv bert_score rouge_score bert_score

Collecting transformers[sentencepiece]
  Downloading transformers-4.32.0-py3-none-any.whl (7.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.5/7.5 MB[0m [31m52.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers[sentencepiece])
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers[sentencepiece])
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m103.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers[sentencepiece])
  Downloading safetensors-0.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Dataset Loading and Preprocessing

In [None]:
import os
def get_dataset(hf_url, local_train_path, local_eval_path, local_test_path):
  if os.path.exists(local_train_path) and os.path.exists(local_eval_path) and os.path.exists(local_test_path):
    train_dataset = load_from_disk(local_train_path)
    eval_dataset = load_from_disk(local_eval_path)
    test_dataset = load_from_disk(local_test_path)
    return {'local': True, 'train': train_dataset, 'eval': eval_dataset, 'test': test_dataset}
  else:
    datasets = load_dataset(hf_url)
    return {'local': False, 'all': datasets}

In [None]:
def get_maximum_labels_length(dataset, question_column='question'):
  tokenized_dataset_lengths = [len(tokenizer.tokenize(sample)) for sample in dataset[question_column]]
  return max(tokenized_dataset_lengths)

In [None]:
MAX_INPUTS_LENGTH = 512
MAX_LABELS_LENGTH = 64
def tokenizer_function(
    samples, max_inputs_length=MAX_INPUTS_LENGTH, max_labels_length=MAX_LABELS_LENGTH,
    input_ids_padding=True, train_dataset=None, ignore_index_id=-100,
):
  max_labels_length = max_labels_length if not train_dataset else get_maximum_labels_length(train_dataset)
  input_tokenized = tokenizer(samples['answer_context'], padding=input_ids_padding, max_length=max_inputs_length, truncation=True, return_tensors='pt')
  labels_tokenized = tokenizer(samples['question'], padding="max_length", max_length=max_labels_length, truncation=True, return_tensors='pt')
  labels, masks = labels_tokenized['input_ids'], labels_tokenized['attention_mask']
  argmin_masks = torch.argmin(masks, dim=-1)
  for index in range(len(argmin_masks)):
    if masks[index][argmin_masks[index]] == 0 and ignore_index_id != 0:
      labels[index][argmin_masks[index]:] = ignore_index_id
  input_tokenized['labels'] = labels
  return input_tokenized

78


In [None]:
def build_answers_squad_it(sample, new_dataset):
  answers_texts = list(set(sample['answers']['text']))
  for answer_text in answers_texts:
    new_dataset['answer'].append(answer_text)
    new_dataset['question'].append(sample['question'])
    new_dataset['context'].append(sample['context'])

In [None]:
def build_train_feature(sample, use_extra_ids=False, context_label='paragraph_answer'):
  return {'answer_context': f"generate questions: <answer> {sample['answer']} <answer> <context> {sample[context_label]} <context>"}

In [None]:
def build_train_feature_squad_it(sample, use_extra_ids=False):
  return build_train_feature(sample, use_extra_ids=use_extra_ids, context_label='context')

In [None]:
def load_and_preprocess_squad_it_dataset(
    dataset_name='squad_it', train_dataset_name='squad_it_qg_train',
    eval_dataset_name='squad_it_qg_eval', test_dataset_name='squad_it_qg_test',
    shuffle_seed=None, train_select=None, eval_select=None, train_dataset_split=0.8,
    use_extra_ids=False,
):
  dataset_loading_result = get_dataset(dataset_name, train_dataset_name, eval_dataset_name, test_dataset_name)
  local = dataset_loading_result['local']
  if local:
    train_dataset = dataset_loading_result['train']
    validation_dataset = dataset_loading_result['eval']
    test_dataset = dataset_loading_result['test']
  if not local:
    datasets = load_dataset(dataset_name)
    dev_dataset = datasets['train'].remove_columns(['id'])
    test_dataset = datasets['test'].remove_columns(['id'])

    new_test_dataset = {'answer': [], 'question': [], 'context': []}
    test_dataset.map(lambda sample: build_answers_squad_it(sample, new_test_dataset))
    test_dataset = Dataset.from_dict(new_test_dataset)

    new_dev_dataset = {'answer': [], 'question': [], 'context': []}
    dev_dataset.map(lambda sample: build_answers_squad_it(sample, new_dev_dataset))
    del dev_dataset
    dev_dataset = Dataset.from_dict(new_dev_dataset)

    train_dataset_length = int(train_dataset_split * dev_dataset_length) + 1 if 0 <= train_dataset_split <= 1 else int(train_dataset_split)

    train_dataset = dev_dataset.shuffle(seed=shuffle_seed).select(range(train_dataset_length))
    validation_dataset = dev_dataset.shuffle(seed=shuffle_seed).select(range(train_dataset_length, dev_dataset_length))
    print(f"Train dataset has {len(train_dataset)} items. Validation dataset has {len(validation_dataset)} items.")

    train_dataset.save_to_disk(train_dataset_name)
    validation_dataset.save_to_disk(eval_dataset_name)
    test_dataset.save_to_disk(test_dataset_name)

  if train_select:
    train_dataset = train_dataset.shuffle(seed=shuffle_seed).select(range(train_select))
  if eval_select:
    validation_dataset = validation_dataset.shuffle(seed=shuffle_seed).select(range(eval_select))

  build_train_feature = lambda sample: build_train_feature_squad_it(sample, use_extra_ids=use_extra_ids)
  train_dataset = train_dataset.map(build_train_feature).remove_columns(['answer', 'context'])
  validation_dataset = validation_dataset.map(build_train_feature).remove_columns(['answer', 'context'])
  test_dataset = test_dataset.map(build_train_feature).remove_columns(['answer', 'context'])

  tokenized_train_dataset = train_dataset.map(tokenizer_function, batched=True).remove_columns(['answer_context', 'question'])
  tokenized_validation_dataset = validation_dataset.map(tokenizer_function, batched=True).remove_columns(['answer_context', 'question'])
  tokenized_test_dataset = test_dataset.map(lambda samples: tokenizer_function(samples, input_ids_padding="max_length", train_dataset=train_dataset), batched=True).remove_columns(['answer_context'])

  tokenized_train_dataset.set_format("torch")
  tokenized_validation_dataset.set_format("torch")
  tokenized_test_dataset.set_format("torch")

  return (train_dataset, validation_dataset, test_dataset), (tokenized_train_dataset, tokenized_validation_dataset, tokenized_test_dataset)

In [None]:
def build_train_feature_lmqg_squad(sample, use_extra_ids=False):
  return build_train_feature(sample, use_extra_ids=use_extra_ids, context_label='paragraph')

In [None]:
def load_and_preprocess_lmqg_squad_dataset(
    dataset_name='lmqg/qg_squad', train_dataset_name='lmqg_qg_squad_train',
    eval_dataset_name='lmqg_qg_squad_eval', test_dataset_name='lmqg_qg_squad_test',
    shuffle_seed=None, train_select=None, eval_select=None, use_extra_ids=False,
):
  dataset_loading_result = get_dataset(dataset_name, train_dataset_name, eval_dataset_name, test_dataset_name)
  local = dataset_loading_result['local']
  if local:
    train_dataset = dataset_loading_result['train']
    validation_dataset = dataset_loading_result['eval']
    test_dataset = dataset_loading_result['test']
  if not local:
    datasets = load_dataset(dataset_name)
    train_dataset = datasets['train'].remove_columns(['paragraph_question', 'sentence', 'sentence_answer', 'paragraph_answer', 'paragraph_sentence'])
    validation_dataset = datasets['validation'].remove_columns(['paragraph_question', 'sentence', 'sentence_answer', 'paragraph_answer', 'paragraph_sentence'])
    test_dataset = datasets['test'].remove_columns(['paragraph_question', 'sentence', 'sentence_answer', 'paragraph_answer', 'paragraph_sentence'])
    print(f"Train dataset has {len(train_dataset)} items. Validation dataset has {len(validation_dataset)} items.")

    train_dataset.save_to_disk(train_dataset_name)
    validation_dataset.save_to_disk(eval_dataset_name)
    test_dataset.save_to_disk(test_dataset_name)

  if train_select:
    train_dataset = train_dataset.shuffle(seed=0).select(range(train_select))
  if eval_select:
    validation_dataset = validation_dataset.shuffle(seed=0).select(range(eval_select))

  build_train_feature = lambda sample: build_train_feature_lmqg_squad(sample, use_extra_ids=use_extra_ids)
  train_dataset = train_dataset.map(build_train_feature).remove_columns(['answer', 'paragraph'])
  validation_dataset = validation_dataset.map(build_train_feature).remove_columns(['answer', 'paragraph'])
  test_dataset = test_dataset.map(build_train_feature).remove_columns(['answer', 'paragraph'])

  tokenized_train_dataset = train_dataset.map(tokenizer_function, batched=True).remove_columns(['answer_context', 'question'])
  tokenized_validation_dataset = validation_dataset.map(tokenizer_function, batched=True).remove_columns(['answer_context', 'question'])
  tokenized_test_dataset = test_dataset.map(lambda samples: tokenizer_function(samples, input_ids_padding="max_length"), batched=True).remove_columns(['answer_context'])

  tokenized_train_dataset.set_format("torch")
  tokenized_validation_dataset.set_format("torch")
  tokenized_test_dataset.set_format("torch")

  return (train_dataset, validation_dataset, test_dataset), (tokenized_train_dataset, tokenized_validation_dataset, tokenized_test_dataset)

In [None]:
def build_train_feature_lmqg_squad_highlighting(sample, use_extra_ids=False):
  return build_train_feature(sample, use_extra_ids=use_extra_ids, context_label='paragraph_answer')

In [None]:
def load_and_preprocess_lmqg_squad_dataset_highlighting(
    dataset_name='lmqg/qg_squad', train_dataset_name='lmqg_qg_squad_highlighting_train',
    eval_dataset_name='lmqg_qg_squad_highlighting_eval', test_dataset_name='lmqg_qg_squad_highlighting_test',
    shuffle_seed=None, train_select=None, eval_select=None, use_extra_ids=False,
):
  dataset_loading_result = get_dataset(dataset_name, train_dataset_name, eval_dataset_name, test_dataset_name)
  local = dataset_loading_result['local']
  if local:
    train_dataset = dataset_loading_result['train']
    validation_dataset = dataset_loading_result['eval']
    test_dataset = dataset_loading_result['test']
  if not local:
    datasets = load_dataset(dataset_name)
    print(datasets['test'])
    train_dataset = datasets['train'].remove_columns(['paragraph_question', 'sentence', 'sentence_answer', 'paragraph', 'paragraph_sentence'])
    validation_dataset = datasets['validation'].remove_columns(['paragraph_question', 'sentence', 'sentence_answer', 'paragraph', 'paragraph_sentence'])
    test_dataset = datasets['test'].remove_columns(['paragraph_question', 'sentence', 'sentence_answer', 'paragraph', 'paragraph_sentence'])
    print(f"Train dataset has {len(train_dataset)} items. Validation dataset has {len(validation_dataset)} items.")

    train_dataset.save_to_disk(train_dataset_name)
    validation_dataset.save_to_disk(eval_dataset_name)
    test_dataset.save_to_disk(test_dataset_name)

  if train_select:
    train_dataset = train_dataset.shuffle(seed=0).select(range(train_select))
  if eval_select:
    validation_dataset = validation_dataset.shuffle(seed=0).select(range(eval_select))

  build_train_feature = lambda sample: build_train_feature_lmqg_squad_highlighting(sample, use_extra_ids=use_extra_ids)
  train_dataset = train_dataset.map(build_train_feature).remove_columns(['answer', 'paragraph_answer'])
  validation_dataset = validation_dataset.map(build_train_feature).remove_columns(['answer', 'paragraph_answer'])
  test_dataset = test_dataset.map(build_train_feature).remove_columns(['answer', 'paragraph_answer'])

  tokenizer_function_lambda = lambda sample: tokenizer_function(sample, train_dataset=train_dataset)
  tokenized_train_dataset = train_dataset.map(tokenizer_function_lambda, batched=True).remove_columns(['answer_context', 'question'])
  tokenized_validation_dataset = validation_dataset.map(tokenizer_function_lambda, batched=True).remove_columns(['answer_context', 'question'])
  tokenized_test_dataset = test_dataset.map(lambda samples: tokenizer_function(samples, input_ids_padding="max_length"), batched=True).remove_columns(['answer_context', 'question'])

  tokenized_train_dataset.set_format("torch")
  tokenized_validation_dataset.set_format("torch")
  tokenized_test_dataset.set_format("torch")

  return (train_dataset, validation_dataset, test_dataset), (tokenized_train_dataset, tokenized_validation_dataset, tokenized_test_dataset)

In [None]:
def load_and_preprocess_squad_qg_dataset(
    dataset_name='derek-thomas/squad-v1.1-t5-question-generation', train_dataset_name='squad_qg_train',
    eval_dataset_name='squad_qg_eval', test_dataset_name='squad_qg_test', shuffle_seed=None,
    train_select=None, eval_select=None, use_extra_ids=False, eval_split=0.3,
):
  dataset_loading_result = get_dataset(dataset_name, train_dataset_name, eval_dataset_name, test_dataset_name)
  local = dataset_loading_result['local']
  if local:
    train_dataset = dataset_loading_result['train']
    test_dataset = dataset_loading_result['test']
  if not local:
    datasets = load_dataset(dataset_name)
    dev_dataset, test_dataset = datasets['train'], datasets['validation']
    print(f"Dev dataset has {len(dev_dataset)} items. Test dataset has {len(test_dataset)} items.")

    eval_length = int(0.2 * len(dev_dataset))
    train_length = len(dev_dataset) - eval_length

    dev_dataset = dev_dataset.shuffle(seed=shuffle_seed)
    train_dataset = dev_dataset.select(range(train_length))
    validation_dataset = dev_dataset.select(range(train_length, train_length + eval_length))

    train_dataset.save_to_disk(train_dataset_name)
    validation_dataset.save_to_disk(eval_dataset_name)
    test_dataset.save_to_disk(test_dataset_name)

  if train_select:
    train_dataset = train_dataset.shuffle(seed=0).select(range(train_select))
  if eval_select:
    validation_dataset = validation_dataset.shuffle(seed=0).select(range(eval_select))

  tokenized_train_dataset = train_dataset.map(tokenizer_function, batched=True).remove_columns(['context', 'question'])
  tokenized_validation_dataset = validation_dataset.map(tokenizer_function, batched=True).remove_columns(['context', 'question'])
  tokenized_test_dataset = test_dataset.map(lambda samples: tokenizer_function(samples, input_ids_padding="max_length"), batched=True).remove_columns(['context', 'question'])

  tokenized_train_dataset.set_format("torch")
  tokenized_validation_dataset.set_format("torch")
  tokenized_test_dataset.set_format("torch")

  return (train_dataset, validation_dataset, test_dataset), (tokenized_train_dataset, tokenized_validation_dataset, tokenized_test_dataset)

### Metric loading, configuration and computation

In [None]:
def std_conversion_predictions(data, tokenizer):
  return tokenizer.batch_decode(data, skip_special_tokens=True)

def std_conversion_references(data, tokenizer):
  data[data == -100] = tokenizer.pad_token_id
  return [[reference] for reference in tokenizer.batch_decode(data, skip_special_tokens=True)]

def get_bleu_config(tokenizer):
  return evaluate.load('bleu'), 'text', 'text'

def get_nist_config(tokenizer):
  return evaluate.load('nist_mt'), 'text', 'text'

def get_rouge_config(tokenizer):
  return evaluate.load('rouge'), 'text', 'text'

### Training Configuration

In [None]:
from torch.utils.data import DataLoader
from transformers import get_scheduler

def get_training_configuration(
    *, optimizer='adam', learning_rate=1e-3, train_collate_fn=None,
    eval_collate_fn=None, tokenizer=None, train_batch_size=8, eval_batch_size=8,
    num_epochs=3, lr_scheduler='linear', num_warmup_steps=0,
):
  # Handle optimizer
  if optimizer == 'adam':
    optimizer = AdamW(model.parameters(), lr=learning_rate)
  elif isinstance(optimizer, str):  # Allow for other optimizer objects defined outside to be used
    raise ValueError(f"Unknown optimizer: '{optimizer}'")
  # Handle collators
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  train_collate_fn = train_collate_fn if train_collate_fn else data_collator
  eval_collate_fn = train_collate_fn if train_collate_fn else data_collator
  # Dataloaders
  train_dataloader = DataLoader(
      tokenized_train_dataset, shuffle=True, batch_size=train_batch_size, collate_fn=train_collate_fn)
  eval_dataloader = DataLoader(
      tokenized_validation_dataset, batch_size=eval_batch_size, collate_fn=eval_collate_fn)
  # Handle scheduling
  num_training_steps = num_epochs * len(train_dataloader)
  print(num_training_steps)
  if lr_scheduler == 'linear':
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )
  elif isinstance(lr_scheduler, str):  # As before for optimizers
    raise ValueError(f"Unknown scheduler: '{lr_scheduler}'")
  return optimizer, train_dataloader, eval_dataloader, lr_scheduler, num_training_steps


### Fine-tuning

In [None]:
import numpy as np
from evaluate import load
bertscore = load('bertscore')
def bertscore_f1based_score(grouped_predictions, references, verbose=False):
  m, n = len(grouped_predictions), len(grouped_predictions[0])
  scores = np.zeros((m, n), dtype=np.float64)
  grouped_predictions = np.array(grouped_predictions)
  for i in range(n):
    current_predictions = grouped_predictions[:, i]
    if verbose:
      print("Current Predictions:", current_predictions, "Current References:", references, '\n', sep='\n')
    current_scores = bertscore.compute(predictions=current_predictions, references=references, lang='en')
    scores[:, i] = current_scores['f1']
  return scores

In [None]:
import numpy as np
def select_best_output(
    model, tokenizer, input_ids, references, score_function, max_length, num_beams, top_k, top_p, num_candidates,
    verbose=False, tokenize_output=False,
):
  num_sequences = len(input_ids)
  model.eval()
  with torch.no_grad():
    predictions = model.generate(input_ids, max_length=max_length, num_beams=num_beams, top_k=top_k, top_p=top_p, num_return_sequences=num_candidates)
  decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  grouped_predictions = []
  for i in range(0, len(decoded_predictions), num_beams):
    grouped_predictions.append(decoded_predictions[i:i+num_beams])
  scores = score_function(grouped_predictions, references, verbose=verbose)
  if verbose:
    print(scores)
  scores = np.argmax(scores, axis=-1)
  final_predictions = [item_predictions[score_index] for item_predictions, score_index in zip(grouped_predictions, scores)]
  if tokenize_output:
    final_predictions = tokenizer(final_predictions, padding=True, max_length=512, truncation=True, return_tensors='pt')
  return final_predictions

In [None]:
from tqdm.auto import tqdm
import torch

def evaluation_loop(
    model, device, optimizer, eval_dataloader, lr_scheduler,
    loss_tracker, metrics_tracker=None, metrics=None, progress_bar=None,
    tokenizer=None, num_beams=1, top_k=None, top_p=None, num_candidates=4,
    score_function=bertscore_f1based_score, tokenize_predictions_output=True,
):
    model.eval()
    current_loss = 0
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        if 'question' in batch:
          text_references = batch['question']
        else:
          labels_batch = torch.tensor(batch['labels'])
          labels_batch[labels_batch == -100] = tokenizer.pad_token_id
          text_references = tokenizer.batch_decode(labels_batch, skip_special_tokens=True)
        with torch.no_grad():
            outputs = model(**batch)
        current_loss = outputs.loss.item()
        if metrics:
          predictions = select_best_output(
              model, tokenizer, batch['input_ids'], text_references, score_function, max_length=100, num_beams=num_beams,
              top_k=top_k, top_p=top_p, num_candidates=num_candidates, verbose=False, tokenize_output=tokenize_predictions_output,
          )
          for metric_name, (metric, conversion_function_predictions, conversion_function_references) in metrics.items():
            conversion_function_predictions = std_conversion_predictions if conversion_function_predictions == 'text' else conversion_function_predictions
            converted_predictions = conversion_function_predictions(predictions) if tokenize_predictions_output else predictions
            references = text_references if conversion_function_references == 'text' else conversion_function_references(batch["labels"])
            metric.add_batch(predictions=converted_predictions, references=references)
        if progress_bar:
          progress_bar.update(1)
    loss_tracker.append(current_loss)
    if metrics:
      metrics_tracker.append({
        metric_name: metric.compute() for metric_name, (metric, _, _) in metrics.items()
      })
      print(f"Metrics = {metrics_tracker[-1]}")
    return current_loss


def main_training_loop(
    model, device, optimizer, train_dataloader, eval_dataloader,
    lr_scheduler, num_training_steps, num_epochs, metrics=None,
    eval_strategy='epoch', eval_every=1000, model_save_path='model',
    early_stopping=False, early_stopping_min_delta=1e-3, early_stopping_patience=5,
    tokenizer=None, num_beams=1, top_k=None, top_p=None, num_candidates=4,
    score_function=bertscore_f1based_score, tokenize_predictions_output=True,
    start_epoch=0,
):
  if eval_strategy == 'epoch':
    num_evaluation_steps = num_epochs * len(eval_dataloader)
  elif eval_strategy == 'steps':
    num_evaluation_steps = (num_training_steps // eval_every + num_epochs * len(eval_dataloader)) * len(eval_dataloader)
  elif (not eval_strategy) or (eval_strategy == 'no'):
    num_evaluation_steps = 0
  else:
    raise ValueError(f"Unknown evaluation strategy: '{eval_strategy}'")
  _flag_evaluate_steps = eval_strategy == 'steps'
  _flag_evaluate_epochs = eval_strategy == 'epoch'

  training_progress_bar = tqdm(range(num_training_steps))
  evaluation_progress_bar = tqdm(range(num_evaluation_steps)) if num_evaluation_steps > 0 else None

  batch_train_losses = []
  epoch_train_losses = []
  epoch_eval_losses = []
  steps_eval_losses = []
  epoch_eval_metrics = []
  steps_eval_metrics = []
  step = 1
  # Early Stopping parameters
  best_early_stopping_value = float('inf')
  best_early_stopping_checkpoint = 0
  best_early_stopping_type = None
  current_early_stopping_patience = 0

  stop_training = False
  for epoch in range(start_epoch, num_epochs + start_epoch):
      if stop_training:
        break
      model.train()
      current_loss = 0
      batch_train_losses.append([])
      for batch in train_dataloader:
          if stop_training:
            break
          batch = {k: v.to(device) for k, v in batch.items()}
          outputs = model(**batch)
          loss = outputs.loss
          current_loss = loss.item()
          batch_train_losses[-1].append(current_loss)
          loss.backward()

          optimizer.step()
          lr_scheduler.step()
          optimizer.zero_grad()
          training_progress_bar.update(1)
          if step % eval_every == 0 and _flag_evaluate_steps:
            eval_loss = evaluation_loop(
                model, device, optimizer, eval_dataloader, lr_scheduler,
                steps_eval_losses, steps_eval_metrics, metrics, evaluation_progress_bar,
                tokenizer=tokenizer, num_beams=num_beams, top_k=top_k, top_p=top_k, num_candidates=num_candidates,
                score_function=score_function, tokenize_predictions_output=tokenize_predictions_output,
            )
            if metrics:
              print(f"Step {step}: Train Loss = {current_loss}, Eval Loss = {eval_loss}, Metrics = {steps_eval_metrics[-1]}")
            else:
              print(f"Step {step}: Train Loss = {current_loss}, Eval Loss = {eval_loss}")
            try:
              save_model = input('Save this model (y/n)?> ')
              save_model = save_model == 'y'
            except:
              save_model = False
            if save_model:
              model.save_pretrained(f'{model_save_path}_step{step}')
            try:
              continue_training = input('Continue training (y/n)?> ')
              continue_training = continue_training == 'y'
            except:
              continue_training = True
            stop_training = not continue_training
            if eval_loss + early_stopping_min_delta <= best_early_stopping_value:
              best_early_stopping_value = eval_loss
              best_early_stopping_checkpoint = step
              best_early_stopping_type = 'step'
              current_early_stopping_patience = 0
            else:
              current_early_stopping_patience += 1
              if current_early_stopping_patience >= early_stopping_patience:
                stop_training = True
            model.train()
          step += 1
      if stop_training:
        break
      torch.cuda.synchronize()  # Sure?
      epoch_train_losses.append(current_loss)
      # If we don't set to evaluate to epoch, we however do an evaluation steps to register epoch loss
      epoch_eval_metrics_now = epoch_eval_metrics if _flag_evaluate_epochs else None
      metrics_now = metrics if _flag_evaluate_epochs else None
      eval_loss = evaluation_loop(
          model, device, optimizer, eval_dataloader, lr_scheduler,
          epoch_eval_losses, epoch_eval_metrics_now, metrics_now, evaluation_progress_bar,
          tokenizer=tokenizer, num_beams=num_beams, top_k=top_k, top_p=top_k, num_candidates=num_candidates,
          score_function=score_function, tokenize_predictions_output=tokenize_predictions_output,
      )
      if metrics_now:
        print(f"Epoch {epoch}: Train Loss = {current_loss}, Eval Loss = {epoch_eval_losses[-1]}, Metrics = {epoch_eval_metrics[-1]}")
      else:
        print(f"Epoch {epoch}: Train Loss = {current_loss}, Eval Loss = {epoch_eval_losses[-1]}")
      if eval_loss + early_stopping_min_delta <= best_early_stopping_value:
        best_early_stopping_value = eval_loss
        best_early_stopping_checkpoint = epoch
        best_early_stopping_type = 'epoch'
        current_early_stopping_patience = 0
      else:
        current_early_stopping_patience += 1
        if current_early_stopping_patience >= early_stopping_patience:
          stop_training = True
      try:
        save_model = input('Save this model (y/n)?> ')
        save_model = save_model == 'y'
      except:
        save_model = False
      if save_model:
        model.save_pretrained(f'{model_save_path}_epoch{epoch}')
      try:
        continue_training = input('Continue training (y/n)?> ')
        continue_training = continue_training == 'y'
      except:
        continue_training = True
      stop_training = not continue_training
  model.eval()
  try:
    return {
        'batch_train_losses': batch_train_losses,
        'epoch_train_losses': epoch_train_losses,
        'epoch_eval_losses': epoch_eval_losses,
        'steps_eval_losses': steps_eval_losses,
        'epoch_eval_metrics': epoch_eval_metrics,
        'steps_eval_metrics': steps_eval_metrics,
        'best_early_stopping_checkpoint': best_early_stopping_checkpoint,
        'best_early_stopping_type': best_early_stopping_type,
        'epoch': epoch,
    }
  except:
    return {
        'batch_train_losses': batch_train_losses,
        'epoch_train_losses': epoch_train_losses,
        'epoch_eval_losses': epoch_eval_losses,
        'steps_eval_losses': steps_eval_losses,
        'epoch_eval_metrics': epoch_eval_metrics,
        'steps_eval_metrics': steps_eval_metrics,
        'best_early_stopping_checkpoint': best_early_stopping_checkpoint,
        'best_early_stopping_type': best_early_stopping_type,
    }

### Checkpointing and resuming

In [None]:
# Taken from: https://discuss.pytorch.org/t/moving-optimizer-from-cpu-to-gpu/96068/2
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

In [None]:
def save_checkpoint(checkpoint_path, model_path, optimizer, lr_scheduler, num_training_steps, model=None, save_model=False):
  checkpoint = {
      'model_path': model_path,
      'optimizer': optimizer.state_dict(),
      'lr_scheduler': lr_scheduler.state_dict(),
  }
  checkpoint['num_training_steps'] = num_training_steps - checkpoint['lr_scheduler']['_step_count'] + 1
  if save_model:
    checkpoint['model'] = model
  torch.save(checkpoint, checkpoint_path)

In [None]:
def load_checkpoint(checkpoint_path, device, optimizer, lr_scheduler, model_class, new_model_path=None):
  checkpoint = torch.load(checkpoint_path)
  if new_model_path:
    checkpoint['model_path'] = new_model_path
  model = model_class.from_pretrained(checkpoint['model_path'], local_files_only=True)
  model.to(device)
  optimizer.load_state_dict(checkpoint['optimizer'])
  num_training_steps = checkpoint['num_training_steps']
  lr_scheduler_dict = checkpoint['lr_scheduler']
  lr_scheduler.base_lrs = lr_scheduler_dict['base_lrs']
  lr_scheduler.last_epoch = lr_scheduler_dict['last_epoch']
  lr_scheduler._step_count = lr_scheduler_dict['_step_count']
  lr_scheduler._get_lr_called_within_step = lr_scheduler_dict['_get_lr_called_within_step']
  lr_scheduler._last_lr = lr_scheduler_dict['_last_lr']
  if device != torch.device('cpu'):
    optimizer_to(optimizer, device)
  return model, optimizer, lr_scheduler, num_training_steps

### Metrics

In [None]:
from evaluate import load
accuracy = load('accuracy')
test_loss_tracker=[]
test_metrics_tracker=[]
def accuracy_conversion_function(predictions):
  predicted_labels = torch.argmax(predictions, dim=-1)
  return predicted_labels

In [None]:
from sklearn.metrics import top_k_accuracy_score
top_2_accuracy = lambda y_score, y_true: top_k_accuracy_score(y_true, y_score, k=2, labels=np.arange(13))
top_3_accuracy = lambda y_score, y_true: top_k_accuracy_score(y_true, y_score, k=3, labels=np.arange(13))

In [None]:
def top_k_accuracy_conversion_function_predictions(logits):
  probabilities = torch.softmax(logits, dim=-1)
  return probabilities.clone().detach().cpu()

def top_k_accuracy_conversion_function_references(references):
  return references.cpu() # ????

In [None]:
class TopKAccuracyMetric:

  def __init__(self, k, num_classes):
    self.k = k
    self.num_classes = num_classes
    self.score_function = lambda y_score, y_true: top_k_accuracy_score(y_true, y_score, k=k, labels=np.arange(num_classes))
    self.batches = []

  def add_batch(self, predictions, references):
    score = self.score_function(predictions, references)
    self.batches.append(score)

  def compute(self, *args, **kwargs):
    result = np.mean(self.batches)
    self.clear()
    return result

  def clear(self):
    self.batches.clear()

In [None]:
def get_default_test_metrics(num_classes, include_top2=True):
  top2accuracy = TopKAccuracyMetric(k=2, num_classes=num_classes)
  result = {'accuracy': (accuracy, accuracy_conversion_function, 'id')}
  if include_top2:
    result['top_2_accuracy'] = (top2accuracy, top_k_accuracy_conversion_function_predictions, top_k_accuracy_conversion_function_references)
  return result

In [None]:
import numpy as np
def compute_bert_score(text_dataset, tokenized_dataset, model, device, tokenizer, batch_size=32, lang='en', model_type=None, max_length=200, num_beams=4, num_candidates=1):
  from evaluate import load
  from tqdm.auto import tqdm
  bert_score = load('bertscore')
  dataset_length = len(text_dataset)
  progress_bar = tqdm(range(dataset_length//batch_size + (dataset_length % batch_size != 0)))
  start = 0
  model.eval()
  results = {
      'precision': [],
      'recall': [],
      'f1': [],
  }
  while start < dataset_length:
    end = min(start + batch_size, dataset_length)
    batch = {
        'references': text_dataset['question'][start:end],
        'input_ids': tokenized_dataset['input_ids'][start:end].to(device),
    }
    with torch.no_grad():
      predictions = model.generate(batch['input_ids'], max_length=max_length, num_beams=num_beams, num_return_sequences=num_candidates)
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    bert_score_results = bert_score.compute(predictions=decoded_predictions, references=batch['references'], lang=lang, model_type=model_type)
    for metric in ['precision', 'recall', 'f1']:
      results[metric].append(np.mean(bert_score_results[metric]).item())
    start = end
    progress_bar.update(1)
  return {k: (np.mean(v), np.std(v)) for k, v in results.items()}