In [1]:
from biome.text import Dataset, Pipeline, TrainerConfiguration
from biome.text.hpo import TuneExperiment
import itertools
import os
from ray import tune

In [2]:
#os.environ["WANDB_PROJECT"] = "profner"

In [3]:
!wandb offline

W&B offline, running your script from this directory will only write metadata locally.


In [4]:
train_ds = Dataset.from_json("../preprocessing_inference/train_v2.json")
valid_ds = Dataset.from_json("../preprocessing_inference/valid_v2.json")

Using custom data configuration default
Reusing dataset json (/home/david/.cache/huggingface/datasets/json/default-6489373448f25f56/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514)
Using custom data configuration default
Reusing dataset json (/home/david/.cache/huggingface/datasets/json/default-635e2cada6fc2ad1/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514)


In [5]:
train_ds.head()

Unnamed: 0,raw_text,tokens,tags_bioul,tags_bio,entity_text,classification_label,file_name
0,Cerramos nuestra querida Radio 😢 Nuestros cola...,"[Cerramos, nuestra, querida, Radio, 😢, Nuestro...","[O, O, O, O, O, O, U-PROFESION, O, U-PROFESION...","[O, O, O, O, O, O, B-PROFESION, O, B-PROFESION...","[colaboradores, conductores]",1,1242399976644325376.txt
1,#OtroEscandalo #HastaCuando \n#DenunciaCCOO #C...,"[#, OtroEscandalo, #, HastaCuando, æ, #, Denun...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...",[],0,1242406334802395137.txt
2,¿Es necesario entregar nuestra privacidad a un...,"[¿, Es, necesario, entregar, nuestra, privacid...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...",[],0,1242407077278093313.txt
3,Así que estás chimbeando mucho con esos Decret...,"[Así, que, estás, chimbeando, mucho, con, esos...","[O, O, O, O, O, O, O, O, O, O, O, U-PROFESION,...","[O, O, O, O, O, O, O, O, O, O, O, B-PROFESION,...",[Presidente],1,1242407274771030016.txt
4,@FeGarPe79 @escipion_r @LuciaMendezEM Estás MU...,"[@FeGarPe79, @escipion_r, @LuciaMendezEM, Está...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...",[],0,1242409866515435520.txt
5,La Generalitat facilitará las videconferencias...,"[La, Generalitat, facilitará, las, videconfere...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]",[],0,1242420050167988227.txt
6,“El pánico por coronavirus es injustificado” d...,"[“, El, pánico, por, coronavirus, es, injustif...","[O, O, O, O, O, O, O, O, O, O, U-PROFESION, O,...","[O, O, O, O, O, O, O, O, O, O, B-PROFESION, O,...",[virólogo],1,1242429168505233410.txt
7,La transparencia es necesaria para luchar cont...,"[La, transparencia, es, necesaria, para, lucha...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...",[],0,1242448823810654209.txt
8,Ojo con los mensajes que se están lanzando des...,"[Ojo, con, los, mensajes, que, se, están, lanz...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...",[],0,1242501570824200194.txt
9,¿Dispones de fundas de plástico cubreasientos ...,"[¿, Dispones, de, fundas, de, plástico, cubrea...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...",[],0,1242506209506312193.txt


In [6]:
train_ds.rename_column_("tags_bio", "tags")
valid_ds.rename_column_("tags_bio", "tags")
train_ds.rename_column_("classification_label", "labels")
valid_ds.rename_column_("classification_label", "labels")

# Transformer model

In [7]:
transformers_model: str = "dccuchile/bert-base-spanish-wwm-cased"
#transformers_model: str = "prajjwal1/bert-tiny"

profnert = {
    "name": "profnert",
    "features": {
        "transformers": {
            "model_name": transformers_model,
            "trainable": True,
        }
    },
    "head": {
        "type": "ProfNerT",
        "classification_labels": train_ds.unique("labels"),
        "classification_pooler": {
            "type": "bert_pooler",
            "pretrained_model": transformers_model,
            "requires_grad": True,
            "dropout": 0.1,
        },
        "ner_tags": list(set(itertools.chain.from_iterable(train_ds["tags"]))),
        "ner_tags_encoding": "BIO",
        "transformers_model": transformers_model,
        "dropout": 0.0,
    },
}

In [8]:
pipeline = Pipeline.from_config(profnert)

In [9]:
trainer = TrainerConfiguration(
    optimizer={
        "type": "adamw",
        "lr": tune.loguniform(1e-5, 1e-4),
        "weight_decay": tune.loguniform(5e-3, 5e-2)
    },
    linear_with_warmup=True,
    warmup_steps=tune.uniform(0, 200),
    training_size=len(train_ds),
    batch_size=tune.choice([4, 8, 16]),
    num_epochs=tune.choice([3, 4, 5]),
)

In [10]:
random_search = TuneExperiment(
    pipeline_config=profnert,
    trainer_config=trainer,
    train_dataset=train_ds,
    valid_dataset=valid_ds,
    name="profner",
    num_samples=1,
    local_dir="tune_runs",
    resources_per_trial={"cpu": 1, "gpu": 1},
)

In [None]:
analysis = tune.run(
    random_search,
    scheduler=tune.schedulers.ASHAScheduler(), 
    metric="validation_loss", 
    mode="min",
    progress_reporter=tune.JupyterNotebookReporter(overwrite=True)
)

In [11]:
trainer = TrainerConfiguration(
    optimizer={
        "type": "adamw",
        "lr": 5e-5,
    },
    batch_size=4,
    num_epochs=1,
    cuda_device=-1,
)

In [None]:
train_ds.select(range(10)).to_instances(pipeline, use_cache=False)

In [9]:
train_ds.cleanup_cache_files()
valid_ds.cleanup_cache_files()

0

In [None]:
pipeline.train(
    output="test",
    training=train_ds,
    validation=valid_ds,
    trainer=trainer,
)

HBox(children=(FloatProgress(value=0.0, description='Loading instances into memory', max=6000.0, style=Progres…

2021-02-12 01:24:05,960 - biome.text.dataset - INFO - Caching instances to /home/david/.cache/huggingface/datasets/json/default-6489373448f25f56/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/ef52491df94380af.instance_list)





HBox(children=(FloatProgress(value=0.0, description='Loading instances into memory', max=2000.0, style=Progres…

2021-02-12 01:24:13,617 - biome.text.dataset - INFO - Caching instances to /home/david/.cache/huggingface/datasets/json/default-635e2cada6fc2ad1/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/100d05899d9f80d2.instance_list)





2021-02-12 01:24:14,287 - allennlp.common.params - INFO - random_seed = 13370
2021-02-12 01:24:14,287 - allennlp.common.params - INFO - numpy_seed = 1337
2021-02-12 01:24:14,288 - allennlp.common.params - INFO - pytorch_seed = 133
2021-02-12 01:24:14,314 - allennlp.common.checks - INFO - Pytorch version: 1.7.1
2021-02-12 01:24:14,354 - allennlp.common.params - INFO - type = gradient_descent
2021-02-12 01:24:14,355 - allennlp.common.params - INFO - local_rank = 0
2021-02-12 01:24:14,355 - allennlp.common.params - INFO - patience = 2
2021-02-12 01:24:14,356 - allennlp.common.params - INFO - validation_metric = -loss
2021-02-12 01:24:14,356 - allennlp.common.params - INFO - num_epochs = 1
2021-02-12 01:24:14,357 - allennlp.common.params - INFO - cuda_device = -1
2021-02-12 01:24:14,358 - allennlp.common.params - INFO - grad_norm = None
2021-02-12 01:24:14,358 - allennlp.common.params - INFO - grad_clipping = None
2021-02-12 01:24:14,359 - allennlp.common.params - INFO - distributed = Fals

HBox(children=(FloatProgress(value=0.0, max=1500.0), HTML(value='')))

In [12]:
print(u'\u2066')

⁦
