### This notebook fine-tunes student or teacher models on the SUCX 3.0 dataset.
The code is heavily inspired by the Hugging Face token classification notebook

__Note__ There are hard coded values that were relevant to me. _You_ need to specify _your_ paths for relevant files, such as where the saved students are.

In [None]:
!pip install transformers
!pip install datasets

In [None]:
# Comment/uncomment whether or not you are using COLAB
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install seqeval

In [None]:
import torch
import numpy as np
print(torch.cuda.is_available())
SEED = 100 # Seeds to be used: 100, 101, 102
torch.manual_seed(SEED) 
np.random.seed(SEED)

#### Constants, remember to choose correct values

In [2]:
FINE_TUNE_Teacher = False

TRUNCATION = True # Affects how student model is initialized
if TRUNCATION:
    SUFFIX = "truncated"
else:
    SUFFIX = "random"

DOING_VALIDATION = False
if DOING_VALIDATION:
    SIZE_STRING = "VAL" # To denote hyperparameter tuning
else:
    SIZE_STRING = "FULL"

# Teacher model names ------
# MODEL_NAME = "KBBERT_12" 
# MODEL_NAME = "MBERT_12" 
# MODEL_NAME = "MBERT_ADAPT"
# MODEL_NAME = "MBERT_SUPER_ADAPT"

# Student model names ------
# MODEL_NAME = "KBBERT_6" # KBBERT_6 when training student from scratch, or when KB-BERT is used for task-specific distillation
# MODEL_NAME = "MBERT_6"  # MBERT_6 when student is going to be task-distilled from MBERT
MODEL_NAME = "MBERT_6_ADAPT" 
# MODEL_NAME = "MBERT_6_SUPER_ADAPT"

# Full means training on entire train + val data
if FINE_TUNE_Teacher:
    OUTPUT_DIR = "SUCX_FT_{3}_{0}_{1}_{2}".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)
    MODEL_SAVE_FILE = "SUCX_FT_{3}_{0}_{1}_{2}.pt".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)
    if MODEL_NAME == "MBERT_ADAPT" or MODEL_NAME == "MBERT_SUPER_ADAPT":
      MODEL_STATE = "/content/drive/MyDrive/coding/mlm_GIGA_1990_100_MBERT_ADAPT_truncate.pt" # <--- SPECIFY the file location 
else:
    MODEL_STATE = "/content/drive/MyDrive/coding/distilled_GIGA_FULL_100_MBERT_6_ADAPT_truncated.pth" # <--- SPECIFY the file location 
    OUTPUT_DIR = "SUCX_FT_{3}_{0}_{1}_{2}".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)
    MODEL_SAVE_FILE = "SUCX_FT_{3}_{0}_{1}_{2}.pt".format(SEED,MODEL_NAME,SUFFIX,SIZE_STRING)

In [None]:
print("SEED = {0}, MODEL_NAME = {1}".format(SEED,MODEL_NAME))
print("OUTPUT_DIR = {0}, MODEL_SAVE_FILE = {1}".format(OUTPUT_DIR,MODEL_SAVE_FILE))

In [None]:
print(MODEL_STATE)

#### Training arguments

In [4]:
NUM_EPOCHS = 3
WEIGHT_DECAY = 0.01
LEARNING_RATE = 5e-5 # main candidates: 1e-4, 5e-5
BATCH_SIZE = 8

In [None]:
print(LEARNING_RATE)

#### Download data

In [None]:
import torch
from datasets import load_dataset

if DOING_VALIDATION:
    raw_dataset_train, raw_dataset_val = load_dataset('KBLab/sucx3_ner','original_cased', split=['train', 'validation'])
# Note the merging of train and validation
else:
    raw_dataset_train, raw_dataset_val = load_dataset('KBLab/sucx3_ner','original_cased', split=['train+validation', 'test'])

In [None]:
print(raw_dataset_train)
print(raw_dataset_val)

In [None]:
label2id = {'B-animal': 0, 'I-person': 1, 'B-other': 2, 'I-inst': 3, 'O': 4, 'I-myth': 5, 'I-event': 6, 'I-other': 7, 'I-product': 8, 'B-event': 9, 'B-place': 10, 'I-animal': 11, 'B-myth': 12, 'I-work': 13, 'B-person': 14, 'B-work': 15, 'B-product': 16, 'B-inst': 17, 'I-place': 18}
id2label = {0: 'B-animal', 1: 'I-person', 2: 'B-other', 3: 'I-inst', 4: 'O', 5: 'I-myth', 6: 'I-event', 7: 'I-other', 8: 'I-product', 9: 'B-event', 10: 'B-place', 11: 'I-animal', 12: 'B-myth', 13: 'I-work', 14: 'B-person', 15: 'B-work', 16: 'B-product', 17: 'B-inst', 18: 'I-place'}
label_names = []
for k, v in id2label.items():
    label_names.append(v)
ner_tag_length = len(id2label)
print(label_names)

#### Choose tokenizer

In [9]:
from transformers import AutoTokenizer

if "KBBERT" in MODEL_NAME:
    teacher_finetuned_name = "KB/bert-base-swedish-cased" 
elif "MBERT" in MODEL_NAME:
    teacher_finetuned_name = "bert-base-multilingual-cased"
else:
    raise Exception("Undefined MODEL_NAME {0}".format(MODEL_NAME))

tokenizer = AutoTokenizer.from_pretrained(teacher_finetuned_name)

In [10]:
### We do dynamic padding, use data_collator for trainer
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [11]:
def tokenize_and_align_labels(examples):

    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []

    for i, label in enumerate(examples[f"ner_tags"]):

        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.

        previous_word_idx = None

        label_ids = []

        for word_idx in word_ids:  # Set the special tokens to -100.

            if word_idx is None:

                label_ids.append(-100)

            elif word_idx != previous_word_idx:  # Only label the first token of a given word.

                label_ids.append(label2id[label[word_idx]]) # Modified to turn it into an int

            else:

                label_ids.append(-100)

            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels

    return tokenized_inputs

In [None]:
tokenized_suc_train = raw_dataset_train.map(tokenize_and_align_labels, batched=True, remove_columns=raw_dataset_train.column_names)
if DOING_VALIDATION:
    tokenized_suc_val = raw_dataset_val.map(tokenize_and_align_labels, batched=True, remove_columns=raw_dataset_train.column_names)
else:
    tokenized_suc_val = raw_dataset_val.map(tokenize_and_align_labels, batched=True, remove_columns=raw_dataset_train.column_names)

In [None]:
print(tokenized_suc_train)
print(tokenized_suc_val)

#### Metric code

In [14]:
from datasets import load_metric

metric = load_metric("seqeval")

In [15]:
import numpy as np


def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

In [None]:
# Load the models
# For the 6-layer student variants and further pre-trained mBERT_12 a state file must be provided

from transformers import AutoModelForTokenClassification, BertConfig

if MODEL_NAME == "KBBERT_12" or MODEL_NAME == "MBERT_12" :
    model = AutoModelForTokenClassification.from_pretrained(
    teacher_finetuned_name,
    id2label=id2label,
    label2id=label2id,
    )
elif MODEL_NAME == "MBERT_ADAPT":
    model_config = BertConfig.from_pretrained(teacher_finetuned_name)
    model_config.num_labels = ner_tag_length 
    model = AutoModelForTokenClassification.from_pretrained(MODEL_STATE,config = model_config)
else:
    print("student")
    model_config = BertConfig.from_pretrained(teacher_finetuned_name)
    model_config.num_labels = ner_tag_length 
    model_config.num_hidden_layers = 6
    model = AutoModelForTokenClassification.from_pretrained(MODEL_STATE,config = model_config) 


### Setup trainer

In [None]:
from transformers import TrainingArguments
from transformers import get_linear_schedule_with_warmup
from transformers import AdamW
from torch.utils.data import DataLoader


train_dataloader =  DataLoader(tokenized_suc_train, shuffle=True, batch_size=BATCH_SIZE)
num_training_steps = len(train_dataloader) * NUM_EPOCHS
del train_dataloader

optimizer = AdamW(model.parameters(), lr= LEARNING_RATE)
scheduler_class = get_linear_schedule_with_warmup(optimizer,num_warmup_steps = 0.1*num_training_steps, num_training_steps = num_training_steps)


args = TrainingArguments(
    OUTPUT_DIR,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate= LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS ,
    weight_decay= WEIGHT_DECAY,
    per_device_train_batch_size = BATCH_SIZE, 
    per_device_eval_batch_size = BATCH_SIZE,
)

In [18]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset = tokenized_suc_train,
    eval_dataset = tokenized_suc_val,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    optimizers = (optimizer,scheduler_class)
)

In [None]:
trainer.train()