In [1]:
import datasets
from datasets import Sequence
from datasets import ClassLabel
# from hfmtl.tasks.sequence_classification import SequenceClassification
# from hfmtl.tasks.token_classification import TokenClassification
# from hfmtl.utils import *
# from hfmtl.models import *

from PYEVALB.scorer import Scorer
from PYEVALB.summary import summary

from codelin.models.const_tree import C_Tree
from codelin.models.const_label import C_Label
from codelin.models.linearized_tree import LinearizedTree
from codelin.encs.constituent import *
from codelin.utils.constants import *

import easydict
from chrono import Timer
from frozendict import frozendict
import os
import torch
import pandas as pd

import logging


# Set logging level
'''
Train the models in multi-task learning fashion. To do this
we will split the fields of the label and train different
tasks according to this. After training, we will evaluate
the decoded trees by re-joining the labels.
'''

ptb_path = "~/Treebanks/const/PENN_TREEBANK/"
ptb_path = os.path.expanduser(ptb_path)

with open(os.path.join(ptb_path,"test.trees")) as f:
    ptb_test = [l.rstrip() for l in f.read().splitlines()]
with open(os.path.join(ptb_path,"dev.trees")) as f:
    ptb_dev = [l.rstrip() for l in f.read().splitlines()]
with open(os.path.join(ptb_path,"train.trees")) as f:
    ptb_train = [l.rstrip() for l in f.read().splitlines()]

def get_n_labels(dsets, tar_field):
    label_set = set()
    for dset in dsets:
        for labels in dset[tar_field]:
            label_set.update(labels)
    label_names = sorted(list(label_set))
    return label_names, len(label_names)

def generate_dataset_from_codelin(train_dset, dev_dset, test_dset=None):
    dsets = [train_dset, dev_dset, test_dset] if test_dset else [train_dset, dev_dset]
    
    l1, nl1 = get_n_labels(dsets, "target_1")
    print("Sample of labels n_commons:", l1[5:10])
    l2, nl2 = get_n_labels(dsets, "target_2")
    print("Sample of labels last_common:", l2[5:10])
    l3, nl3 = get_n_labels(dsets, "target_3")
    print("Sample of labels unary_chain:", l3[5:10])

    train_dset = datasets.Dataset.from_dict(train_dset)
    train_dset = train_dset.cast_column("target_1", Sequence(ClassLabel(num_classes=nl1, names=l1)))
    train_dset = train_dset.cast_column("target_2", Sequence(ClassLabel(num_classes=nl2, names=l2)))
    train_dset = train_dset.cast_column("target_3", Sequence(ClassLabel(num_classes=nl3, names=l3)))

    dev_dset = datasets.Dataset.from_dict(dev_dset)
    dev_dset = dev_dset.cast_column("target_1", Sequence(ClassLabel(num_classes=nl1, names=l1)))
    dev_dset = dev_dset.cast_column("target_2", Sequence(ClassLabel(num_classes=nl2, names=l2)))
    dev_dset = dev_dset.cast_column("target_3", Sequence(ClassLabel(num_classes=nl3, names=l3)))

    if test_dset:
        test_dset = datasets.Dataset.from_dict(test_dset)
        test_dset = test_dset.cast_column("target_1", Sequence(ClassLabel(num_classes=nl1, names=l1)))
        test_dset = test_dset.cast_column("target_2", Sequence(ClassLabel(num_classes=nl2, names=l2)))
        test_dset = test_dset.cast_column("target_3", Sequence(ClassLabel(num_classes=nl3, names=l3)))
    
        # Convert to Hugging Face DatasetDict format
        dataset = datasets.DatasetDict({
                "train": train_dset,
                "validation": dev_dset,
                "test": test_dset
            })
    else:
        # Convert to Hugging Face DatasetDict format
        dataset = datasets.DatasetDict({
                "train": train_dset,
                "validation": dev_dset
            })

    return dataset

def encode_dset(encoder, dset):
    encoded_trees = {"tokens":[], "target_1":[], "target_2":[], "target_3":[]}
    max_len_tree = 0
    for line in dset:
        tree = C_Tree.from_string(line)
        lin_tree = encoder.encode(tree)
        encoded_trees["tokens"].append([w for w in lin_tree.words])
        
        t1,t2,t3 = [],[],[]
        for s1,s2,s3 in lin_tree.get_labels_splitted():
            t1.append(s1)    
            t2.append(s2)
            t3.append(s3)
            
        encoded_trees["target_1"].append(t1)
        encoded_trees["target_2"].append(t2)
        encoded_trees["target_3"].append(t3)
        
        max_len_tree = max(max_len_tree, len(lin_tree.words))
    
    
    return encoded_trees, max_len_tree

def gen_dsets():
    encodings = []

    # naive absolute encodings
    a_enc     = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_absolute", "encoder":a_enc})
    a_br_enc  = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_br", "encoder":a_br_enc})
    a_bl_enc  = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_bl", "encoder":a_bl_enc})
    ar_enc    = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_absolute_r", "encoder":ar_enc})
    ar_br_enc = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_r_br", "encoder":ar_br_enc})
    ar_bl_enc = C_NaiveAbsoluteEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_absolute_r_bl", "encoder":ar_bl_enc})

    # naive relative encodings
    r_enc     = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_relative", "encoder":r_enc})
    r_br_enc  = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_br", "encoder":r_br_enc})
    r_bl_enc  = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_bl", "encoder":r_bl_enc})
    rr_enc    = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_relative_r", "encoder":rr_enc})
    rr_br_enc = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_r_br", "encoder":rr_br_enc})
    rr_bl_enc = C_NaiveRelativeEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_relative_r_bl", "encoder":rr_bl_enc})

    # naive dynamic encodings
    d_enc     = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_dynamic", "encoder":d_enc})
    d_br_enc  = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_br", "encoder":d_br_enc})
    d_bl_enc  = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=False, binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_bl", "encoder":d_bl_enc})
    dr_enc    = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_r", "encoder":dr_enc})
    dr_br_enc = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="R",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_r_br", "encoder":dr_br_enc})
    dr_bl_enc = C_NaiveDynamicEncoding(separator="[_]", unary_joiner="[+]", reverse=True,  binary=True,  binary_direction="L",  binary_marker="[b]")
    encodings.append({"name":"naive_dynamic_r_bl", "encoder":dr_bl_enc})

    # gaps encodings
    g_r_enc   = C_GapsEncoding(separator="[_]", unary_joiner="[+]", binary_direction="R", binary_marker="[b]")
    encodings.append({"name":"gaps_r", "encoder":g_r_enc})
    g_l_enc   = C_GapsEncoding(separator="[_]", unary_joiner="[+]", binary_direction="L", binary_marker="[b]")
    encodings.append({"name":"gaps_l", "encoder":g_l_enc})

    # tetra encodings
    t_pr_enc  = C_Tetratag(separator="[_]", unary_joiner="[+]", mode='preorder',  binary_marker="[b]")
    encodings.append({"name":"tetratag_preorder", "encoder":t_pr_enc})
    t_in_enc  = C_Tetratag(separator="[_]", unary_joiner="[+]", mode='inorder',   binary_marker="[b]")
    encodings.append({"name":"tetratag_inorder", "encoder":t_in_enc})
    t_po_enc  = C_Tetratag(separator="[_]", unary_joiner="[+]", mode='postorder', binary_marker="[b]")
    encodings.append({"name":"tetratag_postorder", "encoder":t_po_enc})

    # yuxtaposed encodings
    j_enc   = C_JuxtaposedEncoding(separator="[_]", unary_joiner="[+]", binary=False, binary_direction=None, binary_marker="[b]")
    encodings.append({"name":"juxtaposed", "encoder":j_enc})
    j_r_enc = C_JuxtaposedEncoding(separator="[_]", unary_joiner="[+]", binary=True, binary_direction='R',   binary_marker="[b]")
    encodings.append({"name":"juxtaposed_r", "encoder":j_r_enc})
    j_l_enc = C_JuxtaposedEncoding(separator="[_]", unary_joiner="[+]", binary=True, binary_direction='L',   binary_marker="[b]")
    encodings.append({"name":"juxtaposed_l", "encoder":j_l_enc})

    return encodings


In [2]:
from torch import nn

class TokenClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels, dropout_p=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.num_labels = num_labels

        self._init_weights()

    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()

    def forward(self, sequence_output, pooled_output, labels = None, attention_mask = None, **kwargs):
        print("forward of TokenClassificationHead")
        sequence_output_dropout = self.dropout(sequence_output)
        logits = self.classifier(sequence_output_dropout)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            labels = labels.long()

            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss,
                    labels.view(-1),
                    torch.tensor(loss_fct.ignore_index).type_as(labels),
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return logits, loss

In [3]:
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import EvalPrediction
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
import transformers
from transformers.trainer_utils import EvalLoopOutput

class MultiTaskModel(nn.Module):
    def __init__(self, encoder_name_or_path, tasks):
        super().__init__()
        
        self.encoder = AutoModel.from_pretrained(encoder_name_or_path)

        tokenizer_kwargs = frozendict(padding="max_length", max_length=128, truncation=True, return_tensors="pt")
        self.tokenizer = AutoTokenizer.from_pretrained(encoder_name_or_path, **tokenizer_kwargs)
        self.output_heads = nn.ModuleDict()
        
        for task in tasks:
            task.set_tokenizer(self.tokenizer)
            # for each target in the task, create a decoder
            for subtask in task.y:
                decoder = self._create_output_head(self.encoder.config.hidden_size, task.task_type, task.num_labels[subtask])
                self.output_heads[subtask] = decoder

        self.processed_tasks = self.preprocess_tasks(tasks, self.tokenizer)['naive_absolute_n_commons']
        self.train_dataset = self.processed_tasks['train']
        self.val_dataset = self.processed_tasks['validation']
        print(self.train_dataset.features)
    
    def preprocess_tasks(self, tasks, tokenizer):      
        features_dict = {}
        print("Preprocessing tasks")
        for i, task in enumerate(tasks):
            if hasattr(task, 'processed_features') and tokenizer==task.tokenizer:
                features_dict[task.name] = task.processed_features
                continue

            task.set_tokenizer(tokenizer)
            
            for split in task.dataset:
                task.index = task.dataset[split].index = i
            
            features_dict[task.name] = {}
            for phase, phase_dataset in task.dataset.items():
                phase_dataset.index = i

                features_dict[task.name][phase] = phase_dataset.map(
                    task.preprocess_function, 
                    batched = True,
                    batch_size = 8,
                    load_from_cache_file = True
                )

            print("[TRN] Finished preprocessing task",task.name)

        return features_dict
    
    @staticmethod
    def _create_output_head(encoder_hidden_size: int, task_type, n_labels):
        if task_type == "TokenClassification":
            print("Creating TokenClassification head w/", n_labels, "labels")
            return TokenClassificationHead(encoder_hidden_size, n_labels)
        else:
            raise NotImplementedError()
    
    def forward(self, input_ids = None, attention_mask = None, token_type_ids = None, position_ids = None,
            head_mask = None, inputs_embeds = None, labels = None, task_ids = None, **kwargs):
            
            # compute the transformer output
            outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )
            sequence_output, pooled_output = outputs[:2]

            print("3) Transformer has been forwarded")
            unique_task_ids_list = torch.unique(task_ids).tolist()

            loss_list = []
            logits = None
            # print("Computing loss...")
            # print("task_ids", task_ids)
            print("==> I have to compute loss for the following tasks:")
            print("==>", unique_task_ids_list)
            for unique_task_id in unique_task_ids_list:
                print("Task_id =",unique_task_id)
                ptc_train = self.processed_tasks['train']
                target_cols = [col for col in ptc_train.features if col.startswith("target_")]
                print("target_cols =", target_cols)

                for tc in target_cols:
                    print("Target Column =",tc)
                    print("Labels =",labels)
                    logits, task_loss = self.output_heads[str(unique_task_id)].forward(
                        sequence_output[task_id_filter],
                        pooled_output[task_id_filter],
                        labels = None if labels is None else labels[task_id_filter],
                        attention_mask=attention_mask[task_id_filter],
                    )

                    if labels is not None:
                        loss_list.append(task_loss)

            # Loss averaged over all tasks
            outputs = (logits, outputs[2:])
            if loss_list:
                loss = torch.stack(loss_list)
                outputs = (loss.mean(),) + outputs

            return outputs

class MultiTaskTrainer(transformers.Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.tasks = self.model.processed_tasks
        self.train_dataset = self.model.train_dataset
        self.val_dataset = self.model.val_dataset
        self.tokenizer = self.model.tokenizer
        self.pretrained_transformer = self.model.encoder
        self.device = self.pretrained_transformer.device

    def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
        print("[*] get_eval_dataloader")
        # if eval_dataset is None:
        #     eval_dataset = self.val_dataset
        # print(eval_dataset)
        return super().get_eval_dataloader(eval_dataset)
    
    def get_train_dataloader(self) -> DataLoader:
        print("[*] get_train_dataloader")
        return super().get_train_dataloader()
    
    def get_test_dataloader(self, test_dataset: Dataset | None = None) -> DataLoader:
        print("[*] get_test_dataloader")
        # if test_dataset is None:
        #     test_dataset = self.val_dataset
        return super().get_test_dataloader(test_dataset)
        
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=[]):
        
        if ignore_keys is None:
            ignore_keys = []
        
        inputs.to(self.device)
        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
        loss = loss.mean().detach()
        print("[*] Prediction Step...")
        print("    inputs", inputs.keys())
        print("    outputs", outputs.keys())
        print("    loss", loss)
        
        
        logits_dict = {}
        labels_dict = {}
        
        print("[*] Extracting logits and labels...")
        for i, head in enumerate(self.model.output_heads.values()):
            print("Head found", i+1)
            labels_name = f"target_{i+1}"
            logits_dict[labels_name] = outputs[labels_name]
            print(self.tasks.keys())
            print(self.tasks['train'][labels_name])


        for k,v in outputs.items():
            print(k, v.shape)
            logits_dict[k] = v.detach().cpu().numpy()
            labels_dict[k] = inputs[k].detach().cpu().numpy()
        
        labels = None
        return (loss, logits_dict, labels)
    
    
    
    def evaluation_loop(self, dataloader: DataLoader, description: str, prediction_loss_only: bool | None = None, ignore_keys: List[str] | None = None, metric_key_prefix: str = "eval") -> EvalLoopOutput:
        print("[*] evaluation_loop")
        def has_length(dataset):
            try:
                return len(dataset) is not None
            except TypeError:
                return False
            
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)

        if has_length(dataloader):
            print(f"Num examples = {self.num_examples(dataloader)}")
        
        for step, inputs in enumerate(dataloader):
            print(f"*** Step [{step}] ***")
            print(f"inputs: {inputs}")
            
            
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
            
            # logits = self._pad_across_processes(logits)
            for k,v in logits.items():
                print("[eval_loop] logits: ", k, v.shape)

                preds_host = logits[k]
                labels_host = labels
                inputs_host = inputs


                metrics = self.compute_metrics(
                        EvalPrediction(
                            predictions = preds_host, 
                            label_ids   = labels, 
                            inputs      = inputs)
                    )
        
        return super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
        

    def compute_loss(self, model, inputs, return_outputs=False):
        print("[*] Computing Loss...")
        input_ids = inputs['input_ids'] if 'input_ids' in inputs.keys() else None
        attention_mask = inputs['attention_mask'] if 'attention_mask' in inputs.keys() else None
        token_type_ids = inputs['token_type_ids'] if 'token_type_ids' in inputs.keys() else None
        position_ids = inputs['position_ids'] if 'position_ids' in inputs.keys() else None
        head_mask = inputs['head_mask'] if 'head_mask' in inputs.keys() else None
        inputs_embeds = inputs['inputs_embeds'] if 'inputs_embeds' in inputs.keys() else None

        outputs = self.pretrained_transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output, pooled_output = outputs[:2]
        loss_list = []
        logits_list = {}
        for i, head in enumerate(self.model.output_heads.values()):
            labels_name = f"target_{i+1}"
            labels_i = inputs.pop(labels_name, None)
            logits, loss = head(sequence_output, pooled_output, labels=labels_i, attention_mask=attention_mask)
            loss_list.append(loss)
            logits_list[labels_name] = logits
        
        loss = torch.stack(loss_list)
        print("mean loss", loss)
        return (loss, logits_list) if return_outputs else loss


In [4]:
import numpy as np
from datasets import Dataset
from transformers import DataCollatorForTokenClassification
import evaluate
import funcy as fc
import warnings
from frozendict import frozendict as fdict
from dataclasses import dataclass

@dataclass
class TokenClassification:
    task_type = "TokenClassification"
    name: str = "TokenClassificationTask"
    dataset: Dataset = None
    metric:... = evaluate.load("seqeval")
    main_split: str = "train"
    tokens: str = 'tokens'
    y: str|list = 'target'
    num_labels: int = None
    tokenizer_kwargs: fdict = fdict(padding="max_length", max_length=128, truncation=True)

    @staticmethod
    def _align_labels_with_tokens(labels, word_ids):
        new_labels = []
        current_word = None
        for word_id in word_ids:
            if word_id is None:
                new_labels.append(-100)

            elif word_id != current_word:
                current_word = word_id
                label = -100 if word_id is None else labels[word_id]
                new_labels.append(label)
            
            else:
                label = labels[word_id]
                new_labels.append(label)
        
        return new_labels

    def __post_init__(self):
        self.label_names = {}
        self.num_labels  = {}

        for y in self.y:
            target = self.dataset[self.main_split].features[y]
            self.num_labels[y] = target.feature.num_classes
            self.label_names[y] = target.feature.names if target.feature.names else [None]

        print(self.label_names)

    def get_labels(self):
        return super().get_labels() or self.label_names

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer
        self.tokenizer.add_prefix_space = True
        self.data_collator = DataCollatorForTokenClassification(
            tokenizer = self.tokenizer
        )

        if examples[self.tokens] and type(examples[self.tokens][0]) == str:
            unsqueeze, examples = True, {k:[v] for k,v in examples.items()}
        
        def get_len(outputs):
            try:
                return len(outputs[fc.first(outputs)])
            except:
                return 1
        
        tokenized_inputs = self.tokenizer(
            examples[self.tokens],
            is_split_into_words=True,
            **self.tokenizer_kwargs
        )

        for target_column in self.y:
            all_labels = examples[target_column]
            new_labels = []
            
            for i, labels in enumerate(all_labels):
                word_ids = tokenized_inputs.word_ids(i)
                new_labels.append(self._align_labels_with_tokens(labels, word_ids))
            
            tokenized_inputs[target_column] = new_labels        
            tokenized_inputs['task_ids'] = [self.index]*get_len(tokenized_inputs)

        return tokenized_inputs       

    def compute_metrics(self, eval_pred):
        logits, labels = eval_pred.predictions, eval_pred.label_ids
        
        predictions = np.argmax(logits, axis=-1)
        true_labels = [
            [self.label_names[l] for l in label if l != -100] for label in labels
        ]
        true_predictions = [
            [self.label_names[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        all_metrics = self.metric.compute(
            predictions = true_predictions, 
            references = true_labels
        )
        meta = {"name": self.name, "size": len(predictions), "index": self.index}
        metrics = {k.replace("overall_",""):v for k,v in all_metrics.items() if "overall" in k}
        self.results+=[metrics]
        return {**metrics, **meta}

    def check(self):
        features = self.dataset['train'].features
        return self.tokens in features and self.y in features

In [5]:

from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from typing import Optional, Union
@dataclass
class DataCollatorForTokenClassificationCustom:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"
    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors
        if return_tensors == "tf":
            return self.tf_call(features)
        elif return_tensors == "pt":
            return self.torch_call(features)
        elif return_tensors == "np":
            return self.numpy_call(features)
        else:
            raise ValueError(f"Framework '{return_tensors}' not recognized!")
    
    def torch_call(self, features):
        import torch
        targets = [{k: v for k, v in feature.items() if "target" in k} for feature in features]
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
        print("==>",targets)

        no_labels_features = [{k: v for k, v in feature.items() if "target" not in k} for feature in features]
        no_labels_features = [{k: v for k, v in feature.items() if "tokens" not in k} for feature in features]
        print("==>",no_labels_features)

        # error here
        # input_id
        batch = self.tokenizer.pad(
            no_labels_features,
            padding = self.padding,
            max_length = 8,
            pad_to_multiple_of = self.pad_to_multiple_of,
            return_tensors = "pt",
        )
        
        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)

        if padding_side == "right":
            batch[label_name] = [
                to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
            ]
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
            ]

        batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
        
        return batch

In [6]:
# import trainer
from transformers import Trainer, TrainingArguments

# train and evaluate using Evalb
encodings = gen_dsets()
results = {}
train_limit = 10
max_seq_len = 128
model_name = "bert-base-cased"

# probably this could be done in parallel
for enc in encodings[:1]:
    results_df = pd.DataFrame(columns=["encoding", "recall", "precision", "f1", "n_labels"])
    print("[GPU] Starting training; Allocated memory:", torch.cuda.memory_allocated()/1e6,"MB")
    print("[GPU] Starting training; Cached memory:", torch.cuda.memory_cached()/1e6,"MB")
    encoder = enc["encoder"]
    
    print("[DST] Encoding the datasets using CoDeLin")
    train_enc, mlt1 = encode_dset(encoder, ptb_train[:train_limit] if train_limit else ptb_train)
    dev_enc,   mlt2   = encode_dset(encoder, ptb_dev[:train_limit]   if train_limit else ptb_dev)
    dataset  = generate_dataset_from_codelin(train_enc, dev_enc)
    
    tasks = [TokenClassification(
                dataset = dataset,
                y = ["target_1", "target_2", "target_3"],
                name = enc["name"]+"_n_commons",
                tokenizer_kwargs = frozendict(padding="max_length", max_length = max_seq_len, truncation=True)
            )]

    model = MultiTaskModel(model_name, tasks)
    
    training_args = TrainingArguments(
        output_dir = f"results/{enc['name']}",
        num_train_epochs = 1,
        per_device_train_batch_size = 8,
        per_device_eval_batch_size = 8,
        warmup_steps = 500,
        weight_decay = 0.01,
        logging_dir = f"results/{enc['name']}/logs",
        logging_steps = 10,
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        load_best_model_at_end = True,
        metric_for_best_model = "f1",
        greater_is_better = True,
        save_total_limit = 1,
        remove_unused_columns=False
    )


    # wait until train stops
    data_collator = DataCollatorForTokenClassificationCustom(
        model.tokenizer
    )
    print("training with",model.train_dataset)
    
    train_dataset = model.train_dataset
    print("train_dataset[0]",train_dataset[0])
#    model.train()
    trainer = MultiTaskTrainer(
        model = model,
        args = training_args,
        train_dataset = model.train_dataset,
        eval_dataset = model.val_dataset,
        compute_metrics = None,
        tokenizer = model.tokenizer,
        data_collator = data_collator,        
    )
    
    trainer.train()

[GPU] Starting training; Allocated memory: 0.0 MB
[GPU] Starting training; Cached memory: 0.0 MB
[DST] Encoding the datasets using CoDeLin
Sample of labels n_commons: ['14', '2', '3', '4', '5']
Sample of labels last_common: ['QP', 'S', 'SBAR', 'SBAR[+]S', 'SBAR[+]S[+]VP']
Sample of labels unary_chain: ['VP', 'WHADVP', 'WHNP']




Casting the dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

[TSK] Loaded TokenClassification task with {'target_1': 14, 'target_2': 15, 'target_3': 8} labels
{'target_1': ['1', '10', '11', '12', '13', '14', '2', '3', '4', '5', '6', '7', '8', '9'], 'target_2': ['ADJP', 'NP', 'NP[+]QP', 'PP', 'PRN', 'QP', 'S', 'SBAR', 'SBAR[+]S', 'SBAR[+]S[+]VP', 'SINV', 'S[+]VP', 'UCP', 'VP', 'WHNP'], 'target_3': ['-NONE-', 'ADJP', 'ADVP', 'NP', 'PRT', 'VP', 'WHADVP', 'WHNP']}


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Creating TokenClassification head w/ 14 labels
Creating TokenClassification head w/ 15 labels
Creating TokenClassification head w/ 8 labels
Preprocessing tasks
[TRN] Preprocessing task naive_absolute_n_commons
[TRN] Preprocessing phase train


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

{'tokens': [['In', 'an', 'Oct.', '19', 'review', 'of', '``', 'The', 'Misanthrope', "''", 'at', 'Chicago', "'s", 'Goodman', 'Theatre', '-LRB-', '``', 'Revitalized', 'Classics', 'Take', 'the', 'Stage', 'in', 'Windy', 'City', ',', "''", 'Leisure', '&', 'Arts', '-RRB-', ',', 'the', 'role', 'of', 'Celimene', ',', 'played', 'by', 'Kim', 'Cattrall', ',', 'was', 'mistakenly', 'attributed', 'to', 'Christina', 'Haag', '.'], ['Ms.', 'Haag', 'plays', 'Elianti', '.'], ['Rolls-Royce', 'Motor', 'Cars', 'Inc.', 'said', 'it', 'expects', 'its', 'U.S.', 'sales', 'to', 'remain', 'steady', 'at', 'about', '1,200', 'cars', 'in', '1990', '.'], ['The', 'luxury', 'auto', 'maker', 'last', 'year', 'sold', '1,214', 'cars', 'in', 'the', 'U.S.'], ['Howard', 'Mosher', ',', 'president', 'and', 'chief', 'executive', 'officer', ',', 'said', 'he', 'anticipates', 'growth', 'for', 'the', 'luxury', 'auto', 'maker', 'in', 'Britain', 'and', 'Europe', ',', 'and', 'in', 'Far', 'Eastern', 'markets', '.'], ['BELL', 'INDUSTRIES', 

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

{'tokens': [['Influential', 'members', 'of', 'the', 'House', 'Ways', 'and', 'Means', 'Committee', 'introduced', 'legislation', 'that', 'would', 'restrict', 'how', 'the', 'new', 'savings-and-loan', 'bailout', 'agency', 'can', 'raise', 'capital', ',', 'creating', 'another', 'potential', 'obstacle', 'to', 'the', 'government', "'s", 'sale', 'of', 'sick', 'thrifts', '.'], ['The', 'bill', ',', 'whose', 'backers', 'include', 'Chairman', 'Dan', 'Rostenkowski', '-LRB-', 'D.', ',', 'Ill.', '-RRB-', ',', 'would', 'prevent', 'the', 'Resolution', 'Trust', 'Corp.', 'from', 'raising', 'temporary', 'working', 'capital', 'by', 'having', 'an', 'RTC-owned', 'bank', 'or', 'thrift', 'issue', 'debt', 'that', 'would', "n't", 'be', 'counted', 'on', 'the', 'federal', 'budget', '.'], ['The', 'bill', 'intends', 'to', 'restrict', 'the', 'RTC', 'to', 'Treasury', 'borrowings', 'only', ',', 'unless', 'the', 'agency', 'receives', 'specific', 'congressional', 'authorization', '.'], ['``', 'Such', 'agency', '`', 'self-



  0%|          | 0/1 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


==> [{'target_1': [-100, 7, 6, 6, 6, 7, 7, 8, 8, 6, 0, 6, 8, 9, 9, 9, 9, 10, 11, 12, 12, 12, 10, 12, 13, 13, 11, 11, 11, 12, 1, 13, 0, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], 'target_2': [-100, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 13, 8, 13, 13, 13, 13, 1, 3, 1, 1, 1, 1, 3, 1, 1, 3, 3, 3, 3, 0, 1, 6, 6, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 



forward of TokenClassificationHead
forward of TokenClassificationHead
forward of TokenClassificationHead
mean loss tensor([2.5977, 2.7110, 2.0459], device='cuda:0', grad_fn=<StackBackward0>)
tensor([2.5977, 2.7110, 2.0459], device='cuda:0', grad_fn=<StackBackward0>)
[*] get_eval_dataloader
Returning base Dataloader
[*] evaluation_loop
Num examples = 10
==> [{'target_1': [-100, 7, 7, 7, 7, 6, 7, 8, 8, 8, 8, 8, 8, 0, 6, 7, 8, 10, 11, 12, 1, 1, 1, 1, 1, 1, 1, 1, 1, 13, 1, 2, 11, 11, 13, 2, 2, 1, 2, 5, 5, 4, 4, 3, 4, 5, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100

KeyError: 'target_1'