In [1]:
import datasets
from datasets import Sequence
from datasets import ClassLabel
def load_conll_dataset(train_path, dev_path, test_path, token_idx, label_idx):
    
    def read_conll_file(file_path, token_idx, label_idx):        
        with open(file_path, "r") as f:
            sentences = [[]]
            for line in f:
                line = line.strip()
                
                if line:
                    split = line.split('\t')
                    sentences[-1].append((split[token_idx], split[label_idx]))
                
                else:
                    if sentences[-1]:
                        sentences.append([])
            
            if not sentences[-1]:
                sentences.pop()

        # Convert sentences to Hugging Face Dataset format
        dataset = {
            "tokens": [[token for token, label in sentence] for sentence in sentences],
            "target": [[label for token, label in sentence] for sentence in sentences],
        }

        return dataset

    train_dset = read_conll_file(train_path, token_idx, label_idx)
    dev_dset = read_conll_file(dev_path, token_idx, label_idx)
    test_dset = read_conll_file(test_path, token_idx, label_idx)

    # Get all possible labels and cast to ClassLabel
    label_set = set()
    for dset in [train_dset, dev_dset, test_dset]:
        for labels in dset["target"]:
            label_set.update(labels)
    label_names = sorted(list(label_set))
    
    train_dset = datasets.Dataset.from_dict(train_dset)
    train_dset = train_dset.cast_column("target", Sequence(ClassLabel(names=label_names)))

    dev_dset = datasets.Dataset.from_dict(dev_dset)
    dev_dset = dev_dset.cast_column("target", Sequence(ClassLabel(names=label_names)))

    test_dset = datasets.Dataset.from_dict(test_dset)
    test_dset = test_dset.cast_column("target", Sequence(ClassLabel(names=label_names)))
    
    # Convert to Hugging Face DatasetDict format
    dataset = datasets.DatasetDict({
            "train": train_dset,
            "validation": dev_dset,
            "test": test_dset
        })

    return dataset

raw_dataset = load_conll_dataset("data/train.conllu", "data/dev.conllu", "data/test.conllu", 1, 3)
print(raw_dataset["train"][0]["tokens"])
print(raw_dataset["train"][0]["target"])
print(raw_dataset["train"].features["target"].feature.names)


  from .autonotebook import tqdm as notebook_tqdm
                                                                                   

['Al', '-', 'Zaman', ':', 'American', 'forces', 'killed', 'Shaikh', 'Abdullah', 'al', '-', 'Ani', ',', 'the', 'preacher', 'at', 'the', 'mosque', 'in', 'the', 'town', 'of', 'Qaim', ',', 'near', 'the', 'Syrian', 'border', '.']
[11, 12, 11, 12, 0, 7, 15, 11, 11, 11, 12, 11, 12, 5, 7, 1, 5, 7, 1, 5, 7, 1, 11, 12, 1, 5, 0, 7, 12]
['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']




We should take as input a .json file with the training info. The input files will be in columns separated by \t. In the json file we will specify the 'target' columns. If more than one target column, then we will duplicate the task.

In [2]:
from src.tasks.sequence_classification import SequenceClassification
from src.tasks.token_classification import TokenClassification
from src.utils import *
from src.models import *

import easydict
from frozendict import frozendict
import json

# read train_config.json as easydict
with open("config.json", "r") as f:
    args = easydict.EasyDict(json.load(f))

tasks = []
for task in args.tasks:
    if task.task_type == "token_classification":
        for l_idx in task.label_idx:
            tasks.append(
                TokenClassification(
                    dataset = load_conll_dataset(task.train_file, task.eval_file, task.test_file, task.tokens_idx, l_idx),
                    name = task.task_name,
                    tokenizer_kwargs = frozendict(padding="max_length", max_length=args.max_seq_length, truncation=True)
                )
            )
    
    elif task.type == "sequence_classification":
        for l_idx in task.label_idx:
            tasks.append(
                SequenceClassification(
                    dataset = load_conll_dataset(task.train_file, task.eval_file, task.test_file, task.tokens_idx, l_idx),
                    name = task.name,
                    tokenizer_kwargs = frozendict(padding="max_length", max_length=args.max_seq_length, truncation=True)
                )
            )


        
model   = Model(tasks, args) # list of models; by default, shared encoder, task-specific CLS token task-specific head 
trainer = Trainer(model, tasks, args) # tasks are uniformly sampled by default

trainer.train()

                                                                                   

Labels for task:
['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
[*] Found task 0 => conllu


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForTokenClassification: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able

train batch size = 16
[*] Preprocessing task 0 => conllu




asking for multitask train dataloader
asking for single train dataloader


  0%|          | 0/2352 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


inputs in forward => 16


  0%|          | 2/2352 [00:00<07:54,  4.96it/s]

inputs in forward => 16
inputs in forward => 16


  0%|          | 4/2352 [00:00<05:21,  7.29it/s]

inputs in forward => 16
inputs in forward => 16


  0%|          | 6/2352 [00:00<04:49,  8.10it/s]

inputs in forward => 16
inputs in forward => 16


  0%|          | 8/2352 [00:01<04:28,  8.73it/s]

inputs in forward => 16
inputs in forward => 16


  0%|          | 10/2352 [00:01<04:14,  9.19it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 12/2352 [00:01<04:06,  9.48it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 14/2352 [00:01<04:04,  9.57it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 16/2352 [00:01<04:02,  9.64it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 18/2352 [00:02<04:02,  9.61it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 20/2352 [00:02<04:01,  9.66it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 22/2352 [00:02<04:00,  9.70it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 24/2352 [00:02<03:59,  9.71it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 26/2352 [00:02<03:58,  9.76it/s]

inputs in forward => 16
inputs in forward => 16


  1%|          | 28/2352 [00:03<03:58,  9.76it/s]

inputs in forward => 16
inputs in forward => 16


  1%|▏         | 31/2352 [00:03<03:56,  9.80it/s]

inputs in forward => 16
inputs in forward => 16


  1%|▏         | 32/2352 [00:03<03:57,  9.75it/s]

inputs in forward => 16
inputs in forward => 16


  1%|▏         | 34/2352 [00:03<03:59,  9.68it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 36/2352 [00:03<03:58,  9.73it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 38/2352 [00:04<04:00,  9.64it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 40/2352 [00:04<03:57,  9.71it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 42/2352 [00:04<03:59,  9.65it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 44/2352 [00:04<03:58,  9.68it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 46/2352 [00:05<03:57,  9.70it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 48/2352 [00:05<04:00,  9.58it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 51/2352 [00:05<03:55,  9.75it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 52/2352 [00:05<03:56,  9.71it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 54/2352 [00:05<03:56,  9.70it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 56/2352 [00:06<03:56,  9.72it/s]

inputs in forward => 16
inputs in forward => 16


  2%|▏         | 58/2352 [00:06<03:55,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 60/2352 [00:06<03:55,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 62/2352 [00:06<03:54,  9.78it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 64/2352 [00:06<03:57,  9.65it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 66/2352 [00:07<03:57,  9.64it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 68/2352 [00:07<04:02,  9.43it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 70/2352 [00:07<04:01,  9.47it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 72/2352 [00:07<03:59,  9.52it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 74/2352 [00:07<04:01,  9.42it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 76/2352 [00:08<03:59,  9.51it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 78/2352 [00:08<03:59,  9.49it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 80/2352 [00:08<03:58,  9.51it/s]

inputs in forward => 16
inputs in forward => 16


  3%|▎         | 82/2352 [00:08<03:55,  9.62it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▎         | 84/2352 [00:08<03:54,  9.65it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▎         | 86/2352 [00:09<03:55,  9.62it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▎         | 88/2352 [00:09<03:55,  9.63it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 90/2352 [00:09<03:55,  9.61it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 92/2352 [00:09<03:53,  9.69it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 94/2352 [00:09<03:51,  9.77it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 96/2352 [00:10<03:52,  9.72it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 98/2352 [00:10<03:50,  9.78it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 100/2352 [00:10<03:51,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 102/2352 [00:10<03:51,  9.73it/s]

inputs in forward => 16
inputs in forward => 16


  4%|▍         | 104/2352 [00:11<03:53,  9.62it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▍         | 106/2352 [00:11<03:52,  9.67it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▍         | 108/2352 [00:11<03:52,  9.67it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▍         | 111/2352 [00:11<03:50,  9.73it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▍         | 112/2352 [00:11<03:51,  9.67it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▍         | 114/2352 [00:12<03:52,  9.62it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▍         | 116/2352 [00:12<03:50,  9.69it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▌         | 118/2352 [00:12<03:49,  9.73it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▌         | 120/2352 [00:12<03:49,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▌         | 122/2352 [00:12<03:49,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▌         | 124/2352 [00:13<03:48,  9.77it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▌         | 126/2352 [00:13<03:48,  9.75it/s]

inputs in forward => 16
inputs in forward => 16


  5%|▌         | 128/2352 [00:13<03:48,  9.72it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 130/2352 [00:13<03:48,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 132/2352 [00:13<03:50,  9.64it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 134/2352 [00:14<03:48,  9.71it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 137/2352 [00:14<03:45,  9.83it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 138/2352 [00:14<03:47,  9.73it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 140/2352 [00:14<03:46,  9.75it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 143/2352 [00:15<03:46,  9.77it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 144/2352 [00:15<03:45,  9.78it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▌         | 146/2352 [00:15<03:46,  9.76it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▋         | 148/2352 [00:15<03:47,  9.70it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▋         | 150/2352 [00:15<03:47,  9.69it/s]

inputs in forward => 16
inputs in forward => 16


  6%|▋         | 152/2352 [00:15<03:48,  9.64it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 154/2352 [00:16<03:47,  9.66it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 156/2352 [00:16<03:45,  9.72it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 158/2352 [00:16<03:46,  9.70it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 161/2352 [00:16<03:43,  9.79it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 162/2352 [00:16<03:45,  9.73it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 164/2352 [00:17<03:45,  9.69it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 166/2352 [00:17<03:43,  9.76it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 169/2352 [00:17<03:42,  9.80it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 170/2352 [00:17<03:43,  9.78it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 172/2352 [00:18<03:42,  9.80it/s]

inputs in forward => 16
inputs in forward => 16


  7%|▋         | 175/2352 [00:18<03:43,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 177/2352 [00:18<03:41,  9.81it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 178/2352 [00:18<03:42,  9.79it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 180/2352 [00:18<03:41,  9.80it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 182/2352 [00:19<03:42,  9.76it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 184/2352 [00:19<03:45,  9.63it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 186/2352 [00:19<03:42,  9.75it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 189/2352 [00:19<03:41,  9.77it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 190/2352 [00:19<03:42,  9.71it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 192/2352 [00:20<03:41,  9.73it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 195/2352 [00:20<03:39,  9.83it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 197/2352 [00:20<03:39,  9.82it/s]

inputs in forward => 16
inputs in forward => 16


  8%|▊         | 199/2352 [00:20<03:39,  9.83it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▊         | 201/2352 [00:20<03:38,  9.85it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▊         | 202/2352 [00:21<03:43,  9.62it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▊         | 204/2352 [00:21<03:42,  9.64it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 206/2352 [00:21<03:42,  9.64it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 208/2352 [00:21<03:41,  9.67it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 210/2352 [00:21<03:40,  9.69it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 212/2352 [00:22<03:38,  9.79it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 214/2352 [00:22<03:38,  9.80it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 216/2352 [00:22<03:41,  9.63it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 218/2352 [00:22<03:41,  9.62it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 220/2352 [00:22<03:39,  9.70it/s]

inputs in forward => 16
inputs in forward => 16


  9%|▉         | 222/2352 [00:23<03:39,  9.68it/s]

inputs in forward => 16
inputs in forward => 16


 10%|▉         | 225/2352 [00:23<03:38,  9.75it/s]

inputs in forward => 16
inputs in forward => 16


 10%|▉         | 226/2352 [00:23<03:38,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


 10%|▉         | 228/2352 [00:23<03:38,  9.74it/s]

inputs in forward => 16
inputs in forward => 16


 10%|▉         | 230/2352 [00:23<03:37,  9.76it/s]

inputs in forward => 16
inputs in forward => 16


 10%|▉         | 232/2352 [00:24<03:37,  9.77it/s]

inputs in forward => 16
inputs in forward => 16


 10%|▉         | 234/2352 [00:24<03:38,  9.69it/s]

inputs in forward => 16
inputs in forward => 16


 10%|█         | 237/2352 [00:24<03:35,  9.79it/s]

inputs in forward => 16
inputs in forward => 16


 10%|█         | 239/2352 [00:24<03:35,  9.80it/s]

inputs in forward => 16
inputs in forward => 16


 10%|█         | 240/2352 [00:25<03:38,  9.67it/s]

inputs in forward => 16
inputs in forward => 16


 10%|█         | 242/2352 [00:25<03:43,  9.45it/s]

inputs in forward => 16
inputs in forward => 16


 10%|█         | 244/2352 [00:25<03:46,  9.29it/s]

inputs in forward => 16
inputs in forward => 16


 10%|█         | 246/2352 [00:25<03:48,  9.22it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 248/2352 [00:25<03:51,  9.09it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 250/2352 [00:26<03:51,  9.08it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 252/2352 [00:26<03:50,  9.10it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 254/2352 [00:26<03:50,  9.11it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 256/2352 [00:26<03:52,  9.02it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 258/2352 [00:26<03:50,  9.09it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 260/2352 [00:27<03:51,  9.03it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 262/2352 [00:27<03:49,  9.11it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█         | 264/2352 [00:27<03:48,  9.13it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█▏        | 266/2352 [00:27<03:47,  9.15it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█▏        | 268/2352 [00:28<03:47,  9.16it/s]

inputs in forward => 16
inputs in forward => 16


 11%|█▏        | 270/2352 [00:28<03:48,  9.12it/s]

inputs in forward => 16
inputs in forward => 16


 12%|█▏        | 272/2352 [00:28<03:48,  9.10it/s]

inputs in forward => 16
inputs in forward => 16


 12%|█▏        | 274/2352 [00:28<03:47,  9.12it/s]

inputs in forward => 16
inputs in forward => 16


 12%|█▏        | 276/2352 [00:28<03:47,  9.11it/s]

inputs in forward => 16
inputs in forward => 16


 12%|█▏        | 278/2352 [00:29<03:46,  9.14it/s]

inputs in forward => 16
inputs in forward => 16


 12%|█▏        | 280/2352 [00:29<03:46,  9.16it/s]

inputs in forward => 16
inputs in forward => 16


 12%|█▏        | 282/2352 [00:29<03:46,  9.12it/s]

inputs in forward => 16
inputs in forward => 16


 12%|█▏        | 283/2352 [00:29<03:46,  9.15it/s]

inputs in forward => 16


KeyboardInterrupt: 