In [None]:
!pip install datasets transformers

In [None]:
# some may be extra
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import AutoModel, AutoConfig
from tokenizers import normalizers
import torch
from torch import nn
from torch.optim import AdamW
import torch.nn.functional as F
from torch.nn.functional import sigmoid
from torch.utils.data.dataset import Dataset
import numpy as np
from collections import Counter, defaultdict
import scipy

In [None]:
model_checkpoint = "you/your-model"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True, add_prefix_space=True)

In [None]:
def make_last_subtoken_mask(mask, has_cls=True, has_eos=True):
    if has_cls:
        mask = mask[1:]
    if has_eos:
        mask = mask[:-1]
    is_last_word = list((first != second) for first, second in zip(mask[:-1], mask[1:])) + [True]
    if has_cls:
        is_last_word = [False] + is_last_word
    if has_eos:
        is_last_word.append(False)
    return is_last_word

In [None]:
# collecting tasks from train subset
all_tasks = {"POS"}
with open('train.conllu', "r", encoding="utf8") as fin:
        for line in fin:
            line = line.strip()
            if line == "" or not line[0].isdigit():
                continue
            splitted = line.split("\t")
            feats = splitted[5]
            if feats != "_":
                for feat in feats.split("|"):
                        key, _ = feat.split("=")
                        all_tasks.add(key)
task_names = sorted(all_tasks)

In [None]:
# reading conll-u
def read_mt(infile):
    answer, sentences = [], []
    with open(infile, "r", encoding="utf8") as fin:
        sent = []
        labels = {task: [] for task in task_names}
        for line in fin:
            line = line.strip()
            if line == "":
                if sent:
                    answer.append({"words": sent, "labels": {k: v[:] for k, v in labels.items()}})
                sent = []
                labels = {task: [] for task in task_names}
                continue

            splitted = line.split("\t")
            if not splitted[0].isdigit():
                continue
            sent.append(splitted[1])
            pos_tag, feats = splitted[3], splitted[5]
            labels["POS"].append(pos_tag)
            feats_dict = {}
            if feats != "_":
                for feat in feats.split("|"):
                        key, val = feat.split("=")
                        feats_dict[key] = val

            for task in task_names:
                if task != "POS" and task!= "UD-feats":
                    labels[task].append(feats_dict.get(task, "None"))

        if sent: #processing the last string
            answer.append({"words": sent, "labels": {k: v[:] for k, v in labels.items()}})
    return answer


In [None]:
train_mt = read_mt('train.conllu')
for k, v in train_mt[1].items(): # to see how the data looks after being read
    print(k, v)
eval_mt = read_mt('dev.conllu')
test_mt = read_mt('test.conllu')

In [None]:
class MultiTaskUDDataset(Dataset):
    def __init__(self, data, tokenizer, min_count=1, tags=None): 
        self.data = data
        self.tokenizer = tokenizer
        self.ignore_index = -100
        # extracting tasks from the first item
        self.tasks = list(data[0]["labels"].keys())

        # tag dictionary for each task
        self.tags_ = {}
        self.tag_indexes_ = {}
        for task in self.tasks:
            if tags is None or task not in tags:
                tag_counts = Counter([label for item in data for label in item["labels"][task]])
                task_tags = [x for x, count in tag_counts.items() if count >= min_count]
            else:
                task_tags = tags[task]
            self.tags_[task] = task_tags
            self.tag_indexes_[task] = {tag: i for i, tag in enumerate(task_tags)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
      item = self.data[index]
      tokenization = self.tokenizer(item["words"], is_split_into_words=True)
      last_subtoken_mask = make_last_subtoken_mask(tokenization.word_ids())
      input_ids = tokenization["input_ids"]
      answer = {"input_ids" : input_ids}

      if "labels" in item:
        labels_out = {}
        for task in self.tasks:
          labels = [self.tag_indexes_[task][label] for label in item["labels"][task]]
          zero_labels = np.array([self.ignore_index] * len(input_ids), dtype=int) # -100 initialization
          zero_labels[last_subtoken_mask] = labels # labels are assigned to the last subtokens
          labels_out[task] = zero_labels
        answer["labels"] = labels_out
      return answer

In [None]:
train_ds = MultiTaskUDDataset(train_mt, tokenizer=tokenizer)
eval_ds = MultiTaskUDDataset(eval_mt, tokenizer=tokenizer, tags = train_ds.tags_)
test_ds = MultiTaskUDDataset(test_mt, tokenizer=tokenizer, tags=train_ds.tags_)

In [None]:
# to see how the data has changed
print(eval_mt[2])
for k, v in eval_ds[2].items():
    print(k)
    if type(v) == dict:
        for kk, vv in v.items():
          print(kk, vv)
    else:
        print(v)

In [None]:
num_labels_dict = {task: len(train_ds.tags_[task]) for task in train_ds.tasks}

print(num_labels_dict)
print(train_ds.tag_indexes_)

In [None]:
class MultiTaskDataCollator(DataCollatorWithPadding):
    def __call__(self, features):
        labels_dict = {}
        for task_name in features[0]["labels"]:
            labels_dict[task_name] = []
        for feature in features:
            for task_name, task_labels in feature.pop("labels").items():
                labels_dict[task_name].append(task_labels)

        # input_ids, attention_mask, token_type_ids padding
        batch = super().__call__(features)

        # label padding for each task
        if labels_dict is not None:
            batch_labels = {}
            max_length = batch["input_ids"].shape[1]
            for task_name, task_labels in labels_dict.items():
                padded_task_labels = []
                for label in task_labels:
                    label = np.array(label)
                    padding_length = max_length - label.shape[0]
                    if padding_length > 0:
                        padded_label = np.pad(label, (0, padding_length), constant_values=-100)
                    else:
                        padded_label = label
                    padded_task_labels.append(padded_label)
                batch_labels[task_name] = torch.tensor(
                    np.array(padded_task_labels), dtype=torch.long
                )
            batch["labels"] = batch_labels

        return batch


In [None]:
class AutoModelForMultiTaskTokenClassification(nn.Module):
    def __init__(self, model_name, num_labels_dict):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        hidden_size = self.model.config.hidden_size
        # a classifier for each task
        self.classifiers = nn.ModuleDict({
            task: nn.Linear(hidden_size, num_labels)
            for task, num_labels in num_labels_dict.items()
        })
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        sequence_output = outputs.last_hidden_state  # (batch, seq_len, hidden)
        logits, losses = {}, {}

        for task, classifier in self.classifiers.items():
            task_logits = classifier(sequence_output)  # (batch, seq, num_labels)
            logits[task] = task_logits

            if labels is not None and task in labels:
                # CrossEntropyLoss <- (batch*seq, num_labels)
                loss = self.loss_fct(
                     task_logits.view(-1, task_logits.size(-1)),
                     labels[task].view(-1)
                )
                losses[task] = loss

        if labels is not None:
            total_loss = sum(losses.values())/len(losses.keys())
            return {
                "loss": total_loss,
                "logits": logits
            }

        return {"logits": logits}

In [None]:
def compute_metrics_with_tasks(eval_pred):
    preds, labels = eval_pred.predictions, eval_pred.label_ids
    task_names = preds.keys()

    correct, total, seq_correct = 0, 0, 0
    task_correct = {t: 0 for t in task_names}
    task_total   = {t: 0 for t in task_names}
    task_seq_correct = {t: 0 for t in task_names}

    batch_size = list(preds.values())[0].shape[0]
    seq_len = list(preds.values())[0].shape[1]

    for i in range(batch_size):
        is_correct_seq = True
        is_correct_seq_by_task = {t: True for t in task_names}
        for t in range(seq_len):
            is_correct_token, is_real_token = True, False
            for task in task_names:
                task_logits = preds[task]
                task_labels = labels[task]
                label = task_labels[i, t]
                if label == -100:
                    continue
                else:
                    is_real_token = True
                pred = np.argmax(task_logits[i, t])
                task_total[task] += 1
                if pred != label:
                    is_correct_token, is_correct_seq = False, False
                    is_correct_seq_by_task[task] = False
                else:
                    task_correct[task] += 1

            if is_real_token:
                total += 1
                correct += int(is_correct_token)
        # sentence-level
        seq_correct += int(is_correct_seq)
        for task in task_names:
                task_seq_correct[task] += int(is_correct_seq_by_task[task])

    metrics = {}
    metrics["token_acc"] = 100 * correct / total
    metrics["sent_acc"] = 100 * seq_correct / batch_size

    for task in task_names:
        metrics[f"{task}_acc"] = 100 * task_correct[task] / task_total[task]

    return metrics

In [None]:
model = AutoModelForMultiTaskTokenClassification(model_checkpoint, num_labels_dict)

In [None]:
training_args = TrainingArguments(
    num_train_epochs=5,
    learning_rate = 5e-5,
    eval_strategy = 'steps',
    eval_steps = 200,
    weight_decay = 0.01,
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    report_to="none",
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    optimizers=(AdamW(model.parameters(), lr=5e-5, weight_decay=0.01), None),
    compute_metrics=compute_metrics_with_tasks,
    data_collator=MultiTaskDataCollator(tokenizer),
)

In [None]:
trainer.train()

In [None]:
# to fully print the metrics at the end
predictions_train = trainer.predict(train_ds)
predictions_eval = trainer.predict(eval_ds)
predictions_test = trainer.predict(test_ds)

rows = []
for (tr_k, tr_v), (v_k, v_v), (t_k, t_v) in zip(predictions_train.metrics.items(), predictions_eval.metrics.items(), predictions_test.metrics.items()):
    rows.append({"task": tr_k, "train": tr_v, "eval": v_v, "test": t_v})
df = pd.DataFrame(rows)[1:-3]
print(df)