In [1]:
import json
from pathlib import Path
import itertools as it

import flair
import torch
from flair.data import Corpus
from flair.embeddings import TransformerWordEmbeddings
from flair.optim import LinearSchedulerWithWarmup
from hyperopt import hp
from torch.optim import AdamW
from transformers import AdamW

from meddocan.data.corpus import MEDDOCAN
from meddocan.hyperparameter.param_selection import (
    OptimizationValue,
    Parameter,
    SearchSpace,
    SequenceTaggerParamSelector,
)
from meddocan.hyperparameter.parameter import Parameter

base_path = Path(__file__).parent

SEED = 1

flair.set_seed(SEED)
flair.device = torch.device("cuda:0")

def training(sentences, window):  
    # 1. get the corpus
    corpus: Corpus = MEDDOCAN(sentences=sentences, window=window, document_separator_token="-DOCSTART-")
    print(corpus)

    if sentences:
        emb_name = f"beto-cased-context_FT_True_Ly_-1_seed_{SEED}"
    else:
        emb_name = f"beto-cased-context_window_{window}_FT_True_Ly_-1_seed_{SEED}"

    # 3. make the label dictionary from the corpus
    label_dict = corpus.make_label_dictionary(label_type="ner")

    statistics = corpus.obtain_statistics("ner", pretty_print=False)
    for set, value in statistics.items():
        for label in value["number_of_documents_per_class"].keys():
            label_dict.add_item(label)

    print(json.dumps(statistics, indent=4))

    # 4. Define your search space
    search_space = SearchSpace()
    search_space.add(
        Parameter.EMBEDDINGS,
        hp.choice,
        options=[
            TransformerWordEmbeddings(
                model="dccuchile/bert-base-spanish-wwm-cased",
                fine_tune=True,
                layers="-1",
                use_context=64,
                layer_mean=True,
                name=emb_name,
                subtoken_pooling="first",
                allow_long_sentences=True,
            ),
        ],
    )
    search_space.add(Parameter.USE_CRF, hp.choice, options=[False])
    search_space.add(Parameter.USE_RNN, hp.choice, options=[False])
    search_space.add(Parameter.REPROJECT_EMBEDDINGS, hp.choice, options=[False])
    search_space.add(Parameter.NUM_WORKERS, hp.choice, options=[4])
    search_space.add(Parameter.DROPOUT, hp.choice, options=[0])
    search_space.add(Parameter.LEARNING_RATE, hp.choice, options=[5e-6])
    search_space.add(Parameter.MINI_BATCH_SIZE, hp.choice, options=[4])
    search_space.add(Parameter.ANNEAL_WITH_RESTARTS, hp.choice, options=[False])
    search_space.add(Parameter.OPTIMIZER, hp.choice, options=[AdamW])
    search_space.add(
        Parameter.SCHEDULER, hp.choice, options=[LinearSchedulerWithWarmup]
    )
    search_space.add(Parameter.WARMUP_FRACTION, hp.choice, options=[0.1])
    search_space.add(Parameter.EMBEDDINGS_STORAGE_MODE, hp.choice, options=["gpu"])
    search_space.add(Parameter.MAX_EPOCHS, hp.choice, options=[10])

    # 5. Create the parameter selector
    param_selector = SequenceTaggerParamSelector(
        corpus.downsample(0.1),
        "ner",
        base_path,
        training_runs=1,
        optimization_value=OptimizationValue.DEV_SCORE,
        tensorboard_logdir=base_path / "logs",
        save_model=True,
    )
    param_selector.tag_dictionary = label_dict

    # 6. Start the optimization
    param_selector.optimize(search_space, max_evals=1)

sentences_yet = False

for window, sentences in it.product([200, 100, 80, 60, 40, 20], [False, True]):
    print(f"sentences is {sentences}")
    if sentences:
        if not sentences_yet:
            print(sentences, window)
            training(sentences, window)
            sentences_yet = sentences
    else:
        training(sentences, window)
    torch.cuda.empty_cache()

sentences is False




2022-09-29 15:34:27,811 Reading data from /tmp/tmph5vsgm1y
2022-09-29 15:34:27,812 Train: /tmp/tmph5vsgm1y/train
2022-09-29 15:34:27,813 Dev: /tmp/tmph5vsgm1y/dev
2022-09-29 15:34:27,813 Test: /tmp/tmph5vsgm1y/test
Corpus: 10811 train + 5518 dev + 5405 test sentences
2022-09-29 15:34:35,139 Computing label dictionary. Progress:


10811it [00:00, 39841.58it/s]

2022-09-29 15:34:35,436 Dictionary created for label 'ner' with 22 values: TERRITORIO (seen 1875 times), FECHAS (seen 1231 times), EDAD_SUJETO_ASISTENCIA (seen 1035 times), NOMBRE_SUJETO_ASISTENCIA (seen 1009 times), NOMBRE_PERSONAL_SANITARIO (seen 1000 times), SEXO_SUJETO_ASISTENCIA (seen 925 times), CALLE (seen 862 times), PAIS (seen 713 times), ID_SUJETO_ASISTENCIA (seen 567 times), ID_TITULACION_PERSONAL_SANITARIO (seen 471 times), CORREO_ELECTRONICO (seen 469 times), ID_ASEGURAMIENTO (seen 391 times), HOSPITAL (seen 255 times), FAMILIARES_SUJETO_ASISTENCIA (seen 243 times), INSTITUCION (seen 98 times), ID_CONTACTO_ASISTENCIAL (seen 77 times), NUMERO_TELEFONO (seen 58 times), PROFESION (seen 24 times), NUMERO_FAX (seen 15 times), OTROS_SUJETO_ASISTENCIA (seen 9 times)





{
    "TRAIN": {
        "dataset": "TRAIN",
        "total_number_of_documents": 10811,
        "number_of_documents_per_class": {
            "NOMBRE_SUJETO_ASISTENCIA": 1009,
            "ID_SUJETO_ASISTENCIA": 567,
            "ID_ASEGURAMIENTO": 391,
            "CALLE": 862,
            "TERRITORIO": 1875,
            "FECHAS": 1231,
            "PAIS": 713,
            "EDAD_SUJETO_ASISTENCIA": 1035,
            "SEXO_SUJETO_ASISTENCIA": 925,
            "NOMBRE_PERSONAL_SANITARIO": 1000,
            "ID_TITULACION_PERSONAL_SANITARIO": 471,
            "CORREO_ELECTRONICO": 469,
            "HOSPITAL": 255,
            "FAMILIARES_SUJETO_ASISTENCIA": 243,
            "OTROS_SUJETO_ASISTENCIA": 9,
            "INSTITUCION": 98,
            "NUMERO_TELEFONO": 58,
            "ID_CONTACTO_ASISTENCIAL": 77,
            "NUMERO_FAX": 15,
            "CENTRO_SALUD": 6,
            "PROFESION": 24
        },
        "number_of_tokens_per_tag": {},
        "number_of_tokens": {
        

1081it [00:00, 37852.45it/s]

2022-09-29 15:34:39,340 Dictionary created for label 'ner' with 22 values: TERRITORIO (seen 194 times), FECHAS (seen 116 times), EDAD_SUJETO_ASISTENCIA (seen 114 times), NOMBRE_PERSONAL_SANITARIO (seen 109 times), NOMBRE_SUJETO_ASISTENCIA (seen 105 times), SEXO_SUJETO_ASISTENCIA (seen 98 times), CALLE (seen 83 times), PAIS (seen 68 times), ID_TITULACION_PERSONAL_SANITARIO (seen 57 times), ID_SUJETO_ASISTENCIA (seen 51 times), ID_ASEGURAMIENTO (seen 47 times), CORREO_ELECTRONICO (seen 46 times), HOSPITAL (seen 31 times), FAMILIARES_SUJETO_ASISTENCIA (seen 23 times), INSTITUCION (seen 7 times), PROFESION (seen 6 times), ID_CONTACTO_ASISTENCIAL (seen 5 times), NUMERO_TELEFONO (seen 3 times), OTROS_SUJETO_ASISTENCIA (seen 2 times), CENTRO_SALUD (seen 2 times)
  0%|          | 0/1 [00:00<?, ?trial/s, best loss=?]2022-09-29 15:34:39,354 ----------------------------------------------------------------------------------------------------
2022-09-29 15:34:39,355 Evaluation run: 1
2022-09-29 15:





2022-09-29 15:34:46,670 epoch 1 - iter 27/271 - loss 4.74793462 - samples/sec: 15.00 - lr: 0.000000
2022-09-29 15:34:53,365 epoch 1 - iter 54/271 - loss 4.70221438 - samples/sec: 16.14 - lr: 0.000001
2022-09-29 15:35:01,810 epoch 1 - iter 81/271 - loss 4.56019087 - samples/sec: 12.79 - lr: 0.000001
2022-09-29 15:35:08,581 epoch 1 - iter 108/271 - loss 4.44216144 - samples/sec: 15.96 - lr: 0.000002
2022-09-29 15:35:14,485 epoch 1 - iter 135/271 - loss 4.31120824 - samples/sec: 18.30 - lr: 0.000002
2022-09-29 15:35:20,759 epoch 1 - iter 162/271 - loss 4.06750852 - samples/sec: 17.23 - lr: 0.000003
2022-09-29 15:35:26,755 epoch 1 - iter 189/271 - loss 3.83177936 - samples/sec: 18.03 - lr: 0.000003
2022-09-29 15:35:33,458 epoch 1 - iter 216/271 - loss 3.45376651 - samples/sec: 16.12 - lr: 0.000004
2022-09-29 15:35:34,173 ----------------------------------------------------------------------------------------------------
2022-09-29 15:35:34,174 Exiting from training early.
2022-09-29 15:35:

  0%|          | 0/135 [00:00<?, ?it/s]
  1%|1         | 2/135 [00:00<00:13,  9.67it/s]
  2%|2         | 3/135 [00:00<00:13,  9.50it/s]
  3%|2         | 4/135 [00:00<00:14,  9.29it/s]
  4%|3         | 5/135 [00:00<00:14,  9.16it/s]
  5%|5         | 7/135 [00:00<00:13,  9.76it/s]
  6%|5         | 8/135 [00:00<00:12,  9.82it/s]
  7%|6         | 9/135 [00:01<00:17,  7.36it/s]
  7%|7         | 10/135 [00:01<00:16,  7.62it/s]
  8%|8         | 11/135 [00:01<00:15,  7.88it/s]
  9%|8         | 12/135 [00:01<00:16,  7.43it/s]
 10%|#         | 14/135 [00:01<00:15,  7.90it/s]
 11%|#1        | 15/135 [00:01<00:15,  7.93it/s]
 12%|#1        | 16/135 [00:01<00:14,  8.03it/s]
 13%|#2        | 17/135 [00:02<00:14,  8.42it/s]
 14%|#4        | 19/135 [00:02<00:14,  8.03it/s]
 15%|#4        | 20/135 [00:02<00:14,  8.11it/s]
 16%|#5        | 21/135 [00:02<00:13,  8.41it/s]
 17%|#7        | 23/135 [00:02<00:12,  9.20it/s]
 18%|#7        | 24/135 [00:02<00:11,  9.37it/s]
 19%|#8        | 25/135 [00:02<00:12