# Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install wandb --quiet
!pip install datasets transformers accelerate --quiet
!pip install intervaltree stanza textstat --quiet

In [None]:
# Base imports
import math, datetime, os, shutil, itertools, wandb, json, random, re, csv
import numpy as np
import xml.etree.ElementTree as ET
# Transformer model imports
from transformers import AutoTokenizer, DataCollatorForWholeWordMask, TrainingArguments, Trainer, AutoModelForMaskedLM, EarlyStoppingCallback
import datasets, transformers, torch
# NER imports
import nltk
import intervaltree, stanza, textstat
from intervaltree import Interval, IntervalTree

In [None]:
WANDB_API_KEY = '7425c67d3c5151a3744fe900a66cc0a3850c0858'
wandb.login(key=WANDB_API_KEY)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
MODEL_TITLE = MODEL_NAME.split('/')[-1]

REPO_PATH = 'drive/MyDrive/LiboMsc'
MODEL_PATH = f'{REPO_PATH}/{MODEL_TITLE}'

TRAIN_DATASET_PATH = f'{REPO_PATH}/data/i2b2_2024_T1_train'

CHUNK_SIZE = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
TEST_MASK_RATIOS = {
  # PHI is set to 1.0 and MED_* to 0.0
  # Every other missing annotation type is automatically set to 0.0
  'NOUN': 0.70,
  'VERB': 0.70,
  'ADJ': 0.70,
}

In [None]:
if not os.path.exists(MODEL_PATH):
    os.mkdir(MODEL_PATH)
    print('Folder created.')

# Dataset

### Load dataset

In [None]:
def parse_xml_file(file_path):
  ''' Takes an XML file path and extracts note_id and text. '''
  tree = ET.parse(file_path)
  root = tree.getroot()
  data = {
    'note_id': file_path.split('/')[-1].replace('.xml', ''),
    'text':    root.find('TEXT').text,
  }
  return data


def load_xml_folder(folder_path):
  ''' Takes a folder path and loads all XML files from it. '''
  data_list = []
  for filename in os.listdir(folder_path):
    if filename.endswith('.xml'):
      file_path = os.path.join(folder_path, filename)
      data = parse_xml_file(file_path)
      data_list.append(data)

  return data_list

In [None]:
# Load training dataset
raw_train_dataset = datasets.Dataset.from_list(load_xml_folder(TRAIN_DATASET_PATH))

# Split training dataset into training and validation
train_dataset = raw_train_dataset.train_test_split(test_size=0.2, seed=42)

### Extract annotations (PHIs, NER, and POS tags)

#### Prepare PHIs

In [None]:
# Download Philter and dependencies
!git clone https://github.com/BCHSI/philter-deidstable1_mirror.git ./philter/src
%cd ./philter/src
!git checkout v1.2024.1
!pip install -r requirements.txt
%cd ../..

nltk.download('averaged_perceptron_tagger', quiet=True)

In [None]:
# Format all letters from training dataset
%rm -rf ./philter/results ./philter/data
os.makedirs('./philter/results')
os.makedirs('./philter/data')

for letter in raw_train_dataset:
  with open(f"./philter/data/{letter['note_id']}.txt", 'w+') as f:
    f.write(letter['text'])

In [None]:
# Extract PHIs from each letter
%cd philter/src
!python3 deidpipe.py -i ../data/ -o ../results/ -f configs/philter_one2024.json -d False
%cd ../..

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Successfully parsed: 7154 dates.
Failed to parse: 174 dates.
/content


In [None]:
# Save PHIs in each sample of the dataset
def get_phi(sample):
  with open(f'./philter/results/log/phi_marked.json', 'r') as f:
    phi = json.load(f)
  phis = phi[f'../data/{sample["note_id"]}.txt']
  for p in phis:
    p['label'] = f'PHI_{p["type"]}'
    p.pop('type')
    p['text'] = p.pop('word')
    p.pop('context')
  return phis

#### Prepare NER and POS

In [None]:
# Load model to annotate clinical text
stanza.download('en', package='mimic', processors={'ner': 'i2b2'})
stza_detector = stanza.Pipeline('en', package='mimic', processors={'ner': 'i2b2'})

Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.8.0.json:   0%|   …

INFO:stanza:Downloaded file to /root/stanza_resources/resources.json
INFO:stanza:Downloading these customized packages for language: en (English)...
| Processor       | Package        |
------------------------------------
| tokenize        | mimic          |
| pos             | mimic_charlm   |
| lemma           | mimic_nocharlm |
| depparse        | mimic_charlm   |
| ner             | i2b2           |
| backward_charlm | mimic          |
| pretrain        | mimic          |
| forward_charlm  | mimic          |



Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/tokenize/mimic.pt:   0%|       …

INFO:stanza:Downloaded file to /root/stanza_resources/en/tokenize/mimic.pt


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/pos/mimic_charlm.pt:   0%|     …

INFO:stanza:Downloaded file to /root/stanza_resources/en/pos/mimic_charlm.pt


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/lemma/mimic_nocharlm.pt:   0%| …

INFO:stanza:Downloaded file to /root/stanza_resources/en/lemma/mimic_nocharlm.pt


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/depparse/mimic_charlm.pt:   0%|…

INFO:stanza:Downloaded file to /root/stanza_resources/en/depparse/mimic_charlm.pt


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/ner/i2b2.pt:   0%|          | 0…

INFO:stanza:Downloaded file to /root/stanza_resources/en/ner/i2b2.pt


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/backward_charlm/mimic.pt:   0%|…

INFO:stanza:Downloaded file to /root/stanza_resources/en/backward_charlm/mimic.pt


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/pretrain/mimic.pt:   0%|       …

INFO:stanza:Downloaded file to /root/stanza_resources/en/pretrain/mimic.pt


Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.8.0/models/forward_charlm/mimic.pt:   0%| …

INFO:stanza:Downloaded file to /root/stanza_resources/en/forward_charlm/mimic.pt
INFO:stanza:Finished downloading models and saved to /root/stanza_resources
INFO:stanza:Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.8.0.json:   0%|   …

INFO:stanza:Downloaded file to /root/stanza_resources/resources.json
INFO:stanza:Loading these models for language: en (English):
| Processor | Package        |
------------------------------
| tokenize  | mimic          |
| pos       | mimic_charlm   |
| lemma     | mimic_nocharlm |
| depparse  | mimic_charlm   |
| ner       | i2b2           |

INFO:stanza:Using device: cuda
INFO:stanza:Loading: tokenize
  checkpoint = torch.load(filename, lambda storage, loc: storage)
INFO:stanza:Loading: pos
  checkpoint = torch.load(filename, lambda storage, loc: storage)
  data = torch.load(self.filename, lambda storage, loc: storage)
  state = torch.load(filename, lambda storage, loc: storage)
INFO:stanza:Loading: lemma
  checkpoint = torch.load(filename, lambda storage, loc: storage)
INFO:stanza:Loading: depparse
  checkpoint = torch.load(filename, lambda storage, loc: storage)
INFO:stanza:Loading: ner
  checkpoint = torch.load(filename, lambda storage, loc: storage)
INFO:stanza:Done loading pro

In [None]:
def get_all_annotations(text, phi):
  doc = stza_detector(text)
  annotations = phi.copy()

  # Get all medical entities (to keep) that do not overlap with PHIs (to remove)
  tree_phi = IntervalTree.from_tuples([(inf['start'], inf['end']) for inf in phi])
  def overlaps(tree, start_ind, end_ind):
    return len(tree.overlap(start_ind, end_ind)) > 0

  for ents in doc.entities:
    if not overlaps(tree_phi, ents.start_char, ents.end_char):
      annotations.append(dict(start=ents.start_char, end=ents.end_char, text=ents.text, label=f'MED_{ents.type}'))
  # print('Number of medical entities:', len(annotations))

  # Search for other types of entities (NOUN, ADJ, VRB...) that do not overlap with medical entities (to keep)
  tree_med = IntervalTree.from_tuples([(annotation['start'], annotation['end']) for annotation in annotations])

  for sent in doc.sentences:
    for word in sent.words:
      if (not overlaps(tree_med, word.start_char, word.end_char)) and (not overlaps(tree_phi, word.start_char, word.end_char)):
        annotations.append(dict(start=word.start_char, end=word.end_char, text=word.text, label=word.upos))
  # print('Number of all entities (medical + others):', len(annotations))

  # Sort annotations based on appeareance order
  annotations = sorted(annotations, key=lambda x: x['start'])

  # Decompose multi-word annotations into multiple single-word annotations
  sing_word_annotations = []
  for anno in annotations:
    # If annotation is a NUM or PUNCT keep it as it is
    if anno['label'] == 'NUM' or anno['label'] == 'PUNCT':
      sing_word_annotations.append(anno)
    # Otherwise decompose it
    else:
      indexes = [[]]
      for i, char in enumerate(anno['text']):
        if char.isalpha() or char.isalnum() or char == '\'':
          indexes[-1].append(i)
        else:
          indexes.append([])

      for new_anno in [{
          'start': anno['start']+i[0], 'end': anno['start']+i[-1]+1,
          'text': ''.join([anno['text'][e] for e in i]), 'label': anno['label']
        } for i in indexes if len(i)>0]:
          sing_word_annotations.append(new_anno)

  # Sort annotations based on appeareance order
  sing_word_annotations = sorted(sing_word_annotations, key=lambda x: x['start'])

  return sing_word_annotations

#### Extraction

In [None]:
def extract_all_annotations(sample):
    phis = get_phi(sample)
    all_annotations = get_all_annotations(sample['text'], phis)
    sample['annotations'] = all_annotations
    return sample


def extract_annotations_to_mask_only(sample, mask_ratios):
    # Get all annotations
    phis = get_phi(sample)
    all_annotations = get_all_annotations(sample['text'], phis)

    # Sort annotations
    annotations = {'PHI': [], 'MED': []}
    for anno in all_annotations:
        anno_type = anno['label']
        if anno_type.startswith('PHI'):
            annotations['PHI'].append(anno)
        elif anno_type.startswith('MED'):
            annotations['MED'].append(anno)
        else:
            if anno_type not in annotations:
                annotations[anno_type] = []
            annotations[anno_type].append(anno)

    # Select entities to mask based on given ratios
    entities_to_mask = {}
    for anno_type, annos in annotations.items():
        if anno_type == 'PHI':
            # Mask ratio for PHI is 1.0
            entities_to_mask[anno_type] = annos.copy()
        elif (anno_type == 'MED') or (anno_type not in mask_ratios):
            # Mask ratio for MED or absent types are 0.0
            continue
        else:
            random.seed(55) # For reproducibility
            annotations_to_keep = random.sample(annos, int(len(annos) * mask_ratios[anno_type]))
            entities_to_mask[anno_type] = annotations_to_keep

    # Merge all entities (stop considering annotation types) and order them
    entities_to_mask = sorted(sum(entities_to_mask.values(), []), key=lambda x: x['start'])

    sample['annotations'] = entities_to_mask
    return sample

In [None]:
# Extract annotations
train_dataset['train'] = train_dataset['train'].map(extract_all_annotations)
train_dataset['test'] = train_dataset['test'].map(extract_annotations_to_mask_only, fn_kwargs={'mask_ratios': TEST_MASK_RATIOS})

train_dataset

Map:   0%|          | 0/632 [00:00<?, ? examples/s]

Map:   0%|          | 0/158 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['note_id', 'text', 'annotations'],
        num_rows: 632
    })
    test: Dataset({
        features: ['note_id', 'text', 'annotations'],
        num_rows: 158
    })
})

### Tokenize dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]



In [None]:
def align_annotations_with_tokens(tokenized_text, annotations):
  tokens = tokenized_text[0].tokens
  aligned_annotations = ["O"]*len(tokens)

  for anno in (annotations):
      for char_ix in range(anno['start'],anno['end']):
          token_ix = tokenized_text.char_to_token(char_ix)
          if token_ix is not None: # White spaces have no token and will return None
              aligned_annotations[token_ix] = anno['label']

  # for token, anno in zip(tokens, aligned_annotations):
  #   print(token, '-', anno)

  return aligned_annotations

In [None]:
def tokenize_function(sample):
  # Get annotations and tokenized text
  annotations = sample['annotations']
  sample = tokenizer(sample['text'])

  # Align annotations with tokens
  aligned_annotations = align_annotations_with_tokens(tokenized_text=sample, annotations=annotations)
  sample['annotations'] = aligned_annotations

  return sample

In [None]:
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=False, remove_columns=train_dataset['train'].column_names)
tokenized_train_dataset

Map:   0%|          | 0/632 [00:00<?, ? examples/s]

Map:   0%|          | 0/158 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['annotations', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 632
    })
    test: Dataset({
        features: ['annotations', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 158
    })
})

### Group dataset

In [None]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than CHUNK_SIZE
    total_length = (total_length // CHUNK_SIZE) * CHUNK_SIZE
    # Split by chunks of max_len
    result = {
        k: [t[i : i + CHUNK_SIZE] for i in range(0, total_length, CHUNK_SIZE)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
grp_datasets = tokenized_train_dataset.map(group_texts, batched=True)
grp_datasets

Map:   0%|          | 0/632 [00:00<?, ? examples/s]

Map:   0%|          | 0/158 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['annotations', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2896
    })
    test: Dataset({
        features: ['annotations', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 720
    })
})

# Hyperparameter tuning & Training

## Finetuning setup

### Data Collator

In [None]:
from transformers.data.data_collator import *
from transformers.data.data_collator import _torch_collate_batch, _tf_collate_batch, _numpy_collate_batch

@dataclass
class CustomBaseWholeWordMaskingDataCollator(DataCollatorForWholeWordMask):
    def __init__(self, tokenizer, mlm_probability):
        super().__init__(tokenizer, mlm_probability)

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
      if isinstance(examples[0], Mapping):
          input_ids = [e["input_ids"] for e in examples]
      else:
          input_ids = examples
          examples = [{"input_ids": e} for e in examples]

      batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

      mask_labels = []
      for e in examples:
          ref_tokens = []
          for id in tolist(e["input_ids"]):
              token = self.tokenizer._convert_id_to_token(id)
              ref_tokens.append(token)

          # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜，##欢]
          if "chinese_ref" in e:
              ref_pos = tolist(e["chinese_ref"])
              len_seq = len(e["input_ids"])
              for i in range(len_seq):
                  if i in ref_pos:
                      ref_tokens[i] = "##" + ref_tokens[i]
          mask_labels.append(self._whole_word_mask(ref_tokens, e['annotations']))
      batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
      inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
      return {"input_ids": inputs, "labels": labels}

    def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
      if isinstance(examples[0], Mapping):
          input_ids = [e["input_ids"] for e in examples]
      else:
          input_ids = examples
          examples = [{"input_ids": e} for e in examples]

      batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

      mask_labels = []
      for e in examples:
          ref_tokens = []
          for id in tolist(e["input_ids"]):
              token = self.tokenizer._convert_id_to_token(id)
              ref_tokens.append(token)

          # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜，##欢]
          if "chinese_ref" in e:
              ref_pos = tolist(e["chinese_ref"])
              len_seq = len(e["input_ids"])
              for i in range(len_seq):
                  if i in ref_pos:
                      ref_tokens[i] = "##" + ref_tokens[i]
          mask_labels.append(self._whole_word_mask(ref_tokens, e['annotations']))
      batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
      inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
      return {"input_ids": inputs, "labels": labels}

    def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
      if isinstance(examples[0], Mapping):
          input_ids = [e["input_ids"] for e in examples]
      else:
          input_ids = examples
          examples = [{"input_ids": e} for e in examples]

      batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

      mask_labels = []
      for e in examples:
          ref_tokens = []
          for id in tolist(e["input_ids"]):
              token = self.tokenizer._convert_id_to_token(id)
              ref_tokens.append(token)

          # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜，##欢]
          if "chinese_ref" in e:
              ref_pos = tolist(e["chinese_ref"])
              len_seq = len(e["input_ids"])
              for i in range(len_seq):
                  if i in ref_pos:
                      ref_tokens[i] = "##" + ref_tokens[i]
          mask_labels.append(self._whole_word_mask(ref_tokens, e['annotations']))
      batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
      inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
      return {"input_ids": inputs, "labels": labels}



@dataclass
class CustomTrainingWordMaskingDataCollator(CustomBaseWholeWordMaskingDataCollator):
    def __init__(self, tokenizer, mlm_probability, phi_masking_proportion):
        self.phi_masking_proportion = phi_masking_proportion
        super().__init__(tokenizer, mlm_probability)

    def _whole_word_mask(self, input_tokens: List[str], token_labels: List[str], max_predictions=512):
        """
        Get 0/1 labels for masked tokens with whole word mask proxy
        """
        if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
            warnings.warn(
                "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
                "Please refer to the documentation for more information."
            )

        phi_indexes, cand_indexes = [], []
        for i, (token, label) in enumerate(zip(input_tokens, token_labels)):
            if token == "[CLS]" or token == "[SEP]":
                continue

            # Force mask PHIs
            if label.startswith("PHI"):
                if len(phi_indexes) >= 1 and token.startswith("##"):
                    phi_indexes[-1].append(i)
                else:
                    phi_indexes.append([i])
            # Otherwise, do not mask if medical entity or is not a noun
            elif label.startswith("MED") or label == 'PUNCT': # TODO: or label != "NOUN" ????
              continue
            # Finally, if non-medical noun: mask
            else:
                if len(cand_indexes) >= 1 and token.startswith("##"):
                    cand_indexes[-1].append(i)
                else:
                    cand_indexes.append([i])

        # Select a proportion of PHIs to mask
        random.shuffle(phi_indexes)
        phi_indexes = phi_indexes[:int(len(phi_indexes)*self.phi_masking_proportion)]

        # Build candidate indexes and /!\ Prioritize PHIs to mask /!\
        random.shuffle(cand_indexes)
        cand_indexes = phi_indexes + cand_indexes

        num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
        masked_lms = []
        covered_indexes = set()
        for index_set in cand_indexes:
            if len(masked_lms) >= num_to_predict:
                break
            # If adding a whole-word mask would exceed the maximum number of
            # predictions, then just skip this candidate.
            if len(masked_lms) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                if index in covered_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                covered_indexes.add(index)
                masked_lms.append(index)

        if len(covered_indexes) != len(masked_lms):
            raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
        mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
        return mask_labels



@dataclass
class CustomEvaluationWordMaskingDataCollator(CustomBaseWholeWordMaskingDataCollator):
    def __init__(self, tokenizer, mlm_probability=1.0):
        super().__init__(tokenizer, mlm_probability)

    def _whole_word_mask(self, input_tokens: List[str], token_labels: List[str], max_predictions=512):
        """
        Get 0/1 labels for masked tokens with whole word mask proxy
        """
        if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
            warnings.warn(
                "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
                "Please refer to the documentation for more information."
            )

        cand_indexes = []
        for i, (token, label) in enumerate(zip(input_tokens, token_labels)):
            if token == "[CLS]" or token == "[SEP]":
                continue

            # Force mask all non-O annotations
            if label != 'O':
                if len(cand_indexes) >= 1 and token.startswith("##"):
                    cand_indexes[-1].append(i)
                else:
                    cand_indexes.append([i])

        random.shuffle(cand_indexes)

        num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
        masked_lms = []
        covered_indexes = set()
        for index_set in cand_indexes:
            if len(masked_lms) >= num_to_predict:
                break
            # If adding a whole-word mask would exceed the maximum number of
            # predictions, then just skip this candidate.
            if len(masked_lms) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                if index in covered_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                covered_indexes.add(index)
                masked_lms.append(index)

        if len(covered_indexes) != len(masked_lms):
            raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
        mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
        return mask_labels

In [None]:
from transformers.trainer import *

class DoubleCollatorTrainer(Trainer):
    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        eval_data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        super().__init__(
            model, args, data_collator, train_dataset, eval_dataset, tokenizer,
            model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics,
        )
        self.eval_data_collator = eval_data_collator


    def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
                If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        # If we have persistent workers, don't do a fork bomb especially as eval datasets
        # don't change during training
        dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
        if (
            hasattr(self, "_eval_dataloaders")
            and dataloader_key in self._eval_dataloaders
            and self.args.dataloader_persistent_workers
        ):
            return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])

        eval_dataset = (
            self.eval_dataset[eval_dataset]
            if isinstance(eval_dataset, str)
            else eval_dataset
            if eval_dataset is not None
            else self.eval_dataset
        )
        data_collator = self.eval_data_collator

        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        # accelerator.free_memory() will destroy the references, so
        # we need to store the non-prepared version
        eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
        if self.args.dataloader_persistent_workers:
            if hasattr(self, "_eval_dataloaders"):
                self._eval_dataloaders[dataloader_key] = eval_dataloader
            else:
                self._eval_dataloaders = {dataloader_key: eval_dataloader}

        return self.accelerator.prepare(eval_dataloader)

### Setup

In [None]:
# Create a model initialiser function
def model_init():
    model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
    # Make sure all tensors are contiguous
    for param in model.parameters():
        param.data = param.data.contiguous()
    return model

In [None]:
def get_training_results_and_best_epoch(trainer):
  training_results = {'eval_loss': [], 'epoch': [], 'train_loss': [], 'eval_perplexity': []}
  for val, tr in zip(trainer.state.log_history[1::2], trainer.state.log_history[:-2:2]):
    training_results['eval_loss'].append(round(val['eval_loss'], 3))
    training_results['epoch'].append(val['epoch'])
    training_results['train_loss'].append(round(tr['loss'], 3))
    training_results['eval_perplexity'].append(round(math.exp(val['eval_loss']), 3))
  best_epoch = {key: val[np.argmin(training_results['eval_loss'])] for key, val in training_results.items()}
  return training_results, best_epoch

In [None]:
def fine_tune(hyperparameters, dataset):
  ''' Function to fine-tune a pretrained model using Hugging Face's pipeline.
  '''
  # Model's and logs directory
  RUN_NAME = f"Run - {datetime.datetime.now().strftime('%m-%d-%H-%M')}"
  DIR = f"temp/hp_tuning/Results/{RUN_NAME}" # Models do not need to be saved during hyperparameter tuning
  os.environ["WANDB_PROJECT"] = 'MLM_LETTER' # set the wandb project where this run will be logged

  # Define training args
  training_args = TrainingArguments(
    run_name=RUN_NAME.replace(' ', '_'),
    output_dir=f"{DIR}/checkpoints",

    # Parameters
    per_device_train_batch_size = hyperparameters["batch_size"],
    learning_rate = hyperparameters["learning_rate"],
    weight_decay = hyperparameters['weight_decay'],
    num_train_epochs = 10,       # Use early stopping (so this is maximum epochs)
    fp16 = (DEVICE == 'cuda'),   # Use 16-bit (mixed) precision instead of 32-bit (ONLY POSSIBLE ON CUDA!)
    optim = "adamw_torch",
    seed=42,                     # Use a seed for reproducibility
    remove_unused_columns=False, # IMPORTANT: if not set to False, the custom data collator will delete 'annotations' column

    # Logging
    logging_dir=f"{DIR}/training_logs",
    logging_strategy="epoch",
    report_to="wandb",
    # Saving
    save_strategy="epoch",
    save_safetensors=True, save_total_limit=1,
    # Evaluating (Use validation loss for model selection and early stopping)
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss", greater_is_better=False,
  )

  # Create a Trainer instance
  trainer = DoubleCollatorTrainer(
    model_init=model_init,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    data_collator=CustomTrainingWordMaskingDataCollator(
      tokenizer=tokenizer,
      mlm_probability=hyperparameters['train_masking_prob'],
      phi_masking_proportion=hyperparameters['phi_masking_proportion']
    ),
    eval_data_collator=CustomEvaluationWordMaskingDataCollator(tokenizer=tokenizer),
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(2, 0.0)],
  )

  # Fine-tune the model
  trainer.train()

  # Delete model checkpoints
  shutil.rmtree(DIR)

  # Get training results
  training_results, best_epoch = get_training_results_and_best_epoch(trainer)
  return training_results, best_epoch, RUN_NAME


## Hyperparameter tuning

In [None]:
# Hyperparameters recommended by BERT and RoBERTa
HYPERPARAMETERS = {
    'weight_decay': [0.01], # 0.02
    'learning_rate': [1e-4, 5e-5, 3e-5],
    'batch_size': [8, 16], # 4
    'phi_masking_proportion': [0.75, 1.0],
    'train_masking_prob': [0.30, 0.50],
}

# Get all possible combination of hyperparameter sets (for grid search)
keys, values = zip(*HYPERPARAMETERS.items())
HYPERPARAMETERS_COMB = [dict(zip(keys, v)) for v in itertools.product(*values)]
print(f'Number of hyperparameter sets to search through: {len(HYPERPARAMETERS_COMB)}.')

Number of hyperparameter sets to search through: 24.


In [None]:
results = {}
with open(f'{MODEL_PATH}/training_logs.txt', 'a+') as log:

    # Fine-tune on each hyperparameter set (grid search)
    for HYPERPARAM in HYPERPARAMETERS_COMB:
        print('--- STARTING FINE-TUNING ---')
        training_results, best_epoch, run_name = fine_tune(hyperparameters=HYPERPARAM, dataset=grp_datasets)
        HYPERPARAM['epoch'] = best_epoch['epoch']
        results[best_epoch['eval_perplexity']] = HYPERPARAM

        print('Run name:', run_name)
        print('Hyperparameters:', HYPERPARAM)
        print('Final training results:', training_results)
        print('Best epoch results:', best_epoch)

        log.write(f"Run name: {run_name}\n")
        log.write('Hyperparameters:\n' + ',\n'.join([f'\t{key}={value}' for key, value in HYPERPARAM.items()]) + '.\n')
        log.write('Final training results:\n' + ';\n'.join([f'\t{key}={", ".join([str(i) for i in value])}' for key, value in training_results.items()]) + '.\n')
        log.write('Best epoch results:\n' + ',\n'.join([f'\t{key}={value}' for key, value in best_epoch.items()]) + '.\n\n')

--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

[34m[1mwandb[0m: Currently logged in as: [33mbelkadisamuel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss
1,3.0724,3.100508
2,2.5673,2.909697
3,2.3644,2.805727
4,2.2137,2.762541
5,2.0764,2.699641
6,1.9547,2.666494
7,1.872,2.690596
8,1.7883,2.589769
9,1.7229,2.654836
10,1.6715,2.604804


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-14-04
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 8, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.3, 'epoch': 8.0}
Final training results: {'eval_loss': [3.101, 2.91, 2.806, 2.763, 2.7, 2.666, 2.691, 2.59, 2.655, 2.605], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.072, 2.567, 2.364, 2.214, 2.076, 1.955, 1.872, 1.788, 1.723, 1.671], 'eval_perplexity': [22.209, 18.351, 16.539, 15.84, 14.874, 14.389, 14.74, 13.327, 14.223, 13.529]}
Best epoch results: {'eval_loss': 2.59, 'epoch': 8.0, 'train_loss': 1.788, 'eval_perplexity': 13.327}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.0724,3.100508
2,2.5673,2.909697
3,2.3644,2.805727
4,2.2137,2.762541
5,2.0764,2.699641
6,1.9547,2.666494
7,1.872,2.690596
8,1.7883,2.589769
9,1.7229,2.654836
10,1.6715,2.604804


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-14-16
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 8, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.5, 'epoch': 8.0}
Final training results: {'eval_loss': [3.101, 2.91, 2.806, 2.763, 2.7, 2.666, 2.691, 2.59, 2.655, 2.605], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.072, 2.567, 2.364, 2.214, 2.076, 1.955, 1.872, 1.788, 1.723, 1.671], 'eval_perplexity': [22.209, 18.351, 16.539, 15.84, 14.874, 14.389, 14.74, 13.327, 14.223, 13.529]}
Best epoch results: {'eval_loss': 2.59, 'epoch': 8.0, 'train_loss': 1.788, 'eval_perplexity': 13.327}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.3536,3.139654
2,2.7717,2.952297
3,2.5251,2.855446
4,2.3404,2.808918
5,2.1841,2.787064
6,2.0235,2.749475
7,1.9084,2.775935
8,1.8028,2.695322
9,1.7093,2.767049
10,1.6495,2.729613


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-14-28
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 8, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.3, 'epoch': 8.0}
Final training results: {'eval_loss': [3.14, 2.952, 2.855, 2.809, 2.787, 2.749, 2.776, 2.695, 2.767, 2.73], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.354, 2.772, 2.525, 2.34, 2.184, 2.023, 1.908, 1.803, 1.709, 1.649], 'eval_perplexity': [23.096, 19.15, 17.382, 16.592, 16.233, 15.634, 16.054, 14.81, 15.912, 15.327]}
Best epoch results: {'eval_loss': 2.695, 'epoch': 8.0, 'train_loss': 1.803, 'eval_perplexity': 14.81}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.3536,3.139654
2,2.7717,2.952297
3,2.5251,2.855446
4,2.3404,2.808918
5,2.1841,2.787064
6,2.0235,2.749475
7,1.9084,2.775935
8,1.8028,2.695322
9,1.7093,2.767049
10,1.6495,2.729613


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-14-39
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 8, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.5, 'epoch': 8.0}
Final training results: {'eval_loss': [3.14, 2.952, 2.855, 2.809, 2.787, 2.749, 2.776, 2.695, 2.767, 2.73], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.354, 2.772, 2.525, 2.34, 2.184, 2.023, 1.908, 1.803, 1.709, 1.649], 'eval_perplexity': [23.096, 19.15, 17.382, 16.592, 16.233, 15.634, 16.054, 14.81, 15.912, 15.327]}
Best epoch results: {'eval_loss': 2.695, 'epoch': 8.0, 'train_loss': 1.803, 'eval_perplexity': 14.81}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.085,3.095192
2,2.5879,2.91448
3,2.3871,2.794341
4,2.2653,2.755175
5,2.1401,2.684023
6,2.0222,2.688478
7,1.956,2.661776
8,1.8935,2.595476
9,1.8376,2.647528
10,1.7945,2.599767


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-14-51
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 16, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.3, 'epoch': 8.0}
Final training results: {'eval_loss': [3.095, 2.914, 2.794, 2.755, 2.684, 2.688, 2.662, 2.595, 2.648, 2.6], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.085, 2.588, 2.387, 2.265, 2.14, 2.022, 1.956, 1.893, 1.838, 1.794], 'eval_perplexity': [22.091, 18.439, 16.352, 15.724, 14.644, 14.709, 14.322, 13.403, 14.119, 13.461]}
Best epoch results: {'eval_loss': 2.595, 'epoch': 8.0, 'train_loss': 1.893, 'eval_perplexity': 13.403}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.085,3.095192
2,2.5879,2.91448
3,2.3871,2.794341
4,2.2653,2.755175
5,2.1401,2.684023
6,2.0222,2.688478
7,1.956,2.661776
8,1.8935,2.595476
9,1.8376,2.647528
10,1.7945,2.599767


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-15-00
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 16, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.5, 'epoch': 8.0}
Final training results: {'eval_loss': [3.095, 2.914, 2.794, 2.755, 2.684, 2.688, 2.662, 2.595, 2.648, 2.6], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.085, 2.588, 2.387, 2.265, 2.14, 2.022, 1.956, 1.893, 1.838, 1.794], 'eval_perplexity': [22.091, 18.439, 16.352, 15.724, 14.644, 14.709, 14.322, 13.403, 14.119, 13.461]}
Best epoch results: {'eval_loss': 2.595, 'epoch': 8.0, 'train_loss': 1.893, 'eval_perplexity': 13.403}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.3459,3.1872
2,2.8118,2.955823
3,2.5615,2.839792
4,2.4064,2.7961
5,2.2569,2.734899
6,2.1098,2.742698
7,2.0222,2.735871


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-15-10
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 16, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.3, 'epoch': 5.0}
Final training results: {'eval_loss': [3.187, 2.956, 2.84, 2.796, 2.735, 2.743, 2.736], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], 'train_loss': [3.346, 2.812, 2.562, 2.406, 2.257, 2.11, 2.022], 'eval_perplexity': [24.221, 19.218, 17.112, 16.381, 15.408, 15.529, 15.423]}
Best epoch results: {'eval_loss': 2.735, 'epoch': 5.0, 'train_loss': 2.257, 'eval_perplexity': 15.408}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.3459,3.1872
2,2.8118,2.955823
3,2.5615,2.839792
4,2.4064,2.7961
5,2.2569,2.734899
6,2.1098,2.742698
7,2.0222,2.735871


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-15-17
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 0.0001, 'batch_size': 16, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.5, 'epoch': 5.0}
Final training results: {'eval_loss': [3.187, 2.956, 2.84, 2.796, 2.735, 2.743, 2.736], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], 'train_loss': [3.346, 2.812, 2.562, 2.406, 2.257, 2.11, 2.022], 'eval_perplexity': [24.221, 19.218, 17.112, 16.381, 15.408, 15.529, 15.423]}
Best epoch results: {'eval_loss': 2.735, 'epoch': 5.0, 'train_loss': 2.257, 'eval_perplexity': 15.408}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.1133,3.14437
2,2.6484,2.944725
3,2.4662,2.842561
4,2.3412,2.787303
5,2.2389,2.734298
6,2.1398,2.705681
7,2.0808,2.712857
8,2.0253,2.635262
9,1.981,2.684434
10,1.9498,2.628921


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-15-24
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 8, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.3, 'epoch': 10.0}
Final training results: {'eval_loss': [3.144, 2.945, 2.843, 2.787, 2.734, 2.706, 2.713, 2.635, 2.684, 2.629], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.113, 2.648, 2.466, 2.341, 2.239, 2.14, 2.081, 2.025, 1.981, 1.95], 'eval_perplexity': [23.205, 19.005, 17.16, 16.237, 15.399, 14.964, 15.072, 13.947, 14.65, 13.859]}
Best epoch results: {'eval_loss': 2.629, 'epoch': 10.0, 'train_loss': 1.95, 'eval_perplexity': 13.859}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.1133,3.14437
2,2.6484,2.944725
3,2.4662,2.842561
4,2.3412,2.787303
5,2.2389,2.734298
6,2.1398,2.705681
7,2.0808,2.712857
8,2.0253,2.635262
9,1.981,2.684434
10,1.9498,2.628921


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-15-35
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 8, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.5, 'epoch': 10.0}
Final training results: {'eval_loss': [3.144, 2.945, 2.843, 2.787, 2.734, 2.706, 2.713, 2.635, 2.684, 2.629], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.113, 2.648, 2.466, 2.341, 2.239, 2.14, 2.081, 2.025, 1.981, 1.95], 'eval_perplexity': [23.205, 19.005, 17.16, 16.237, 15.399, 14.964, 15.072, 13.947, 14.65, 13.859]}
Best epoch results: {'eval_loss': 2.629, 'epoch': 10.0, 'train_loss': 1.95, 'eval_perplexity': 13.859}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.3788,3.174569
2,2.854,2.964445
3,2.6267,2.869552
4,2.4781,2.808293
5,2.3562,2.783348
6,2.2364,2.754899
7,2.1599,2.763312
8,2.093,2.700912
9,2.0352,2.746927
10,2.0039,2.703594


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-15-47
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 8, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.3, 'epoch': 8.0}
Final training results: {'eval_loss': [3.175, 2.964, 2.87, 2.808, 2.783, 2.755, 2.763, 2.701, 2.747, 2.704], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.379, 2.854, 2.627, 2.478, 2.356, 2.236, 2.16, 2.093, 2.035, 2.004], 'eval_perplexity': [23.917, 19.384, 17.629, 16.582, 16.173, 15.719, 15.852, 14.893, 15.595, 14.933]}
Best epoch results: {'eval_loss': 2.701, 'epoch': 8.0, 'train_loss': 2.093, 'eval_perplexity': 14.893}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.3788,3.174569
2,2.854,2.964445
3,2.6267,2.869552
4,2.4781,2.808293
5,2.3562,2.783348
6,2.2364,2.754899
7,2.1599,2.763312
8,2.093,2.700912
9,2.0352,2.746927
10,2.0039,2.703594


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-15-58
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 8, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.5, 'epoch': 8.0}
Final training results: {'eval_loss': [3.175, 2.964, 2.87, 2.808, 2.783, 2.755, 2.763, 2.701, 2.747, 2.704], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.379, 2.854, 2.627, 2.478, 2.356, 2.236, 2.16, 2.093, 2.035, 2.004], 'eval_perplexity': [23.917, 19.384, 17.629, 16.582, 16.173, 15.719, 15.852, 14.893, 15.595, 14.933]}
Best epoch results: {'eval_loss': 2.701, 'epoch': 8.0, 'train_loss': 2.093, 'eval_perplexity': 14.893}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.1755,3.194907
2,2.7238,3.015862
3,2.542,2.904717
4,2.44,2.841956
5,2.3339,2.768363
6,2.2362,2.765403
7,2.1911,2.746276
8,2.1487,2.691926
9,2.1115,2.733178
10,2.0839,2.671197


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-16-10
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 16, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.3, 'epoch': 10.0}
Final training results: {'eval_loss': [3.195, 3.016, 2.905, 2.842, 2.768, 2.765, 2.746, 2.692, 2.733, 2.671], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.175, 2.724, 2.542, 2.44, 2.334, 2.236, 2.191, 2.149, 2.111, 2.084], 'eval_perplexity': [24.408, 20.407, 18.26, 17.149, 15.933, 15.885, 15.584, 14.76, 15.382, 14.457]}
Best epoch results: {'eval_loss': 2.671, 'epoch': 10.0, 'train_loss': 2.084, 'eval_perplexity': 14.457}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.1755,3.194907
2,2.7238,3.015862
3,2.542,2.904717
4,2.44,2.841956
5,2.3339,2.768363
6,2.2362,2.765403
7,2.1911,2.746276
8,2.1487,2.691926
9,2.1115,2.733178
10,2.0839,2.671197


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-16-20
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 16, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.5, 'epoch': 10.0}
Final training results: {'eval_loss': [3.195, 3.016, 2.905, 2.842, 2.768, 2.765, 2.746, 2.692, 2.733, 2.671], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.175, 2.724, 2.542, 2.44, 2.334, 2.236, 2.191, 2.149, 2.111, 2.084], 'eval_perplexity': [24.408, 20.407, 18.26, 17.149, 15.933, 15.885, 15.584, 14.76, 15.382, 14.457]}
Best epoch results: {'eval_loss': 2.671, 'epoch': 10.0, 'train_loss': 2.084, 'eval_perplexity': 14.457}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.4497,3.234231
2,2.9433,3.060259
3,2.7256,2.941847
4,2.6086,2.862248
5,2.4828,2.804654
6,2.368,2.798002
7,2.3081,2.781413
8,2.2646,2.729963
9,2.2108,2.776732
10,2.1906,2.726476


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-16-29
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 16, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.3, 'epoch': 10.0}
Final training results: {'eval_loss': [3.234, 3.06, 2.942, 2.862, 2.805, 2.798, 2.781, 2.73, 2.777, 2.726], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.45, 2.943, 2.726, 2.609, 2.483, 2.368, 2.308, 2.265, 2.211, 2.191], 'eval_perplexity': [25.387, 21.333, 18.951, 17.501, 16.521, 16.412, 16.142, 15.332, 16.066, 15.279]}
Best epoch results: {'eval_loss': 2.726, 'epoch': 10.0, 'train_loss': 2.191, 'eval_perplexity': 15.279}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.4497,3.234231
2,2.9433,3.060259
3,2.7256,2.941847
4,2.6086,2.862248
5,2.4828,2.804654
6,2.368,2.798002
7,2.3081,2.781413
8,2.2646,2.729963
9,2.2108,2.776732
10,2.1906,2.726476


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-16-39
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 5e-05, 'batch_size': 16, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.5, 'epoch': 10.0}
Final training results: {'eval_loss': [3.234, 3.06, 2.942, 2.862, 2.805, 2.798, 2.781, 2.73, 2.777, 2.726], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.45, 2.943, 2.726, 2.609, 2.483, 2.368, 2.308, 2.265, 2.211, 2.191], 'eval_perplexity': [25.387, 21.333, 18.951, 17.501, 16.521, 16.412, 16.142, 15.332, 16.066, 15.279]}
Best epoch results: {'eval_loss': 2.726, 'epoch': 10.0, 'train_loss': 2.191, 'eval_perplexity': 15.279}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.194,3.224421
2,2.7466,3.030812
3,2.5767,2.93367
4,2.4678,2.859784
5,2.3724,2.806974
6,2.2922,2.783483
7,2.2414,2.785889
8,2.2043,2.727731
9,2.1721,2.760677
10,2.1476,2.700244


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-16-49
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 3e-05, 'batch_size': 8, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.3, 'epoch': 10.0}
Final training results: {'eval_loss': [3.224, 3.031, 2.934, 2.86, 2.807, 2.783, 2.786, 2.728, 2.761, 2.7], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.194, 2.747, 2.577, 2.468, 2.372, 2.292, 2.241, 2.204, 2.172, 2.148], 'eval_perplexity': [25.139, 20.714, 18.796, 17.458, 16.56, 16.175, 16.214, 15.298, 15.811, 14.883]}
Best epoch results: {'eval_loss': 2.7, 'epoch': 10.0, 'train_loss': 2.148, 'eval_perplexity': 14.883}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.194,3.224421
2,2.7466,3.030812
3,2.5767,2.93367
4,2.4678,2.859784
5,2.3724,2.806974
6,2.2922,2.783483
7,2.2414,2.785889
8,2.2043,2.727731
9,2.1721,2.760677
10,2.1476,2.700244


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-17-00
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 3e-05, 'batch_size': 8, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.5, 'epoch': 10.0}
Final training results: {'eval_loss': [3.224, 3.031, 2.934, 2.86, 2.807, 2.783, 2.786, 2.728, 2.761, 2.7], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.194, 2.747, 2.577, 2.468, 2.372, 2.292, 2.241, 2.204, 2.172, 2.148], 'eval_perplexity': [25.139, 20.714, 18.796, 17.458, 16.56, 16.175, 16.214, 15.298, 15.811, 14.883]}
Best epoch results: {'eval_loss': 2.7, 'epoch': 10.0, 'train_loss': 2.148, 'eval_perplexity': 14.883}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.4691,3.257412
2,2.9688,3.074408
3,2.7683,2.968498
4,2.6434,2.90241
5,2.5431,2.859454
6,2.4444,2.834884
7,2.3804,2.836597
8,2.3346,2.783706
9,2.2916,2.817677
10,2.2715,2.759548


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-17-12
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 3e-05, 'batch_size': 8, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.3, 'epoch': 10.0}
Final training results: {'eval_loss': [3.257, 3.074, 2.968, 2.902, 2.859, 2.835, 2.837, 2.784, 2.818, 2.76], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.469, 2.969, 2.768, 2.643, 2.543, 2.444, 2.38, 2.335, 2.292, 2.272], 'eval_perplexity': [25.982, 21.637, 19.463, 18.218, 17.452, 17.028, 17.058, 16.179, 16.738, 15.793]}
Best epoch results: {'eval_loss': 2.76, 'epoch': 10.0, 'train_loss': 2.272, 'eval_perplexity': 15.793}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.4691,3.257412
2,2.9688,3.074408
3,2.7683,2.968498
4,2.6434,2.90241
5,2.5431,2.859454
6,2.4444,2.834884
7,2.3804,2.836597
8,2.3346,2.783706
9,2.2916,2.817677
10,2.2715,2.759548


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-17-24
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 3e-05, 'batch_size': 8, 'phi_masking_proportion': 1.0, 'train_masking_prob': 0.5, 'epoch': 10.0}
Final training results: {'eval_loss': [3.257, 3.074, 2.968, 2.902, 2.859, 2.835, 2.837, 2.784, 2.818, 2.76], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.469, 2.969, 2.768, 2.643, 2.543, 2.444, 2.38, 2.335, 2.292, 2.272], 'eval_perplexity': [25.982, 21.637, 19.463, 18.218, 17.452, 17.028, 17.058, 16.179, 16.738, 15.793]}
Best epoch results: {'eval_loss': 2.76, 'epoch': 10.0, 'train_loss': 2.272, 'eval_perplexity': 15.793}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.2788,3.300688
2,2.8454,3.114532
3,2.6713,3.000424
4,2.5739,2.934534


Epoch,Training Loss,Validation Loss
1,3.2788,3.300688
2,2.8454,3.114532
3,2.6713,3.000424
4,2.5739,2.934534
5,2.4762,2.852698
6,2.384,2.844365
7,2.3509,2.827296
8,2.3188,2.778606
9,2.29,2.816779
10,2.2697,2.752297


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-17-36
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 3e-05, 'batch_size': 16, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.3, 'epoch': 10.0}
Final training results: {'eval_loss': [3.301, 3.115, 3.0, 2.935, 2.853, 2.844, 2.827, 2.779, 2.817, 2.752], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.279, 2.845, 2.671, 2.574, 2.476, 2.384, 2.351, 2.319, 2.29, 2.27], 'eval_perplexity': [27.131, 22.523, 20.094, 18.813, 17.334, 17.191, 16.9, 16.097, 16.723, 15.679]}
Best epoch results: {'eval_loss': 2.752, 'epoch': 10.0, 'train_loss': 2.27, 'eval_perplexity': 15.679}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.2788,3.300688
2,2.8454,3.114532
3,2.6713,3.000424
4,2.5739,2.934534
5,2.4762,2.852698
6,2.384,2.844365
7,2.3509,2.827296
8,2.3188,2.778606
9,2.29,2.816779
10,2.2697,2.752297


There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


Run name: Run - 09-08-17-45
Hyperparameters: {'weight_decay': 0.01, 'learning_rate': 3e-05, 'batch_size': 16, 'phi_masking_proportion': 0.75, 'train_masking_prob': 0.5, 'epoch': 10.0}
Final training results: {'eval_loss': [3.301, 3.115, 3.0, 2.935, 2.853, 2.844, 2.827, 2.779, 2.817, 2.752], 'epoch': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 'train_loss': [3.279, 2.845, 2.671, 2.574, 2.476, 2.384, 2.351, 2.319, 2.29, 2.27], 'eval_perplexity': [27.131, 22.523, 20.094, 18.813, 17.334, 17.191, 16.9, 16.097, 16.723, 15.679]}
Best epoch results: {'eval_loss': 2.752, 'epoch': 10.0, 'train_loss': 2.27, 'eval_perplexity': 15.679}
--- STARTING FINE-TUNING ---


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,3.5439,3.334927
2,3.076,3.159799
3,2.8754,3.043682
4,2.7676,2.967408
5,2.6615,2.916734
6,2.5627,2.897597
7,2.5158,2.887901
8,2.4874,2.841732


## Training (with optimal hyperparameters)

In [None]:
# Select best hyperparameter set
best_hyperparameters = results[min(results.keys())]
print(f'Best hyperparameters (prplx={min(results.keys())}):', best_hyperparameters)

In [None]:
# Transform validation dataset into training format
train_dataset['test'] = train_dataset['test'].map(extract_all_annotations)
tokenized_train_dataset['test'] = train_dataset['test'].map(tokenize_function, batched=False, remove_columns=train_dataset['train'].column_names)
grp_datasets['test'] = tokenized_train_dataset['test'].map(group_texts, batched=True)

# Combine training and validation sets
grp_datasets_train_valid = datasets.concatenate_datasets([grp_datasets['train'], grp_datasets['test']])

In [None]:
# Model's and logs directory
RUN_NAME = f"Run - {datetime.datetime.now().strftime('%m-%d-%H-%M')}"
DIR = f"{MODEL_PATH}/Results/{RUN_NAME}"
os.environ["WANDB_PROJECT"] = 'MLM_LETTER' # set the wandb project where this run will be logged

# Define training args
training_args = TrainingArguments(
  run_name=RUN_NAME.replace(' ', '_'),
  output_dir=f"{DIR}/checkpoints",

  # Parameters
  per_device_train_batch_size = best_hyperparameters["batch_size"],
  learning_rate = best_hyperparameters["learning_rate"],
  weight_decay = best_hyperparameters['weight_decay'],
  num_train_epochs = best_hyperparameters['epoch'],
  fp16 = (DEVICE == 'cuda'),   # Use 16-bit (mixed) precision instead of 32-bit (ONLY POSSIBLE ON CUDA!)
  optim = "adamw_torch",
  seed=56,                     # Use a seed for reproducibility
  remove_unused_columns=False, # IMPORTANT: if not set to False, the custom data collator will delete 'annotations' column

  # Logging, Saving, Evaluating
  logging_strategy="no",
  save_strategy="epoch", save_total_limit=1, save_safetensors=True,
  evaluation_strategy="no",
)

# Create a Trainer instance
trainer = DoubleCollatorTrainer(
  model_init=model_init,
  args=training_args,
  train_dataset=grp_datasets_train_valid,
  data_collator=CustomTrainingWordMaskingDataCollator(
    tokenizer=tokenizer,
    mlm_probability=best_hyperparameters['train_masking_prob'],
    phi_masking_proportion=best_hyperparameters['phi_masking_proportion']
  ),
  tokenizer=tokenizer,
)

# Fine-tune the model
trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss


Step,Training Loss


In [None]:
# Save logs for trained model hyperparameters
with open(f'{DIR}/best_model_hyperparameters.txt', 'w+') as f:
  f.write('Hyperparameters:\n' + ',\n'.join([f'\t{key}={value}' for key, value in best_hyperparameters.items()]) + '.\n\n')

In [None]:
# Save tokenizer
tokenizer.save_pretrained(f'{DIR}/tokenizer')

# Close runtime (save compute units)

In [None]:
# Close google colab runtime to save credits
from google.colab import runtime
runtime.unassign()