In [29]:
import itertools

import numpy as np
import torch
from transformers import BertForTokenClassification, BertConfig, IntervalStrategy, TrainingArguments, Trainer
import datasets
from datasets.arrow_dataset import Dataset

import ruamel.yaml

import abctk.obj.comparative as aoc

import abct_comp_ner_utils.models.NER_with_root as nwr

tokenizer = nwr.get_tokenizer()

BATCH_SIZE = 16
OUTPUT_PATH = "../../results_2022-12-27"

In [30]:
dataset_raw = datasets.load_dataset(
    "abctreebank/comparative-NER-BCCWJ",
    use_auth_token = True,
    revision = "18dcd7235a4ae43a3517b0545314c888a579995e",
)

Using custom data configuration default-60f4c3a656674579
Found cached dataset parquet (/home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-60f4c3a656674579/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

## Training

In [3]:
ds_train: Dataset = dataset_raw["train"]
ds_train = ds_train.map(
    lambda E: nwr.convert_annotation_entries_to_matrices(
        E,
        return_type = "pt",
    ),
    batched = True,
    batch_size = BATCH_SIZE,
    remove_columns = ds_train.column_names,
)

  0%|          | 0/195 [00:00<?, ?ba/s]

In [4]:
# train/eval split
ds_train_split = ds_train.train_test_split(test_size = 0.1, shuffle = True)

In [5]:
config = BertConfig.from_pretrained(
    nwr.BERT_MODEL,
    id2label = nwr.ID2LABEL,
    label2id = nwr.LABEL2ID,
)

model = BertForTokenClassification.from_pretrained(
    nwr.BERT_MODEL,
    config = config,
)

training_args = TrainingArguments(
        output_dir = OUTPUT_PATH,
        num_train_epochs = 27,
        per_device_train_batch_size = 16,
        per_device_eval_batch_size = 16,
        learning_rate = 5e-5,
        warmup_steps = 200,
        weight_decay = 0,
        save_strategy = IntervalStrategy.STEPS,
        save_steps = 1000,
        seed = 2630987289,
        logging_dir = f"{OUTPUT_PATH}/logs",
        logging_steps= 10,
    )

trainer = Trainer(
    model_init = lambda: model,
    args = training_args,
    train_dataset = ds_train_split["train"],
)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the m

In [6]:
trainer.train()
trainer.save_state()
trainer.save_model()

The following columns in the training set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: token_subwords. If token_subwords are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2803
  Num Epochs = 27
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 4752
  Number of trainable parameters = 110031366


Step,Training Loss
10,1.8957
20,1.6596
30,1.1309
40,0.4233
50,0.1751
60,0.1233
70,0.1146
80,0.1001
90,0.0856
100,0.075


Saving model checkpoint to ../../results_2022-12-27/checkpoint-1000
Configuration saved in ../../results_2022-12-27/checkpoint-1000/config.json
Model weights saved in ../../results_2022-12-27/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to ../../results_2022-12-27/checkpoint-2000
Configuration saved in ../../results_2022-12-27/checkpoint-2000/config.json
Model weights saved in ../../results_2022-12-27/checkpoint-2000/pytorch_model.bin
Saving model checkpoint to ../../results_2022-12-27/checkpoint-3000
Configuration saved in ../../results_2022-12-27/checkpoint-3000/config.json
Model weights saved in ../../results_2022-12-27/checkpoint-3000/pytorch_model.bin
Saving model checkpoint to ../../results_2022-12-27/checkpoint-4000
Configuration saved in ../../results_2022-12-27/checkpoint-4000/config.json
Model weights saved in ../../results_2022-12-27/checkpoint-4000/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)


Saving mod

In [19]:
model.name_or_path = "cl-tohoku/bert-base-japanese-whole-word-masking"

# # To push:
# model.push_to_hub(
#     "abctreebank/comparative-NER-with-root",
#     private = True,
#     use_auth_token = True,
# )
# tokenizer.push_to_hub(
#     "abctreebank/comparative-NER-with-root",
#     private = True,
#     use_auth_token = True,
# )

## Evaluating

In [31]:
SAVED_PATH = OUTPUT_PATH

model = BertForTokenClassification.from_pretrained(
    SAVED_PATH,
).cuda()

loading configuration file ../../results_2022-12-27/config.json
Model config BertConfig {
  "_name_or_path": "cl-tohoku/bert-base-japanese-whole-word-masking",
  "architectures": [
    "BertForTokenClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "O",
    "1": "deg",
    "2": "prej",
    "3": "cont",
    "4": "diff",
    "5": "root"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "IGNORE": -100,
    "O": 0,
    "cont": 3,
    "deg": 1,
    "diff": 4,
    "prej": 2,
    "root": 5
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "tokenizer_class": "BertJapaneseTokenizer",
  "torch_dtype": "float32",
  "transformers_version": "4.25.1",
  "type_vocab_size": 

In [32]:
ds_test = dataset_raw["test"].map(
    lambda E: nwr.convert_annotation_entries_to_matrices(
        E,
        return_type = "pt",
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-60f4c3a656674579/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-b2b5b7aabc6254b2.arrow


In [33]:
def _predict(
    examples: datasets.arrow_dataset.Batch
):
    predictions_raw = model.forward(
        input_ids = torch.tensor(examples["input_ids"]).cuda(),
        attention_mask = torch.tensor(examples["attention_mask"]).cuda(),
        token_type_ids  = torch.tensor(examples["token_type_ids"]).cuda(),
        return_dict = True,
    )

    examples["label_ids_predicted"] = (
        predictions_raw.logits
        .argmax(dim = 2,)
        .detach()
        .cpu()
        .numpy()
    )

    return examples
# === END ===

ds_test = ds_test.map(
    lambda e: (
        nwr.convert_predictions_to_annotations(
            nwr.convert_predictions_to_annotations(
                _predict(e),
                label_ids_key = "label_ids_predicted",
                comp_key = "comp_predicted",
            ),
            label_ids_key = "label_ids",
            comp_key = "comp_subword_aligned",
        )
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-60f4c3a656674579/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-e87a1d016301bd3d.arrow


In [34]:
metric = nwr.NERWithRootMetrics()
metric.add_batch(
    predictions = ds_test["label_ids_predicted"],
    references = ds_test["label_ids"],
)
metric_result = metric.compute()
ds_test_with_alignments = ds_test.add_column(
    "errors",
    [
        [
            aoc.MatchSpanResult(jdg).name for _, jdg in itertools.chain(
                res.map_pred_to_ref,
                res.map_ref_to_pred
            )
            if jdg != aoc.MatchSpanResult.CORRECT
        ]
        for res in metric_result["alignments"]
    ]
)

In [35]:
def _linearize_comp(
    batch: datasets.arrow_dataset.Batch
) -> datasets.arrow_dataset.Batch:
    ls_reference_linear = []
    ls_prediction_linear = []
    batch_size = len(batch["ID"])

    for i in range(batch_size):
        ID = batch["ID"][i]
        tokens = tuple(
                    itertools.takewhile(
                lambda t: t not in ("[SEP]", "[PAD]"),
                batch["token_subwords"][i]
            )
        )

        ls_reference_linear.append(
            aoc.linearize_annotations(
                tokens,
                batch["comp_subword_aligned"][i],
            )
        )
        ls_prediction_linear.append(
            aoc.linearize_annotations(
                tokens,
                batch["comp_predicted"][i],
            )
        )

    batch["reference_linear"] = ls_reference_linear
    batch["prediction_linear"] = ls_prediction_linear

    return batch

ds_test_with_alignments = ds_test_with_alignments.map(
    _linearize_comp,
    batched = True,
    batch_size = BATCH_SIZE,
)

ds_test_dump = ds_test_with_alignments.remove_columns(
    [
        col for col in ds_test_with_alignments.column_names
        if col not in (
            "ID",
            "prediction_linear",
            "reference_linear",
            "alignments",
            "errors",
        )
    ]
)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-60f4c3a656674579/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8acef3209db2b632.arrow


In [36]:
yaml = ruamel.yaml.YAML()
with open("./result.yaml", "w") as f:
    yaml.dump(list(ds_test_dump), f)

In [37]:
metric_result["scores_spanwise"] = { 
    str(k) : v
    for k, v in metric_result["scores_spanwise"].items()
}

In [38]:
metric_result["F1_partial_average"]

0.7942831376833167

In [39]:
metric_result["F1_strict_average"]


0.6900892737256631

In [40]:
for label, res in metric_result["scores_spanwise"].items():
    print(label)
    print(res["F1_partial"])
    print(res["F1_strict"])
    print()

root
0.7026194144838213
0.49922958397534667

cont
0.714987714987715
0.5749385749385749

prej
0.8943661971830986
0.8485915492957746

diff
0.8295454545454544
0.7613636363636364

deg
0.8298969072164948
0.7663230240549826



In [18]:
yaml = ruamel.yaml.YAML()
with open("scores.yaml", "w") as g:
    yaml.dump(
        {
            k : v for k, v in metric_result.items()
            if k != "alignments"
        },
        stream = g,
    )