In [1]:
%load_ext autoreload
%autoreload 2

In [None]:

import sys

sys.path.insert(0,'/home/ytaille/pyner/pyner')

from pyner import NER, Vocabulary, NER_MTL
from pyner._datasets import BRATDataset
import string
import torch
import pytorch_lightning as pl
from rich_logger import RichTableLogger

task_names = ['conll', 'mantra', 'n2c2', 'quaero_full', 'quaero_medline', 'quaero_emea', 'wikiann', 'cas_pos', 'ftb']

# bert_name = "bert-base-german-cased"
bert_name = "bert-base-multilingual-cased"
# bert_name = "emilyalsentzer/Bio_ClinicalBERT"
# bert_name = 'camembert-base'
model = NER_MTL(
    seed=42,
    preprocessor=dict(
        module="preprocessor",
        bert_name=bert_name, # transformer name
        sentence_split_regex=r"((?:\s*\n)+\s*|(?:(?<=[a-z0-9)]\.)\s+))(?=[A-Z-])", # regex to use to split sentences (must not contain consuming patterns)
        sentence_balance_chars=('()',), # try to avoid splitting between parentheses
        sentence_entity_overlap="split", # raise when an entity spans more than one sentence, or use "split" to split entities in 2 when this happens
        word_regex='[\\w\']+|[!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~]', # regex to use to extract words (will be aligned with bert tokens), leave to None to use wordpieces as is
        substitutions=( # Apply these regex substitutions on sentences before tokenizing
            (r"(?<=[{}\\])(?![ ])".format(string.punctuation), r" "), # insert a space before punctuations
            (r"(?<![ ])(?=[{}\\])".format(string.punctuation), r" "), # insert a space after punctuations
            #("(?<=[a-zA-Z])(?=[0-9])", r" "), # insert a space between letters and numbers
            #("(?<=[0-9])(?=[A-Za-z])", r" "), # insert a space between numbers and letters
        ),
        max_tokens=512,         # Maximum number of tokens in a sentence (will split if more than this number)
                                # Must be equal to or lower than the max number of tokens in the Bert model
        large_sentences="equal-split", # for these large sentences, split them in equal sub sentences < max_tokens tokens 
        empty_entities="drop", # when an entity cannot be mapped to any word, "raise" or "drop"
        vocabularies=torch.nn.ModuleDict({ # vocabularies to use, call .train() before initializing to fill/complete them automatically from training data
            "char": Vocabulary(string.punctuation + string.ascii_letters + string.digits, with_unk=True, with_pad=True),
            **{f"{task_name}_label" : Vocabulary(with_unk=True, with_pad=False)
                  for task_name in task_names
              }
        }).train(),
    ),

    # Word encore parameters
    word_encoders=[
        dict(
            module="char_cnn",
            n_chars=None, # automatically inferred from data
            in_channels=8,
            out_channels=50,
            kernel_sizes=(3, 4, 5),
        ),
        dict(
            module="bert",
            path=bert_name,
            n_layers=4,
            freeze_n_layers=0, # unfreeze all
            dropout_p=0.1,
        )
    ],
    
    # Decoder parameters
    decoders={
        "quaero_full": dict(
        module="exhaustive_biaffine_ner",
        dim=192,
        label_dim=64,
        n_labels=None, # automatically inferred from data
        dropout_p=0.,
        use_batch_norm=False,
        contextualizer=dict(
            module="lstm",
            # use gate = False for better performance but slower convergence (needs ~50 epochs)
            gate=dict(
                module="sigmoid_gate",
                ln_mode="pre",
                init_value=0,
                proj=False,
                dim=192,
            ),
            input_size=768 + 150,
            hidden_size=192,
            num_layers=4,
            dropout_p=0.,
        )), 
        "ftb": dict(
        module="exhaustive_biaffine_ner",
        dim=192,
        label_dim=64,
        n_labels=None, # automatically inferred from data
        dropout_p=0.,
        use_batch_norm=False,
        contextualizer=dict(
            module="lstm",
            # use gate = False for better performance but slower convergence (needs ~50 epochs)
            gate=dict(
                module="sigmoid_gate",
                ln_mode="pre",
                init_value=0,
                proj=False,
                dim=192,
            ),
            input_size=768 + 150,
            hidden_size=192,
            num_layers=4,
            dropout_p=0.,
        )), 
        
#         "quaero_medline": dict(
#         module="exhaustive_biaffine_ner",
#         dim=192,
#         label_dim=64,
#         n_labels=None, # automatically inferred from data
#         dropout_p=0.,
#         use_batch_norm=False,
#         contextualizer=dict(
#             module="lstm",
#             # use gate = False for better performance but slower convergence (needs ~50 epochs)
#             gate=dict(
#                 module="sigmoid_gate",
#                 ln_mode="pre",
#                 init_value=0,
#                 proj=False,
#                 dim=192,
#             ),
#             input_size=768 + 150,
#             hidden_size=192,
#             num_layers=4,
#             dropout_p=0.,
#         )), 
#         "quaero_emea": dict(
#         module="exhaustive_biaffine_ner",
#         dim=192,
#         label_dim=64,
#         n_labels=None, # automatically inferred from data
#         dropout_p=0.,
#         use_batch_norm=False,
#         contextualizer=dict(
#             module="lstm",
#             # use gate = False for better performance but slower convergence (needs ~50 epochs)
#             gate=dict(
#                 module="sigmoid_gate",
#                 ln_mode="pre",
#                 init_value=0,
#                 proj=False,
#                 dim=192,
#             ),
#             input_size=768 + 150,
#             hidden_size=192,
#             num_layers=4,
#             dropout_p=0.,
#         )), 
        
#         "n2c2":dict(
#         module="exhaustive_biaffine_ner",
#         dim=192,
#         label_dim=64,
#         n_labels=None, # automatically inferred from data
#         dropout_p=0.,
#         use_batch_norm=False,
#         contextualizer=dict(
#             module="lstm",
#             # use gate = False for better performance but slower convergence (needs ~50 epochs)
#             gate=dict(
#                 module="sigmoid_gate",
#                 ln_mode="pre",
#                 init_value=0,
#                 proj=False,
#                 dim=192,
#             ),
#             input_size=768 + 150,
#             hidden_size=192,
#             num_layers=4,
#             dropout_p=0.,
#         )),
        
#         "mantra":dict(
#         module="exhaustive_biaffine_ner",
#         dim=192,
#         label_dim=64,
#         n_labels=None, # automatically inferred from data
#         dropout_p=0.,
#         use_batch_norm=False,
#         contextualizer=dict(
#             module="lstm",
#             # use gate = False for better performance but slower convergence (needs ~50 epochs)
#             gate=dict(
#                 module="sigmoid_gate",
#                 ln_mode="pre",
#                 init_value=0,
#                 proj=False,
#                 dim=192,
#             ),
#             input_size=768 + 150,
#             hidden_size=192,
#             num_layers=4,
#             dropout_p=0.,
#         )),
        
#         "wikiann":dict(
#         module="exhaustive_biaffine_ner",
#         dim=192,
#         label_dim=64,
#         n_labels=None, # automatically inferred from data
#         dropout_p=0.,
#         use_batch_norm=False,
#         contextualizer=dict(
#             module="lstm",
#             # use gate = False for better performance but slower convergence (needs ~50 epochs)
#             gate=dict(
#                 module="sigmoid_gate",
#                 ln_mode="pre",
#                 init_value=0,
#                 proj=False,
#                 dim=192,
#             ),
#             input_size=768 + 150,
#             hidden_size=192,
#             num_layers=4,
#             dropout_p=0.,
#         )), 
        
#         "conll":dict(
#         module="exhaustive_biaffine_ner",
#         dim=192,
#         label_dim=64,
#         n_labels=None, # automatically inferred from data
#         dropout_p=0.,
#         use_batch_norm=False,
#         contextualizer=dict(
#             module="lstm",
#             # use gate = False for better performance but slower convergence (needs ~50 epochs)
#             gate=dict(
#                 module="sigmoid_gate",
#                 ln_mode="pre",
#                 init_value=0,
#                 proj=False,
#                 dim=192,
#             ),
#             input_size=768 + 150,
#             hidden_size=192,
#             num_layers=4,
#             dropout_p=0.,
#         )),
        
#         "cas_pos":dict(
#         module="exhaustive_biaffine_ner",
#         dim=192,
#         label_dim=64,
#         n_labels=None, # automatically inferred from data
#         dropout_p=0.,
#         use_batch_norm=False,
#         contextualizer=dict(
#             module="lstm",
#             # use gate = False for better performance but slower convergence (needs ~50 epochs)
#             gate=dict(
#                 module="sigmoid_gate",
#                 ln_mode="pre",
#                 init_value=0,
#                 proj=False,
#                 dim=192,
#             ),
#             input_size=768 + 150,
#             hidden_size=192,
#             num_layers=4,
#             dropout_p=0.,
#         )),
    },

    # Initialize last classifying layer bias with log frequencies from labels in data
    init_labels_bias=True,

    batch_size=8,
    
    # Use learning rate schedules (linearly decay with warmup)
    use_lr_schedules=False,
    warmup_rate=0.1,

    gradient_clip_val=5.,
    
    # Learning rates
    main_lr=1.5e-3,
    top_lr=1.5e-3,
    bert_lr=4e-5,
    
    # Optimizer, can be class or str
    optimizer_cls="transformers.AdamW",
    
    share_contextualizers="hybrid",
).train()

flt_format = (5, "{:.4f}".format)
trainer = pl.Trainer(
    gpus=1,
    progress_bar_refresh_rate=False,
    move_metrics_to_cpu=True,
    logger=[
        #        pl.loggers.TestTubeLogger("path/to/logs", name="my_experiment"),
        RichTableLogger(key="epoch", fields={
            "epoch": {},
            "step": {},
            "train_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": "train_loss"},
            "val_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": "val_loss"},
            **{k:v 
            for task_name in model.decoders.keys() for k,v in {
            f"{task_name}_train_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": f"{task_name}_train_loss"},
            f"{task_name}_train_f1": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_train_f1"},
            f"{task_name}_train_precision": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_train_p"},
            f"{task_name}_train_recall": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_train_r"},

            f"{task_name}_val_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": f"{task_name}_val_loss"},
            f"{task_name}_val_f1": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_val_f1"},
            f"{task_name}_val_precision": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_val_p"},
            f"{task_name}_val_recall": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_val_r"},
            }.items()},

            "main_lr": {"format": "{:.2e}"},
            "top_lr": {"format": "{:.2e}"},
            "bert_lr": {"format": "{:.2e}"},
        },
       ),
    ],
    max_epochs=50)


# N2C2 PATH:
# /home/ytaille/data/resources/n2c2/brat_files/train

# QUAERO PATH:
# /home/ytaille/data/resources/quaero/corpus/train/MEDLINE

# MANTRA PATH:
# /home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/EMEA_ec22-cui-best_man
# /home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/Medline_EN_FR_ec22-cui-best_man

# CONLL PATH:
# /home/ytaille/data/resources/conll/brat/eng/train

dropped_entity_label = [] #['ACTI', 'CONC', 'GEOG']

# conll_dataset = BRATDataset(
#     train=[
#         "/home/ytaille/data/resources/conll/brat/eng/train",
#     ],
#     test="/home/ytaille/data/resources/conll/brat/eng/test",#"path/to/brat/test",    # None for training only, test directory otherwise
#     val="/home/ytaille/data/resources/conll/brat/eng/dev", # first 20% doc will be for validation
#     seed=42,  # don't shuffle before splitting1
#     dropped_entity_label=dropped_entity_label,
# )

# n2c2_dataset = BRATDataset(
#     train=[
#         "/home/ytaille/data/resources/n2c2/brat_files/train",
#     ],
#     test="/home/ytaille/data/resources/n2c2/brat_files/test",#"path/to/brat/test",    # None for training only, test directory otherwise
#     val="/home/ytaille/data/resources/n2c2/brat_files/test", # first 20% doc will be for validation
#     seed=42,  # don't shuffle before splitting
#     dropped_entity_label=dropped_entity_label,
# )

ftb_dataset = BRATDataset(
    train=[
        "/home/ytaille/data/resources/french_treebank/ftb6/brat/train",
    ],
    test="/home/ytaille/data/resources/french_treebank/ftb6/brat/test",#"path/to/brat/test",    # None for training only, test directory otherwise
    val=[
        "/home/ytaille/data/resources/french_treebank/ftb6/brat/dev",
        
    ],# first 20% doc will be for validation
    seed=42,  # don't shuffle before splitting
    dropped_entity_label=dropped_entity_label,
)

quaero_full_dataset = BRATDataset(
    train=[
        "/home/ytaille/data/resources/quaero/corpus/train/MEDLINE",
        "/home/ytaille/data/resources/quaero/corpus/train/EMEA",
    ],
    test="/home/ytaille/data/resources/quaero/corpus/test/MEDLINE",#"path/to/brat/test",    # None for training only, test directory otherwise
    val=[
        "/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/Medline_EN_FR_ec22-cui-best_man",
        "/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/EMEA_ec22-cui-best_man",
        
    ],# first 20% doc will be for validation
    seed=42,  # don't shuffle before splitting
    dropped_entity_label=dropped_entity_label,
)

quaero_medline_dataset = BRATDataset(
    train=[
        "/home/ytaille/data/resources/quaero/corpus/train/MEDLINE",
    ],
    test="/home/ytaille/data/resources/quaero/corpus/test/MEDLINE",#"path/to/brat/test",    # None for training only, test directory otherwise
    val=[
        "/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/Medline_EN_FR_ec22-cui-best_man",
        "/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/EMEA_ec22-cui-best_man",
        
    ],# first 20% doc will be for validation
    seed=42,  # don't shuffle before splitting
    dropped_entity_label=dropped_entity_label,
)

quaero_emea_dataset = BRATDataset(
    train=[
        "/home/ytaille/data/resources/quaero/corpus/train/EMEA",
    ],
    test="/home/ytaille/data/resources/quaero/corpus/test/EMEA",#"path/to/brat/test",    # None for training only, test directory otherwise
    val=[
        "/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/Medline_EN_FR_ec22-cui-best_man",
        "/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/EMEA_ec22-cui-best_man",
        
    ],# first 20% doc will be for validation
    seed=42,  # don't shuffle before splitting
    dropped_entity_label=dropped_entity_label,
)

mantra_dataset = BRATDataset(
    train=[
        "/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/EMEA_ec22-cui-best_man",
    ],
    test="/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/Medline_EN_FR_ec22-cui-best_man",#"path/to/brat/test",    # None for training only, test directory otherwise
    val="/home/ytaille/data/resources/mantra/Mantra-GSC_new_ann/French/Medline_EN_FR_ec22-cui-best_man", # first 20% doc will be for validation
    seed=42,  # don't shuffle before splitting
    dropped_entity_label=dropped_entity_label,
)

# wikiann_dataset = BRATDataset(
#     train=[
#         "/home/ytaille/data/resources/wikiann/fr/train",
#     ],
#     test="/home/ytaille/data/resources/wikiann/fr/test",# None for training only, test directory otherwise
#     val="/home/ytaille/data/resources/wikiann/fr/val", # first 20% doc will be for validation
#     seed=42,  # don't shuffle before splitting
#     dropped_entity_label=dropped_entity_label,
# )

cas_pos_dataset = BRATDataset(
    train=[
        "/home/ytaille/data/resources/corpus_dalloux/CAS_POS_brat",
    ],
    test="/home/ytaille/data/resources/corpus_dalloux/CAS_POS_brat", 
    val="/home/ytaille/data/resources/corpus_dalloux/CAS_POS_brat", 
    seed=42,  # don't shuffle before splitting
    dropped_entity_label=dropped_entity_label,
)

# PRENDRE LSTM EN COMMUN AU LIEU DE DIFFÉRENCER LES DECODERS ENTIERS (seule la dernière couche est spécifique aux tâches)

# PROBLEM IN VAL LOADING -> use CombinedLoader -> OK

MTL_data = {
    "quaero_full": quaero_full_dataset,
    "ftb": ftb_dataset,
#     "quaero_medline": quaero_medline_dataset,
#     "quaero_emea": quaero_emea_dataset,
#     "mantra": mantra_dataset,
#     "cas_pos": cas_pos_dataset,
#     "wikiann": wikiann_dataset,
#     "conll": conll_dataset,
#     "n2c2": n2c2_dataset,
}

# assert model.decoders.keys() == MTL_data.keys(), f"Datasets and decoders tasks don't match.\n Datasets:{MTL_data.keys()}\nDecoders:{model.decoders.keys()}"

try:
    print("Starting model fit")
    print(f"Will fit for tasks: {', '.join(model.decoders.keys())}")
    trainer.fit(model, MTL_data)

    # Save logs and config

    from datetime import datetime
    import pathlib
    import os
    
    log_dir = "/home/ytaille/pyner/logs/"
    exp_dir = os.path.join(log_dir, "_".join(MTL_data.keys()))
    exp_sub_dir = os.path.join(exp_dir, datetime.now().strftime("%d_%m_%Y_%H_%M_%S"))

    pathlib.Path(exp_sub_dir).mkdir(parents=True, exist_ok=True)
    table_html_file = os.path.join(exp_sub_dir, "table.html")
    config_file = os.path.join(exp_sub_dir, "config.json")

    console = trainer.logger[0].printer.console
    table = trainer.logger[0].printer.table
    with console.capture() as capture: 
        console.print(table)
    console.save_html(table_html_file)
    
    from torch_utils import get_config
    import json
    with open(config_file, 'w') as fp:
        json.dump(get_config(model), fp)
    
except:
    print("SOMETHING FAILED DURING FIT")
    import traceback
    traceback.print_exc()
    
    from datetime import datetime
    import pathlib
    import os
    log_dir = "/home/ytaille/pyner/logs/"
    exp_dir = os.path.join(log_dir, "_".join(MTL_data.keys()))
    exp_sub_dir = os.path.join(exp_dir, datetime.now().strftime("%d_%m_%Y_%H_%M_%S"))

    pathlib.Path(exp_sub_dir).mkdir(parents=True, exist_ok=True)
    table_html_file = os.path.join(exp_sub_dir, "table.html")
    config_file = os.path.join(exp_sub_dir, "config.json")

    console = trainer.logger[0].printer.console
    table = trainer.logger[0].printer.table
    with console.capture() as capture:
        console.print(table)
    console.save_html(table_html_file)
    
    from torch_utils import get_config
    import json
    with open(config_file, 'w') as fp:
        json.dump(get_config(model), fp)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


Starting model fit
Will fit for tasks: quaero_full, ftb


Token indices sequence length is longer than the specified maximum sequence length for this model (5396 > 512). Running this sequence through the model will result in indexing errors
  return super(Tensor, self).rename(names)
Set SLURM handle signals.

  | Name          | Type         | Params
-----------------------------------------------
0 | preprocessor  | Preprocessor | 0     
1 | word_encoders | ModuleList   | 177 M 
2 | decoders      | ModuleDict   | 4.2 M 
-----------------------------------------------
182 M     Trainable params
0         Non-trainable params
182 M     Total params
728.320   Total estimated model params size (MB)


In [5]:
trainer = pl.Trainer(
    gpus=1,
    progress_bar_refresh_rate=False,
    move_metrics_to_cpu=True,
    logger=[
        #        pl.loggers.TestTubeLogger("path/to/logs", name="my_experiment"),
        RichTableLogger(key="epoch", fields={
            "epoch": {},
            "step": {},
            "train_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": "train_loss"},
            "val_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": "val_loss"},
            **{k:v 
            for task_name in model.decoders.keys() for k,v in {
            f"{task_name}_train_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": f"{task_name}_train_loss"},
            f"{task_name}_train_f1": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_train_f1"},
            f"{task_name}_train_precision": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_train_p"},
            f"{task_name}_train_recall": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_train_r"},

            f"{task_name}_val_loss": {"goal": "lower_is_better", "format": "{:.4f}", "name": f"{task_name}_val_loss"},
            f"{task_name}_val_f1": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_val_f1"},
            f"{task_name}_val_precision": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_val_p"},
            f"{task_name}_val_recall": {"goal": "higher_is_better", "format": "{:.4f}", "name": f"{task_name}_val_r"},
            }.items()},

            "main_lr": {"format": "{:.2e}"},
            "top_lr": {"format": "{:.2e}"},
            "bert_lr": {"format": "{:.2e}"},
        },
       ),
    ],
    max_epochs=50)

trainer.fit(model, MTL_data)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Set SLURM handle signals.

  | Name          | Type         | Params
-----------------------------------------------
0 | preprocessor  | Preprocessor | 0     
1 | word_encoders | ModuleList   | 177 M 
2 | decoders      | ModuleDict   | 2.7 M 
-----------------------------------------------
180 M     Trainable params
0         Non-trainable params
180 M     Total params
722.336   Total estimated model params size (MB)


1

In [None]:
# MTL -> mask labels au lieu de faire des têtes différentes
# Aligner labels multilingues
# Méga entraînement avec tous les datasets et toutes les tâches

In [2]:
from datetime import datetime
import pathlib
import os
log_dir = "/home/ytaille/pyner/logs/"
exp_dir = os.path.join(log_dir, "_".join(MTL_data.keys()))
exp_sub_dir = os.path.join(exp_dir, datetime.now().strftime("%d_%m_%Y_%H_%M_%S"))

pathlib.Path(exp_sub_dir).mkdir(parents=True, exist_ok=True)
table_html_file = os.path.join(exp_sub_dir, "table.html")
config_file = os.path.join(exp_sub_dir, "config.json")

console = trainer.logger[0].printer.console
table = trainer.logger[0].printer.table
with console.capture() as capture:
    console.print(table)
console.save_html(table_html_file)

from torch_utils import get_config
import json
with open(config_file, 'w') as fp:
    json.dump(get_config(model), fp)

In [None]:
# PROBLEME QUAND VAL=NONE dans bratdataset -> preprocess has no attribute "chain"