### This notebook fine-tunes mBERT teacher and students on the WikiANN dataset

The code is heavily inspired by the Hugging Face token classification notebook

__Note__ There are hard coded values that were relevant to me. _You_ need to specify _your_ paths for relevant files, such as where the saved students are.

In [None]:
# Comment/uncomment whether or not you are using COLAB
from google.colab import drive
drive.mount('/content/drive')

In [None]:
! pip install datasets transformers seqeval

In [None]:
import transformers
import numpy as np
import torch

print(transformers.__version__)

In [4]:
SEED = 100
torch.manual_seed(SEED) 
np.random.seed(SEED)

task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "bert-base-multilingual-cased" # mBERT
BATCH_SIZE = 8
# LEARNING_RATE = 1e-4 # Choose one of these LR
LEARNING_RATE = 5e-5 # Choose one of these LR
NUM_TRAIN_EPOCHS = 3

DOING_VALIDATION = False

FINE_TUNE_Teacher = False

TRUNCATION = True # Affects how student model is initialized
if TRUNCATION:
    SUFFIX = "truncated"
else:
    SUFFIX = "random"

if DOING_VALIDATION:
    SIZE_STRING = "VAL" # To denote hyperparameter tuning
else:
    SIZE_STRING = "FULL"

# Teacher model names
# MODEL_NAME = "KBBERT_12" 
# MODEL_NAME = "MBBERT_12" 
# MODEL_NAME = "MBERT_ADAPT" 
# MODEL_NAME = "MBERT_SUPER_ADAPT"

# Student model names
# MODEL_NAME = "KBBERT_6" 
# MODEL_NAME = "MBERT_6" 
# MODEL_NAME = "MBERT_6_ADAPT"
MODEL_NAME = "MBERT_6_SUPER_ADAPT"


if FINE_TUNE_Teacher:
    OUTPUT_DIR = "SUCX_FT_{3}_{0}_{1}_{2}".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)
    MODEL_SAVE_FILE = "SUCX_FT_{3}_{0}_{1}_{2}.pt".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)
    if MODEL_NAME == "MBERT_ADAPT" or MODEL_NAME == "MBERT_SUPER_ADAPT":
      MODEL_STATE = "/content/drive/MyDrive/coding/mlm_GIGA_1990_100_MBERT_ADAPT_truncate.pt" # <--- Specify the file location 
else:
    MODEL_STATE = "/content/drive/MyDrive/coding/distilled_GIGA_FULL_100_MBERT_6_SUPER_ADAPT_truncated.pth" # <--- Specify the file location 
    OUTPUT_DIR = "SUCX_FT_{3}_{0}_{1}_{2}".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)
    MODEL_SAVE_FILE = "SUCX_FT_{3}_{0}_{1}_{2}.pt".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)

In [None]:
print(LEARNING_RATE)

In [6]:
from datasets import load_dataset, load_metric, concatenate_datasets

In [None]:
datasets = load_dataset("wikiann", "en")

In [None]:
print(datasets)

In [9]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))



In [None]:
show_random_elements(datasets["train"])

In [11]:
label_list = datasets["train"].features[f"{task}_tags"].feature.names

In [None]:
print(label_list)

In [13]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [14]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [15]:
label_all_tokens = False

In [16]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

In [19]:
tokenized_datasets = tokenized_datasets.remove_columns(["tokens","ner_tags","langs","spans"])

### Fine-tuning model


In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer, BertConfig

if FINE_TUNE_Teacher:
  if MODEL_NAME == "MBERT_ADAPT" or MODEL_NAME == "MBERT_SUPER_ADAPT":
    model_config = BertConfig.from_pretrained(model_checkpoint)
    model_config.num_labels = len(label_list) 
    model = AutoModelForTokenClassification.from_pretrained(MODEL_STATE,config = model_config)
  else:
    model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))
else:
    model_config = BertConfig.from_pretrained(model_checkpoint)
    model_config.num_labels = len(label_list)
    model_config.num_hidden_layers = 6
    model = AutoModelForTokenClassification.from_pretrained(MODEL_STATE,config = model_config) 

In [None]:
print(MODEL_NAME)

In [None]:
print(model)

In [None]:
if DOING_VALIDATION:
    TRAIN_DATASET = tokenized_datasets['train']
    EVAL_DATASET = tokenized_datasets['validation']
    SAVE_STRATEGY = "no"
else:
    TRAIN_DATASET = concatenate_datasets([tokenized_datasets['train'], tokenized_datasets['validation']])
    EVAL_DATASET = tokenized_datasets['test']
    SAVE_STRATEGY = "epoch"

print(TRAIN_DATASET)
print(EVAL_DATASET)

In [None]:
from transformers import TrainingArguments
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
from torch.utils.data import DataLoader


train_dataloader =  DataLoader(TRAIN_DATASET, shuffle=True, batch_size=BATCH_SIZE)
num_training_steps = len(train_dataloader) * NUM_TRAIN_EPOCHS
del train_dataloader

optimizer = AdamW(model.parameters(), lr= LEARNING_RATE)
scheduler_class = get_linear_schedule_with_warmup(optimizer,num_warmup_steps = 0.1*num_training_steps, num_training_steps = num_training_steps)

args = TrainingArguments(
    f"{MODEL_NAME}-finetuned-{task}",
    evaluation_strategy = "epoch",
    save_strategy = SAVE_STRATEGY,
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    weight_decay=0.01,
)

In [26]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

In [27]:
metric = load_metric("seqeval")

In [28]:
import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [29]:
trainer = Trainer(
    model,
    args,
    train_dataset = TRAIN_DATASET,
    eval_dataset = EVAL_DATASET,
    data_collator = data_collator,
    tokenizer = tokenizer,
    compute_metrics=compute_metrics,
    optimizers = (optimizer,scheduler_class)
)

In [None]:
trainer.train()