In [None]:
!pip install transformers[sentencepiece] tokenizers datasets seqeval

In [None]:
import numpy as np
from datasets import load_dataset, load_metric
from transformers import AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification, AutoTokenizer, DataCollatorWithPadding, DataCollatorForTokenClassification, TrainingArguments, Trainer
import os
import shutil
from tqdm.auto import tqdm
import collections

checkpoint = "google/electra-small-discriminator"
# checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
label_list = load_dataset("conll2003")["train"].features["chunk_tags"].feature.names

In [None]:
# GLUE: https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb
GLUE_TASKS_1 = [  # AutoModelForSequenceClassification
    ("cola", 2, None),
]

# QA: https://github.com/huggingface/notebooks/blob/master/examples/question_answering.ipynb
SQUAD_TASKS = [ # AutoModelForQuestionAnswering
    ("squad_v2", None, "cola")
]

GLUE_TASKS_2 = [  # AutoModelForSequenceClassification
    ("sst2", 2, "squad_v2"),
    ("qqp", 2, "sst2"),
    ("rte", 2, "qqp"),
]

# Token Classification: https://github.com/huggingface/notebooks/blob/master/examples/token_classification.ipynb
CHUCK_TASKS = [ # AutoModelForTokenClassification
    ("chunk", len(label_list), "rte")
]

GLUE_TASKS_3 = [  # AutoModelForSequenceClassification
    ("mrpc", 2, "chunk"),
    ("stsb", 1, "mrpc"),
    ("wnli", 2, "stsb"),
    ("mnli", 3, "wnli"),
]

JERICHO_TASKS = [  # AutoModelForSequenceClassification
    ("npc", 3, "mnli"),
    ("fn", 54, "npc"),
    ("vn", 70, "fn"),
    ("wn", 84, "vn")
]

JERICHO_TASKS_FULL = [  # AutoModelForSequenceClassification
    ("npc_full", 3, "mnli"),
    ("fn_full", 54, "npc_full"),
    ("vn_full", 70, "fn_full"),
    ("wn_full", 84, "vn_full")
]

In [None]:
fn_dataset = load_dataset("csv",
                          data_files={'train': 'fn_train.tsv',
                                      'test': 'fn_test.tsv'},
                          skiprows=1,
                          column_names=['idx', 'label', 'sentence'],
                          delimiter="\t")
print(fn_dataset)
print(fn_dataset['train'][0])

npc_dataset = load_dataset("csv",
                           data_files={'train': 'npc_train.tsv',
                                       'test': 'npc_test.tsv'},
                           skiprows=1,
                           column_names=['idx', 'label', 'sentence'],
                           delimiter="\t")
print(npc_dataset)
print(npc_dataset['train'][0])

vn_dataset = load_dataset("csv",
                          data_files={'train': 'vn_train.tsv',
                                      'test': 'vn_test.tsv'},
                          skiprows=1,
                          column_names=['idx', 'label', 'sentence'],
                          delimiter="\t")
print(vn_dataset)
print(vn_dataset['train'][0])

wn_dataset = load_dataset("csv",
                          data_files={'train': 'wn_train.tsv',
                                      'test': 'wn_test.tsv'},
                          skiprows=1,
                          column_names=['idx', 'label', 'sentence'],
                          delimiter="\t")
print(wn_dataset)
print(wn_dataset['train'][0])

In [None]:
full_fn_dataset = load_dataset("csv",
                          data_files={'train': 'fn_full.tsv'},
                          skiprows=1,
                          column_names=['idx', 'label', 'sentence'],
                          delimiter="\t")
print(full_fn_dataset)
print(full_fn_dataset['train'][0])

full_npc_dataset = load_dataset("csv",
                           data_files={'train': 'npc_full.tsv'},
                           skiprows=1,
                           column_names=['idx', 'label', 'sentence'],
                           delimiter="\t")
print(full_npc_dataset)
print(full_npc_dataset['train'][0])

full_vn_dataset = load_dataset("csv",
                          data_files={'train': 'vn_full.tsv'},
                          skiprows=1,
                          column_names=['idx', 'label', 'sentence'],
                          delimiter="\t")
print(full_vn_dataset)
print(full_vn_dataset['train'][0])

full_wn_dataset = load_dataset("csv",
                          data_files={'train': 'wn_full.tsv'},
                          skiprows=1,
                          column_names=['idx', 'label', 'sentence'],
                          delimiter="\t")
print(full_wn_dataset)
print(full_wn_dataset['train'][0])

In [None]:
%env TOKENIZERS_PARALLELISM=false

In [None]:
for task, num_labels, prev in GLUE_TASKS_1 + SQUAD_TASKS + GLUE_TASKS_2 + CHUCK_TASKS + GLUE_TASKS_3 + JERICHO_TASKS:
    if task == "squad_v2":
        model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)
        if prev is not None:
            model.electra = AutoModel.from_pretrained(f"{prev}-trainer")
    elif task == "chunk":
        model = AutoModelForTokenClassification.from_pretrained(checkpoint, num_labels=num_labels)
        if prev is not None:
            model.electra = AutoModel.from_pretrained(f"{prev}-trainer")
    else:
        model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)
        if prev is not None:
            model.electra = AutoModel.from_pretrained(f"{prev}-trainer")

    # print(model)
    if task == "squad_v2":
        raw_datasets = load_dataset("squad_v2")
    elif task == "chunk":
        raw_datasets = load_dataset("conll2003")
    elif task == "fn":
        raw_datasets = fn_dataset
    elif task == "npc":
        raw_datasets = npc_dataset
    elif task == "vn":
        raw_datasets = vn_dataset
    elif task == "wn":
        raw_datasets = wn_dataset
    else:
        raw_datasets = load_dataset("glue", task)


    # print(raw_datasets)

    def tokenize_function(example):
        if task == "cola" or task == "sst2":
            return tokenizer(example["sentence"], truncation=True, stride=128)
        if task == "qqp":
            return tokenizer(example["question1"], example["question2"], truncation=True, stride=128)
        if task == "rte" or task == "mrpc" or task == "stsb" or task == "wnli":
            return tokenizer(example["sentence1"], example["sentence2"], truncation=True, stride=128)
        if task == "mnli":
            return tokenizer(example["premise"], example["hypothesis"], truncation=True, stride=128)
        if task == "chunk":
            tokenized_inputs = tokenizer(example["tokens"], is_split_into_words=True, truncation=True, stride=128)

            labels = []
            for i, label in enumerate(example["chunk_tags"]):
                word_ids = tokenized_inputs.word_ids(batch_index=i)
                label_ids = []
                for word_idx in word_ids:
                    if word_idx is None:
                        label_ids.append(-100)
                    else:
                        label_ids.append(label[word_idx])

                labels.append(label_ids)

            tokenized_inputs["labels"] = labels
            return tokenized_inputs
        if task == "squad_v2":
            tokenized_examples = tokenizer(
                example["question" if tokenizer.padding_side == "right" else "context"],
                example["context" if tokenizer.padding_side == "right" else "question"],
                truncation="only_second" if tokenizer.padding_side == "right" else "only_first",
                stride=128,
                return_overflowing_tokens=True,
                return_offsets_mapping=True,
                padding="max_length"
            )

            sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
            offset_mapping = tokenized_examples.pop("offset_mapping")

            tokenized_examples["start_positions"] = []
            tokenized_examples["end_positions"] = []

            for i, offsets in enumerate(offset_mapping):
                input_ids = tokenized_examples["input_ids"][i]
                cls_index = input_ids.index(tokenizer.cls_token_id)

                sequence_ids = tokenized_examples.sequence_ids(i)

                sample_index = sample_mapping[i]
                answers = example["answers"][sample_index]
                if len(answers["answer_start"]) == 0:
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                else:
                    start_char = answers["answer_start"][0]
                    end_char = start_char + len(answers["text"][0])

                    token_start_index = 0
                    while sequence_ids[token_start_index] != (1 if tokenizer.padding_side == "right" else 0):
                        token_start_index += 1

                    token_end_index = len(input_ids) - 1
                    while sequence_ids[token_end_index] != (1 if tokenizer.padding_side == "right" else 0):
                        token_end_index -= 1

                    if not (
                            offsets[token_start_index][0] <= start_char
                            and offsets[token_end_index][1] >= end_char
                    ):
                        tokenized_examples["start_positions"].append(cls_index)
                        tokenized_examples["end_positions"].append(cls_index)
                    else:
                        while (
                                token_start_index < len(offsets)
                                and offsets[token_start_index][0] <= start_char
                        ):
                            token_start_index += 1
                        tokenized_examples["start_positions"].append(token_start_index - 1)
                        while offsets[token_end_index][1] >= end_char:
                            token_end_index -= 1
                        tokenized_examples["end_positions"].append(token_end_index + 1)

            return tokenized_examples
        return tokenizer(example["sentence"], truncation=True)


    if task == "chunk":
        data_collator = DataCollatorForTokenClassification(tokenizer)
    else:
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True, remove_columns=raw_datasets[
        "train"].column_names if task == "squad_v2" else None)
    
    # print(tokenized_datasets['train'][3])
    
    training_args = TrainingArguments(f"{task}-trainer",
                                      overwrite_output_dir=True,
                                      optim="adamw_torch",
                                      learning_rate=1e-4,
                                      weight_decay=0.01,
                                      warmup_ratio=0.1,
                                      adam_epsilon=1e-6,
                                      num_train_epochs=10.0,
                                      save_strategy="epoch",
                                      evaluation_strategy="epoch",
                                      
                                      # Debug
                                      #save_steps=2,
                                      #eval_steps =2,
                                      #save_strategy="steps",
                                      #evaluation_strategy="steps",
                                      #max_steps=4,
                                      
                                      save_total_limit=1,
                                      load_best_model_at_end=True,
                                      per_device_train_batch_size=32,
                                      per_device_eval_batch_size=32)

    def compute_metrics(eval_preds):
        if task == "squad_v2":
            metric = load_metric("squad_v2")
        elif task == "chunk":
            metric = load_metric("seqeval")
        elif task == "npc" or task == "vn" or task == "wn" or task == "fn":
            metric = load_metric("glue", "mnli")  # Simple Accuracy
        else:
            metric = load_metric("glue", task)

        logits, labels = eval_preds
        if task == "stsb":
            predictions = logits[:, 0]
        elif task == "chunk":
            predictions = np.argmax(logits, axis=2)
        else:
            predictions = np.argmax(logits, axis=1)
            
        if task == "chunk":
            true_predictions = [
                [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
            true_labels = [
                [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
            results = metric.compute(predictions=true_predictions, references=true_labels)
            return {
                "precision": results["overall_precision"],
                "recall": results["overall_recall"],
                "f1": results["overall_f1"],
                "accuracy": results["overall_accuracy"],
            }
        else:
            return metric.compute(predictions=predictions, references=labels)


    trainer = Trainer(
        model,
        training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation_matched" if task == "mnli" else ("test" if task == "npc" or task == "vn" or task == "wn" or task == "fn" else "validation")],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=None if task == "squad_v2" else compute_metrics
    )

    trainer.train()
    trainer.model.electra.save_pretrained(f"{task}-trainer")
    if prev is not None and prev != "mnli" and os.path.exists(f"{prev}-trainer"):
        try:
            shutil.rmtree(f"{prev}-trainer")
        except OSError as e:
            print("Error: %s - %s." % (e.filename, e.strerror))

In [None]:
for task, num_labels, prev in JERICHO_TASKS_FULL:
    
    model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels)
    model.electra = AutoModel.from_pretrained(f"{prev}-trainer")

    # print(model)
    if task == "fn_full":
        raw_datasets = full_fn_dataset
    elif task == "npc_full":
        raw_datasets = full_npc_dataset
    elif task == "vn_full":
        raw_datasets = full_vn_dataset
    else:
        raw_datasets = full_wn_dataset

    def tokenize_function(example):
        return tokenizer(example["sentence"], truncation=True)

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
    
    training_args = TrainingArguments(f"{task}-trainer",
                                      overwrite_output_dir=True,
                                      optim="adamw_torch",
                                      learning_rate=1e-4,
                                      weight_decay=0.01,
                                      warmup_ratio=0.1,
                                      adam_epsilon=1e-6,
                                      num_train_epochs=10.0,
                                      save_strategy="no",
                                      evaluation_strategy="no",
                                      
                                      # Debug
                                      #save_steps=2,
                                      #eval_steps =2,
                                      #max_steps=4,
                                      
                                      save_total_limit=1,
                                      load_best_model_at_end=True,
                                      per_device_train_batch_size=32,
                                      per_device_eval_batch_size=32)


    def compute_metrics(eval_preds):
        metric = load_metric("glue", "mnli")  # Simple Accuracy
        logits, labels = eval_preds
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)


    trainer = Trainer(
        model,
        training_args,
        train_dataset=tokenized_datasets["train"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    trainer.model.electra.save_pretrained(f"{task}-trainer")
    if prev != "npc_full" and os.path.exists(f"{prev}-trainer"):
        try:
            shutil.rmtree(f"{prev}-trainer")
        except OSError as e:
            print("Error: %s - %s." % (e.filename, e.strerror))

In [None]:
if_model = AutoModel.from_pretrained("wn_full-trainer")
tokenizer.save_pretrained("if_model")
if_model.save_pretrained("if_model")

from transformers import TFAutoModel

tf_model = TFAutoModel.from_pretrained("wn_full-trainer", from_pt=True)
tf_model.save_pretrained("if_model")

# from transformers import FlaxAutoModel
# 
# fx_model = FlaxAutoModel.from_pretrained("wn-trainer", from_pt=True)
# fx_model.save_pretrained("if-model")

if_model = AutoModel.from_pretrained("npc_full-trainer")
tokenizer.save_pretrained("npc_model")
if_model.save_pretrained("npc_model")

tf_model = TFAutoModel.from_pretrained("npc_full-trainer", from_pt=True)
tf_model.save_pretrained("npc_model")

try:
    shutil.rmtree("wn-trainer")
    shutil.rmtree("wn_full-trainer")
    shutil.rmtree("npc_full-trainer")
    shutil.rmtree("mlruns")
except OSError as e:
    print("Error: %s - %s." % (e.filename, e.strerror))

In [None]:
!zip if_model.zip -r if_model
!zip npc_model.zip -r npc_model