In [1]:
import datasets
from datasets import Sequence
from datasets import ClassLabel
# from hfmtl.tasks.sequence_classification import SequenceClassification
# from hfmtl.tasks.token_classification import TokenClassification
# from hfmtl.utils import *
# from hfmtl.models import *

from PYEVALB.scorer import Scorer
from PYEVALB.summary import summary

from codelin.models.const_tree import C_Tree
from codelin.models.const_label import C_Label
from codelin.models.linearized_tree import LinearizedTree
from codelin.encs.constituent import *
from codelin.utils.constants import *

import easydict
from chrono import Timer
from frozendict import frozendict
import os
import torch
import pandas as pd

import logging


# Set logging level
'''
Train the models in multi-task learning fashion. To do this
we will split the fields of the label and train different
tasks according to this. After training, we will evaluate
the decoded trees by re-joining the labels.
'''

ptb_path = "~/Treebanks/const/PENN_TREEBANK/"
ptb_path = os.path.expanduser(ptb_path)

with open(os.path.join(ptb_path,"test.trees")) as f:
    ptb_test = [l.rstrip() for l in f.read().splitlines()]
with open(os.path.join(ptb_path,"dev.trees")) as f:
    ptb_dev = [l.rstrip() for l in f.read().splitlines()]
with open(os.path.join(ptb_path,"train.trees")) as f:
    ptb_train = [l.rstrip() for l in f.read().splitlines()]

def get_n_labels(dsets, tar_field):
    label_set = set()
    for dset in dsets:
        for labels in dset[tar_field]:
            label_set.update(labels)
    label_names = sorted(list(label_set))
    return label_names, len(label_names)

def generate_dataset_from_codelin(train_dset, dev_dset, test_dset=None):
    dsets = [train_dset, dev_dset, test_dset] if test_dset else [train_dset, dev_dset]
    
    l1, nl1 = get_n_labels(dsets, "target_1")
    print("Sample of labels target_1: n_commons:", l1[5:10])
    l2, nl2 = get_n_labels(dsets, "target_2")
    print("Sample of labels target_2: last_common:", l2[5:10])
    l3, nl3 = get_n_labels(dsets, "target_3")
    print("Sample of labels target_3: unary_chain:", l3[5:10])

    train_dset = datasets.Dataset.from_dict(train_dset)
    train_dset = train_dset.cast_column("target_1", Sequence(ClassLabel(num_classes=nl1, names=l1)))
    train_dset = train_dset.cast_column("target_2", Sequence(ClassLabel(num_classes=nl2, names=l2)))
    train_dset = train_dset.cast_column("target_3", Sequence(ClassLabel(num_classes=nl3, names=l3)))

    dev_dset = datasets.Dataset.from_dict(dev_dset)
    dev_dset = dev_dset.cast_column("target_1", Sequence(ClassLabel(num_classes=nl1, names=l1)))
    dev_dset = dev_dset.cast_column("target_2", Sequence(ClassLabel(num_classes=nl2, names=l2)))
    dev_dset = dev_dset.cast_column("target_3", Sequence(ClassLabel(num_classes=nl3, names=l3)))

    if test_dset:
        test_dset = datasets.Dataset.from_dict(test_dset)
        test_dset = test_dset.cast_column("target_1", Sequence(ClassLabel(num_classes=nl1, names=l1)))
        test_dset = test_dset.cast_column("target_2", Sequence(ClassLabel(num_classes=nl2, names=l2)))
        test_dset = test_dset.cast_column("target_3", Sequence(ClassLabel(num_classes=nl3, names=l3)))
    
        # Convert to Hugging Face DatasetDict format
        dataset = datasets.DatasetDict({
                "train": train_dset,
                "validation": dev_dset,
                "test": test_dset
            })
    else:
        # Convert to Hugging Face DatasetDict format
        dataset = datasets.DatasetDict({
                "train": train_dset,
                "validation": dev_dset
            })

    return dataset

def encode_dset(encoder, dset):
    encoded_trees = {"tokens":[], "target_1":[], "target_2":[], "target_3":[]}
    max_len_tree = 0
    for line in dset:
        tree = C_Tree.from_string(line)
        lin_tree = encoder.encode(tree)
        encoded_trees["tokens"].append([w for w in lin_tree.words])
        
        t1,t2,t3 = [],[],[]
        for s1,s2,s3 in lin_tree.get_labels_splitted():
            t1.append(s1)    
            t2.append(s2)
            t3.append(s3)
            
        encoded_trees["target_1"].append(t1)
        encoded_trees["target_2"].append(t2)
        encoded_trees["target_3"].append(t3)
        
        max_len_tree = max(max_len_tree, len(lin_tree.words))
    
    
    return encoded_trees, max_len_tree

def gen_dsets():
    encodings = []

    # naive absolute encodings
    a_enc     = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_absolute", "encoder":a_enc})
    a_br_enc  = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_br", "encoder":a_br_enc})
    a_bl_enc  = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_bl", "encoder":a_bl_enc})
    ar_enc    = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_absolute_r", "encoder":ar_enc})
    ar_br_enc = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_r_br", "encoder":ar_br_enc})
    ar_bl_enc = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_r_bl", "encoder":ar_bl_enc})

    # naive relative encodings
    r_enc     = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_relative", "encoder":r_enc})
    r_br_enc  = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_br", "encoder":r_br_enc})
    r_bl_enc  = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_bl", "encoder":r_bl_enc})
    rr_enc    = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_relative_r", "encoder":rr_enc})
    rr_br_enc = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_r_br", "encoder":rr_br_enc})
    rr_bl_enc = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_r_bl", "encoder":rr_bl_enc})

    # naive dynamic encodings
    d_enc     = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_dynamic", "encoder":d_enc})
    d_br_enc  = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_br", "encoder":d_br_enc})
    d_bl_enc  = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_bl", "encoder":d_bl_enc})
    dr_enc    = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_r", "encoder":dr_enc})
    dr_br_enc = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_r_br", "encoder":dr_br_enc})
    dr_bl_enc = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_r_bl", "encoder":dr_bl_enc})

    # gaps encodings
    g_r_enc   = C_GapsEncoding(separator="[_]", unary_joiner="[+]", binary_direction="R", binary_marker="[b]")
    encodings.append({"name":"gaps_r", "encoder":g_r_enc})
    g_l_enc   = C_GapsEncoding(separator="[_]", unary_joiner="[+]", binary_direction="L", binary_marker="[b]")
    encodings.append({"name":"gaps_l", "encoder":g_l_enc})

    # tetra encodings
    t_pr_enc  = C_Tetratag(separator="[_]", unary_joiner="[+]", mode='preorder',  binary_marker="[b]")
    encodings.append({"name":"tetratag_preorder", "encoder":t_pr_enc})
    t_in_enc  = C_Tetratag(separator="[_]", unary_joiner="[+]", mode='inorder',   binary_marker="[b]")
    encodings.append({"name":"tetratag_inorder", "encoder":t_in_enc})
    t_po_enc  = C_Tetratag(separator="[_]", unary_joiner="[+]", mode='postorder', binary_marker="[b]")
    encodings.append({"name":"tetratag_postorder", "encoder":t_po_enc})

    # yuxtaposed encodings
    j_enc   = C_JuxtaposedEncoding(separator="[_]", unary_joiner="[+]", binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"juxtaposed", "encoder":j_enc})
    j_r_enc = C_JuxtaposedEncoding(separator="[_]", unary_joiner="[+]", binary=True, binary_direction='R',   binary_marker="[b]")
    encodings.append({"name":"juxtaposed_r", "encoder":j_r_enc})
    j_l_enc = C_JuxtaposedEncoding(separator="[_]", unary_joiner="[+]", binary=True, binary_direction='L',   binary_marker="[b]")
    encodings.append({"name":"juxtaposed_l", "encoder":j_l_enc})

    return encodings


In [2]:
from torch import nn

class TokenClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels, dropout_p=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels

        self._init_weights()

    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()

    def forward(self, sequence_output, pooled_output, labels = None, attention_mask = None, **kwargs):
        sequence_output_dropout = self.dropout(sequence_output)
        logits = self.classifier(sequence_output_dropout)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            labels = labels.long()

            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss,
                    labels.view(-1),
                    torch.tensor(loss_fct.ignore_index).type_as(labels),
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return logits, loss

In [3]:
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import numpy as np
import evaluate
from transformers import EvalPrediction
from torch import nn
from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers.data.data_collator import InputDataClass
from types import MappingProxyType
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
import transformers
from transformers.trainer_utils import EvalLoopOutput

class DataLoaderWithTaskname:
    def __init__(self, task_name, data_loader):
        self.task = task_name
        self.data_loader = data_loader
        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)

    def __iter__(self):
        for batch in self.data_loader:
            yield batch

class NLPDataCollator:
    def __init__(self, tasks):
        self.tasks = tasks

    def __call__(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
        features = [{k:v for k,v in x.items() if k!='task_ids'} for x in features]
        return features
class MultitaskDataloader:
    """
    Data loader that combines and samples from multiple single-task
    data loaders.
    """

    def __init__(self, dataloader_dict, p=1):
        self.dataloader_dict = dataloader_dict
        N = max([len(x)**(1-p) for x in dataloader_dict.values()])
        
        f_p = lambda x: int(N*x**p)

        self.num_batches_dict = {
            task_name: f_p(len(dataloader))
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            f_p(len(dataloader.dataset)) for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        """
        For each batch, sample a task, and yield a batch from the respective
        task Dataloader.
        """
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        
        dataloader_iter_dict = {
            task_name: iter(dataloader)
            for task_name, dataloader in self.dataloader_dict.items()
        }

        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])



class MultiTaskModel(nn.Module):
    def __init__(self, encoder_name_or_path, tasks):
        super().__init__()
        
        self.encoder = AutoModel.from_pretrained(encoder_name_or_path)
        tokenizer_kwargs = frozendict(padding="max_length", max_length=128, truncation=True, return_tensors="pt")
        self.tokenizer = AutoTokenizer.from_pretrained(encoder_name_or_path, **tokenizer_kwargs)
        self.output_heads = nn.ModuleDict()
        
        for task in tasks:
            ###############################
            print("[TRN] Creating output head for task", task.name)
            print("      - Task type:", task.task_type)
            print("      - Number of labels:", task.num_labels)
            print("      - Label names:")
            for k, v in task.label_names.items():
                print("            -", k, ":", v)            
            print("[TRN] Example input:")
            sample = task.dataset['train'][0]
            for k, v in sample.items():
                print("      -", k, ":", v)
            print("[TRN] Example input (real labels):")
            sample = task.dataset['train'][0]
            for k, v in sample.items():
                if k in task.label_names.keys():
                    print("      -", [task.label_names[k][vi] for vi in v])
            ###############################
            
            task.set_tokenizer(self.tokenizer)
            for subtask in task.y:
                decoder = self._create_output_head(
                    self.encoder.config.hidden_size, 
                    task.task_type, 
                    task.num_labels[subtask]
                )
                
                self.output_heads[subtask] = decoder

        self.processed_tasks = self.preprocess_tasks(tasks, self.tokenizer)
        self.label_names = {task.name: task.label_names for task in tasks}
        self.train_dataset = {self.processed_tasks[task.name]['train'] for task in tasks}
        self.eval_dataset = {self.processed_tasks[task.name]['validation'] for task in tasks}
        
        print("[TRN] Model has", len(self.output_heads), "output heads")
        print("[TRN] Model has", len(self.train_dataset), "training datasets")
        print("[TRN] Model has", len(self.eval_dataset), "evaluation datasets")
    
    def preprocess_tasks(self, tasks, tokenizer):      
        features_dict = {}
        for i, task in enumerate(tasks):
            print("Model is preprocessing task", task.name)
            
            if hasattr(task, 'processed_features') and tokenizer == task.tokenizer:
                print("==> Task features are already processed, skipping...")
                features_dict[task.name] = task.processed_features
                continue
            
            for split in task.dataset:
                task.index = task.dataset[split].index = i
            
            features_dict[task.name] = {}
            for phase, phase_dataset in task.dataset.items():
                phase_dataset.index = i

                features_dict[task.name][phase] = phase_dataset.map(
                    task.preprocess_function, 
                    batched = True,
                    batch_size = 8,
                    load_from_cache_file = True
                )
            print("Model finished preprocessing task", task.name)
        return features_dict
    
    @staticmethod
    def _create_output_head(encoder_hidden_size: int, task_type, n_labels):
        if task_type == "TokenClassification":
            print("Creating TokenClassification head w/", n_labels, "labels")
            return TokenClassificationHead(encoder_hidden_size, n_labels)
        else:
            raise NotImplementedError()
    
    def forward(self, input_ids = None, attention_mask = None, token_type_ids = None, position_ids = None,
            head_mask = None, inputs_embeds = None, labels = None, task_ids = None, **kwargs):
            
            # compute the transformer output
            outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )
            sequence_output, pooled_output = outputs[:2]

            print("3) Transformer has been forwarded")
            unique_task_ids_list = torch.unique(task_ids).tolist()

            loss_list = []
            logits = None
            # print("Computing loss...")
            # print("task_ids", task_ids)
            print("==> I have to compute loss for the following tasks:")
            print("==>", unique_task_ids_list)
            for unique_task_id in unique_task_ids_list:
                print("Task_id =",unique_task_id)
                ptc_train = self.processed_tasks['train']
                target_cols = [col for col in ptc_train.features if col.startswith("target_")]
                print("target_cols =", target_cols)

                for tc in target_cols:
                    print("Target Column =",tc)
                    print("Labels =",labels)
                    logits, task_loss = self.output_heads[str(unique_task_id)].forward(
                        sequence_output[task_id_filter],
                        pooled_output[task_id_filter],
                        labels = None if labels is None else labels[task_id_filter],
                        attention_mask=attention_mask[task_id_filter],
                    )

                    if labels is not None:
                        loss_list.append(task_loss)

            # Loss averaged over all tasks
            outputs = (logits, outputs[2:])
            if loss_list:
                loss = torch.stack(loss_list)
                outputs = (loss.mean(),) + outputs

            return outputs

class MultiTaskTrainer(transformers.Trainer):
    def __init__(self, tasks, **kwargs):
        super().__init__(**kwargs)
        self.p = 1
        self.processed_tasks = self.model.processed_tasks
        self.label_names = self.model.label_names
        self.train_dataset = {
            task: dataset["train"]
            for task, dataset in self.processed_tasks.items()
        }
        self.eval_dataset = {
            task: dataset["validation"]
            for task, dataset in self.processed_tasks.items()
        }
        self.eval_dataset = MappingProxyType(self.eval_dataset)
        self.tokenizer = self.model.tokenizer
        self.pretrained_transformer = self.model.encoder
        self.device = self.pretrained_transformer.device
        self.data_collator = NLPDataCollator(tasks)
        
        print("[*] Init multitask trainer with tasks:", self.processed_tasks)
        print("[*] Label names are:", self.label_names)
        
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=[]):
        if ignore_keys is None:
            ignore_keys = []

        # loss function returns the loss and the prediction logits
        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
        loss = loss.mean().detach()        
        
        print("[*] Prediction Step...")
        print("    inputs", inputs[0].keys())
        print("    outputs", outputs.keys())
        print("    loss", loss)
        
        
        logits_dict = {}
        labels_dict = {}
        print("[*] Extracting logits and labels...")
        for task_name, label_names in self.label_names.items():
            logits_dict[task_name] = {}
            labels_dict[task_name] = {}
            for label_name in label_names:
                logits_dict[task_name][label_name] = outputs[label_name]
                logits_dict[task_name][label_name] = np.argmax(outputs[label_name].detach().cpu().numpy(), axis=2)
                target_labels = []
                for i in inputs:
                    target_labels.append(i[label_name])
                labels_dict[task_name][label_name] = torch.tensor(target_labels)
        
        print("[*] Prediction step ended:")
        return (loss, logits_dict, labels_dict)
    
    def get_single_train_dataloader(self, task_name, train_dataset):
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        
        train_sampler = (SequentialSampler(train_dataset) if self.args.local_rank == -1 else DistributedSampler(train_dataset))

        data_loader = DataLoaderWithTaskname(
            task_name = task_name,
            data_loader = DataLoader(
                train_dataset,
                batch_size = self.args.train_batch_size,
                shuffle = False,
                sampler = train_sampler,
                collate_fn = self.data_collator.__call__,
            ),
        )

        return data_loader

    def get_train_dataloader(self):
        return MultitaskDataloader(
            {
                task_name: self.get_single_train_dataloader(task_name, task_dataset)
                for task_name, task_dataset in self.train_dataset.items()
            }, p = self.p,
        )
    
    def get_eval_dataloader(self, eval_dataset=None):
        return MultitaskDataloader(
            {
                task_name: self.get_single_train_dataloader(task_name, task_dataset)
                for task_name, task_dataset in (
                    eval_dataset if eval_dataset else self.eval_dataset
                ).items()
            }
        )

    def evaluation_loop(self, dataloader: DataLoader, description: str, 
                        prediction_loss_only: bool | None = None, ignore_keys: List[str] | None = None, 
                        metric_key_prefix: str = "eval_") -> EvalLoopOutput:
        
        print("[*] Evaluation_loop")        
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
        eval_results = {}

        for step, inputs in enumerate(dataloader):
            print("[*] Step", step, "of", len(dataloader), "...")
            print("    inputs", inputs)
            
            loss, preds, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
            
            for task, label_names in self.label_names.items():
                print("[*] Evaluating task:",task)
                preds_task = preds[task]
                labels_task = labels[task]
                
                for label_name, labels_values in label_names.items():
                    print("[*] Computing metrics for subtask", label_name, "...")
                    preds_tl  = preds_task[label_name]
                    labels_tl = labels_task[label_name]
                    
                    eval_pred = EvalPrediction(
                                predictions = preds_tl, 
                                label_ids   = labels_tl, 
                                inputs      = inputs)

                    # compute metrics foreach head using the corresponding task eval_function
                    # i copied the function from the task-specific class to this one
                    metrics = self.compute_metrics_token_classification(eval_pred, label_name)
                    metrics_eval = {}
                    for metric in metrics.items():
                        metrics_eval[metric_key_prefix + "_" + metric[0]] = metric[1]

        return EvalLoopOutput(predictions=preds_tl, label_ids=labels_tl, metrics=metrics_eval, num_samples=len(self.eval_dataset))
    
    def compute_metrics_token_classification(self, eval_pred, label_names):
        predictions, labels = eval_pred.predictions, eval_pred.label_ids
        
        print("Labels shape =>", labels.shape)
        print("Labels sample =>", labels[0])
        
        print("Predictions shape =>", predictions.shape)
        print("Predictions sample =>", predictions[0])
        
        print("Label names =>", self.label_names)
        
        # for task in tasks
        label_names_task = self.label_names['naive_absolute_n_commons']
        true_labels = [
            [label_names_task[label_names][int(l)] for l in label if l != -100] for label in labels
        ]
        
        true_predictions = [
            [label_names_task[label_names][p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        metric = evaluate.load("seqeval")
        all_metrics = metric.compute(
            predictions = true_predictions, 
            references = true_labels
        )
        print("==> all_metrics", all_metrics)
        
        meta = {"name": 'naive_absolute_n_commons', "size": len(predictions), "index": 0}
        metrics = {k.replace("overall_",""):v for k,v in all_metrics.items() if "overall" in k}
        
        return {**metrics, **meta}      

    def compute_loss(self, model, inputs, return_outputs=False):
        print("[*] Computing Loss...")
        keys = inputs[0].keys()

        input_ids = torch.tensor([i['input_ids'] for i in inputs], device=self.args.device) if 'input_ids' in keys else None
        attention_mask = torch.tensor([i['attention_mask'] for i in inputs], device=self.args.device) if 'attention_mask' in keys else None        
        token_type_ids = torch.tensor([i['token_type_ids'] for i in inputs], device=self.args.device) if 'token_type_ids' in keys else None        
        position_ids = torch.tensor([i['position_ids'] for i in inputs], device=self.args.device) if 'position_ids' in keys else None        
        head_mask = torch.tensor([i['head_mask'] for i in inputs], device=self.args.device) if 'head_mask' in keys else None        
        inputs_embeds = torch.tensor([i['inputs_embeds'] for i in inputs], device=self.args.device) if 'inputs_embeds' in keys else None
        
        outputs = self.pretrained_transformer(
            input_ids = input_ids,
            attention_mask = attention_mask,
            token_type_ids = token_type_ids,
            position_ids = position_ids,
            head_mask = head_mask,
            inputs_embeds = inputs_embeds,
        )

        sequence_output, pooled_output = outputs[:2]
        loss_list = []
        logits_list = {}
        
        for i, head in enumerate(self.model.output_heads.values()):
            labels_name = f"target_{i+1}"
            labels_i = torch.tensor([i[labels_name] for i in inputs], device=self.args.device)
            logits, loss = head(sequence_output, pooled_output, labels=labels_i, attention_mask=attention_mask)
            loss_list.append(loss)
            logits_list[labels_name] = logits
        
        loss = torch.stack(loss_list)
        print("    loss", loss)
        return (loss, logits_list) if return_outputs else loss


In [4]:
import numpy as np
from datasets import Dataset
from transformers import DataCollatorForTokenClassification
import evaluate
import funcy as fc
import warnings
from frozendict import frozendict as fdict
from dataclasses import dataclass

@dataclass
class TokenClassification:
    task_type = "TokenClassification"
    name: str = "TokenClassificationTask"
    dataset: Dataset = None
    metric:... = evaluate.load("seqeval")
    main_split: str = "train"
    tokens: str = 'tokens'
    y: str|list = 'target'
    num_labels: int = None
    label_names: dict = None
    tokenizer_kwargs: fdict = fdict(padding="max_length", max_length=128, truncation=True)

    @staticmethod
    def _align_labels_with_tokens(labels, word_ids):
        new_labels = []
        current_word = None
        for word_id in word_ids:
            if word_id is None:
                new_labels.append(-100)

            elif word_id != current_word:
                current_word = word_id
                label = -100 if word_id is None else labels[word_id]
                new_labels.append(label)
            
            else:
                label = labels[word_id]
                new_labels.append(label)
        
        return new_labels

    def __post_init__(self):
        self.label_names = {}
        self.num_labels  = {}

        for y in self.y:
            target = self.dataset[self.main_split].features[y]
            self.num_labels[y] = target.feature.num_classes
            self.label_names[y] = target.feature.names if target.feature.names else [None]
        
        print(f"Task loaded {self.task_type} task with {self.num_labels} labels")
        for k,v in self.label_names.items():
            print(f"      {k} labels: {v}")

    def get_labels(self):
        return super().get_labels() or self.label_names

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer
        self.tokenizer.add_prefix_space = True
        self.data_collator = DataCollatorForTokenClassification(
            tokenizer = self.tokenizer
        )

    def preprocess_function(self, examples):
        if examples[self.tokens] and type(examples[self.tokens][0]) == str:
            unsqueeze, examples = True, {k:[v] for k,v in examples.items()}
        
        def get_len(outputs):
            try:
                return len(outputs[fc.first(outputs)])
            except:
                return 1
        
        tokenized_inputs = self.tokenizer(
            examples[self.tokens],
            is_split_into_words=True,
            **self.tokenizer_kwargs
        )

        for target_column in self.y:
            all_labels = examples[target_column]
            new_labels = []
            
            for i, labels in enumerate(all_labels):
                word_ids = tokenized_inputs.word_ids(i)
                new_labels.append(self._align_labels_with_tokens(labels, word_ids))
            
            tokenized_inputs[target_column] = new_labels        
            tokenized_inputs['task_ids'] = [self.index]*get_len(tokenized_inputs)

        return tokenized_inputs       

    def compute_metrics(self, eval_pred):
        logits, labels = eval_pred.predictions, eval_pred.label_ids
        predictions = np.argmax(logits, axis=-1)
        
        true_labels = [
            [self.label_names[l] for l in label if l != -100] for label in labels
        ]
        true_predictions = [
            [self.label_names[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        all_metrics = self.metric.compute(
            predictions = true_predictions, 
            references = true_labels
        )
        meta = {"name": self.name, "size": len(predictions), "index": self.index}
        metrics = {k.replace("overall_",""):v for k,v in all_metrics.items() if "overall" in k}
        self.results+=[metrics]
        return {**metrics, **meta}

    def check(self):
        features = self.dataset['train'].features
        return self.tokens in features and self.y in features

In [5]:

from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from typing import Optional, Union
@dataclass
class DataCollatorForTokenClassificationCustom:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"
    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors
        if return_tensors == "tf":
            return self.tf_call(features)
        elif return_tensors == "pt":
            return self.torch_call(features)
        elif return_tensors == "np":
            return self.numpy_call(features)
        else:
            raise ValueError(f"Framework '{return_tensors}' not recognized!")
    
    def torch_call(self, features):
        import torch
        targets = [{k: v for k, v in feature.items() if "target" in k} for feature in features]
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
        print("==>",targets)

        no_labels_features = [{k: v for k, v in feature.items() if "target" not in k} for feature in features]
        no_labels_features = [{k: v for k, v in feature.items() if "tokens" not in k} for feature in features]
        print("==>",no_labels_features)

        # error here
        # input_id
        batch = self.tokenizer.pad(
            no_labels_features,
            padding = self.padding,
            max_length = 8,
            pad_to_multiple_of = self.pad_to_multiple_of,
            return_tensors = "pt",
        )
        
        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        if padding_side == "right":
            batch[label_name] = [
                to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
            ]
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
            ]

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
        
        return batch

In [6]:
# import trainer
from transformers import Trainer, TrainingArguments

# train and evaluate using Evalb
encodings = gen_dsets()
results = {}
train_limit = 256
max_seq_len = 128
model_name = "bert-base-cased"

# probably this could be done in parallel
for enc in encodings[:1]:
    results_df = pd.DataFrame(columns=["encoding", "recall", "precision", "f1", "n_labels"])
    encoder = enc["encoder"]
    
    train_enc, mlt1 = encode_dset(encoder, ptb_train[:train_limit] if train_limit else ptb_train)
    dev_enc,   mlt2   = encode_dset(encoder, ptb_dev[:train_limit]   if train_limit else ptb_dev)
    dataset  = generate_dataset_from_codelin(train_enc, dev_enc)
    
    print("[*] Sample encoded sentence")
    print("   ",train_enc['tokens'][0])
    print("   ",train_enc['target_1'][0])
    print("   ",train_enc['target_2'][0])
    print("   ",train_enc['target_3'][0])

    tasks = [TokenClassification(
                dataset = dataset,
                y = ["target_1", "target_2", "target_3"],
                name = enc["name"]+"_n_commons",
                tokenizer_kwargs = frozendict(padding="max_length", max_length = max_seq_len, truncation=True)
            )]
    
    model = MultiTaskModel(model_name, tasks)
    
    training_args = TrainingArguments(
        output_dir = f"results/{enc['name']}",
        num_train_epochs = 2,
        per_device_train_batch_size = 8,
        per_device_eval_batch_size = 8,
        warmup_steps = 500,
        weight_decay = 0.01,
        logging_dir = f"results/{enc['name']}/logs",
        logging_steps = 10,
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        load_best_model_at_end = True,
        metric_for_best_model = "f1",
        greater_is_better = True,
        save_total_limit = 1,
        remove_unused_columns=False
    )

    train_dataset = model.train_dataset
    trainer = MultiTaskTrainer(
        model = model,
        tasks = tasks,
        args = training_args,
        train_dataset = model.train_dataset,
        eval_dataset = model.eval_dataset,
        compute_metrics = None,
        tokenizer = model.tokenizer
    )
    
    trainer.train()

Sample of labels target_1: n_commons: ['14', '15', '16', '17', '18']
Sample of labels target_2: last_common: ['FRAG[+]ADJP', 'FRAG[+]NP', 'LST', 'NAC', 'NP']
Sample of labels target_3: unary_chain: ['INTJ', 'NP', 'NP[+]NP', 'NX', 'PP']


Casting the dataset:   0%|          | 0/256 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/256 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/256 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/256 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/256 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/256 [00:00<?, ? examples/s]

[*] Sample encoded sentence
    ['In', 'an', 'Oct.', '19', 'review', 'of', '``', 'The', 'Misanthrope', "''", 'at', 'Chicago', "'s", 'Goodman', 'Theatre', '-LRB-', '``', 'Revitalized', 'Classics', 'Take', 'the', 'Stage', 'in', 'Windy', 'City', ',', "''", 'Leisure', '&', 'Arts', '-RRB-', ',', 'the', 'role', 'of', 'Celimene', ',', 'played', 'by', 'Kim', 'Cattrall', ',', 'was', 'mistakenly', 'attributed', 'to', 'Christina', 'Haag', '.']
    ['2', '4', '4', '4', '3', '4', '5', '6', '5', '5', '6', '8', '7', '7', '3', '4', '4', '6', '5', '6', '7', '6', '7', '8', '4', '4', '4', '5', '5', '4', '1', '1', '4', '3', '4', '2', '2', '3', '4', '5', '2', '1', '2', '3', '3', '4', '5', '1', '1']
    ['PP', 'NP', 'NP', 'NP', 'NP', 'PP', 'NP', 'NP', 'NP', 'NP', 'PP', 'NP', 'NP', 'NP', 'NP', 'PRN', 'PRN', 'NP', 'S', 'VP', 'NP', 'VP', 'PP', 'NP', 'PRN', 'PRN', 'PRN', 'NP', 'NP', 'PRN', 'S', 'S', 'NP', 'NP', 'PP', 'NP', 'NP', 'VP', 'PP', 'NP', 'NP', 'S', 'VP', 'VP', 'VP', 'PP', 'NP', 'S', 'S']
    ['-NONE-',

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


[TRN] Creating output head for task naive_absolute_n_commons
      - Task type: TokenClassification
      - Number of labels: {'target_1': 28, 'target_2': 37, 'target_3': 16}
      - Label names:
            - target_1 : ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '24', '25', '26', '27', '28', '29', '3', '4', '5', '6', '7', '8', '9']
            - target_2 : ['ADJP', 'ADJP[+]QP', 'ADVP', 'CONJP', 'FRAG', 'FRAG[+]ADJP', 'FRAG[+]NP', 'LST', 'NAC', 'NP', 'NP[+]NP', 'NP[+]QP', 'NX', 'PP', 'PRN', 'PRN[+]S', 'QP', 'S', 'SBAR', 'SBARQ', 'SBAR[+]S', 'SBAR[+]SINV', 'SBAR[+]S[+]VP', 'SINV', 'SQ', 'SQ[+]VP', 'S[+]ADJP', 'S[+]NP', 'S[+]PP', 'S[+]VP', 'UCP', 'VP', 'VP[+]VP', 'WHADJP', 'WHADVP', 'WHNP', 'WHPP']
            - target_3 : ['-NONE-', 'ADJP', 'ADJP[+]ADJP', 'ADVP', 'ADVP[+]ADVP', 'INTJ', 'NP', 'NP[+]NP', 'NX', 'PP', 'PRT', 'S[+]ADJP', 'S[+]VP', 'VP', 'WHADVP', 'WHNP']
[TRN] Example input:
      - tokens : ['In', 'an', 'Oct.', '19', 'review', '

Map:   0%|          | 0/256 [00:00<?, ? examples/s]

Map:   0%|          | 0/256 [00:00<?, ? examples/s]

Model finished preprocessing task naive_absolute_n_commons
[TRN] Model has 3 output heads
[TRN] Model has 1 training datasets
[TRN] Model has 1 evaluation datasets
[*] Init multitask trainer with tasks: {'naive_absolute_n_commons': {'train': Dataset({
    features: ['tokens', 'target_1', 'target_2', 'target_3', 'input_ids', 'token_type_ids', 'attention_mask', 'task_ids'],
    num_rows: 256
}), 'validation': Dataset({
    features: ['tokens', 'target_1', 'target_2', 'target_3', 'input_ids', 'token_type_ids', 'attention_mask', 'task_ids'],
    num_rows: 256
})}}
[*] Label names are: {'naive_absolute_n_commons': {'target_1': ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '24', '25', '26', '27', '28', '29', '3', '4', '5', '6', '7', '8', '9'], 'target_2': ['ADJP', 'ADJP[+]QP', 'ADVP', 'CONJP', 'FRAG', 'FRAG[+]ADJP', 'FRAG[+]NP', 'LST', 'NAC', 'NP', 'NP[+]NP', 'NP[+]QP', 'NX', 'PP', 'PRN', 'PRN[+]S', 'QP', 'S', 'SBAR', 'SBARQ', 'SBAR[+]S', 'SBAR[+]SI



  0%|          | 0/32 [00:00<?, ?it/s]

[*] Computing Loss...
    loss tensor([3.4347, 3.6126, 2.7503], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.4038, 3.6206, 2.7273], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.4008, 3.5965, 2.7073], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.3610, 3.5981, 2.7351], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.3841, 3.5750, 2.7291], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.3848, 3.5966, 2.7288], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.3932, 3.6320, 2.7317], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.4034, 3.6244, 2.6706], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor([3.3876, 3.6036, 2.5906], device='cuda:0', grad_fn=<StackBackward0>)
[*] Computing Loss...
    loss tensor

  _warn_prf(average, modifier, msg_start, len(result))


==> all_metrics {'0': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 20}, '1': {'precision': 0.022222222222222223, 'recall': 0.14285714285714285, 'f1': 0.038461538461538464, 'number': 14}, '2': {'precision': 0.058823529411764705, 'recall': 0.125, 'f1': 0.07999999999999999, 'number': 8}, '3': {'precision': 0.027777777777777776, 'recall': 0.125, 'f1': 0.04545454545454545, 'number': 8}, '4': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 4}, '5': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, '6': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, '7': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '8': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '9': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '_': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 22}, 'overall_precision': 0.010582010582010581, 'overall_recall': 0.05128205128205128, 'overall_f1': 0.017543859649122806, 'overall_accuracy': 0.0356472795497185

  _warn_prf(average, modifier, msg_start, len(result))


==> all_metrics {'BAR': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 16}, 'BAR[+]S': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 5}, 'BAR[+]SINV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'CP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'HPP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'INV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 20}, 'ONJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.08, 'recall': 0.05128205128205128, 'f1': 0.0625, 'number': 78}, 'P[+]NP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P[+]QP': {'p



==> all_metrics {'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'DJP[+]ADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 6}, 'DVP[+]ADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'HNP': {'precision': 0.125, 'recall': 0.2, 'f1': 0.15384615384615385, 'number': 5}, 'NTJ': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 28}, 'P[+]NP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'RT': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'X': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '[+]ADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '_': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 37}, 'overall_precision': 



==> all_metrics {'AC': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, 'BAR[+]S': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'BAR[+]SINV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'CP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 7}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'HADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HPP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'ONJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.09375, 'recall': 0.0967741935483871, 'f1': 0.09523809523809523, 'number': 62}, '



==> all_metrics {'BAR': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, 'BARQ': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 2}, 'BAR[+]SINV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'CP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 2}, 'DJP[+]QP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 2}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HPP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'INV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 16}, 'ONJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.05555555



==> all_metrics {'AC': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'BAR': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 8}, 'BAR[+]SINV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'CP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DJP[+]QP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'HADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HPP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'ONJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'P': {'precision': 0.034482758620689655, 'recall': 0.04081632653061224, 'f1': 0.03738317757009345, 'n

  _warn_prf(average, modifier, msg_start, len(result))


==> all_metrics {'0': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 20}, '1': {'precision': 0.011627906976744186, 'recall': 0.07142857142857142, 'f1': 0.02, 'number': 14}, '2': {'precision': 0.14285714285714285, 'recall': 0.125, 'f1': 0.13333333333333333, 'number': 8}, '3': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 8}, '4': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 4}, '5': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, '6': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, '7': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '8': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '9': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '_': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 22}, 'overall_precision': 0.005763688760806916, 'overall_recall': 0.02564102564102564, 'overall_f1': 0.009411764705882352, 'overall_accuracy': 0.054409005628517824}
[*] Computing metrics for subtask target_2 ...
L

  _warn_prf(average, modifier, msg_start, len(result))


==> all_metrics {'AC': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 16}, 'BARQ': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 5}, 'BAR[+]SINV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'HPP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'INV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 20}, 'ONJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.05479452054



==> all_metrics {'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'DJP[+]ADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 6}, 'DVP[+]ADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 5}, 'NTJ': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.16666666666666666, 'recall': 0.03571428571428571, 'f1': 0.058823529411764705, 'number': 28}, 'P[+]NP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'RT': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'X': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '[+]ADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '_': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'numb



==> all_metrics {'AC': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 7}, 'BARQ': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'BAR[+]SINV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'CP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HPP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'INV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'ONJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.0821917808219178



==> all_metrics {'BAR': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, 'BAR[+]S': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'BAR[+]SINV': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'BAR[+]S[+]VP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'CP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 7}, 'DJP[+]QP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'DVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, 'HADJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HADVP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HNP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'HPP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, 'ONJP': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, 'P': {'precision': 0.07246376811594203, 'recall': 0.08064516129032258, 'f1': 0.0763358778625954



==> all_metrics {'0': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, '1': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 5}, '2': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 6}, '3': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 5}, '4': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3}, '5': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 5}, '6': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 2}, '7': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, '8': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 1}, '9': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0}, '_': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 10}, 'overall_precision': 0.0, 'overall_recall': 0.0, 'overall_f1': 0.0, 'overall_accuracy': 0.0413625304136253}
[*] Computing metrics for subtask target_2 ...
Labels shape => torch.Size([16, 128])
Labels sample => tensor([-100,   17,   31,   29,   31,   18,    9,    9,   17,   31,