In [1]:
import torch
from transformers import BertForTokenClassification, BertConfig, IntervalStrategy, TrainingArguments, Trainer
import datasets
from datasets.arrow_dataset import Dataset

import abct_comp_ner_utils.models.NER_with_root as nwr

tokenizer = nwr.get_tokenizer()

BATCH_SIZE = 32
OUTPUT_PATH = "../../results_2022-12-10"

In [2]:
dataset_raw = datasets.load_dataset(
    "abctreebank/comparative-NER-BCCWJ",
    use_auth_token = True,
    revision = "de8fe2785391897efc898be36d695d8164863045",
)

Using custom data configuration abctreebank--comparative-NER-BCCWJ-c32c3cdce4ba824a
Found cached dataset parquet (/home/owner/.cache/huggingface/datasets/abctreebank___parquet/abctreebank--comparative-NER-BCCWJ-c32c3cdce4ba824a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

## Training

In [4]:
ds_train: Dataset = dataset_raw["train"]
ds_train = ds_train.map(
    lambda E: nwr.convert_annotation_entries_to_matrices(
        E,
        return_type = "pt",
    ),
    batched = True,
    batch_size = BATCH_SIZE,
    remove_columns = ds_train.column_names,
)

  0%|          | 0/98 [00:00<?, ?ba/s]

In [5]:
# train/eval split
ds_train_split = ds_train.train_test_split(test_size = 0.1, shuffle = True)

In [6]:
config = BertConfig.from_pretrained(
    nwr.BERT_MODEL,
    id2label = nwr.ID2LABEL,
    label2id = nwr.LABEL2ID,
)

model = BertForTokenClassification.from_pretrained(
    nwr.BERT_MODEL,
    config = config,
)

training_args = TrainingArguments(
        output_dir = OUTPUT_PATH,
        num_train_epochs = 27,
        per_device_train_batch_size = 64,
        per_device_eval_batch_size = 128,
        learning_rate = 5e-5,
        warmup_steps = 200,
        weight_decay = 0,
        save_strategy = IntervalStrategy.STEPS,
        save_steps = 1000,
        seed = 2630987289,
        logging_dir = f"{OUTPUT_PATH}/logs",
        logging_steps= 10,
    )

trainer = Trainer(
    model_init = lambda: model,
    args = training_args,
    train_dataset = ds_train_split["train"],
)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the m

In [7]:
trainer.train()
trainer.save_state()
trainer.save_model()

The following columns in the training set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: token_subwords. If token_subwords are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2808
  Num Epochs = 27
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 1188
  Number of trainable parameters = 110031366


  0%|          | 0/1188 [00:00<?, ?it/s]

{'loss': 1.3141, 'learning_rate': 2.5e-06, 'epoch': 0.23}
{'loss': 1.0868, 'learning_rate': 5e-06, 'epoch': 0.45}
{'loss': 0.6381, 'learning_rate': 7.5e-06, 'epoch': 0.68}
{'loss': 0.2293, 'learning_rate': 1e-05, 'epoch': 0.91}
{'loss': 0.1631, 'learning_rate': 1.25e-05, 'epoch': 1.14}
{'loss': 0.1291, 'learning_rate': 1.5e-05, 'epoch': 1.36}
{'loss': 0.1106, 'learning_rate': 1.75e-05, 'epoch': 1.59}
{'loss': 0.0945, 'learning_rate': 2e-05, 'epoch': 1.82}
{'loss': 0.0816, 'learning_rate': 2.25e-05, 'epoch': 2.05}
{'loss': 0.0633, 'learning_rate': 2.5e-05, 'epoch': 2.27}
{'loss': 0.0561, 'learning_rate': 2.7500000000000004e-05, 'epoch': 2.5}
{'loss': 0.046, 'learning_rate': 3e-05, 'epoch': 2.73}
{'loss': 0.0493, 'learning_rate': 3.2500000000000004e-05, 'epoch': 2.95}
{'loss': 0.0402, 'learning_rate': 3.5e-05, 'epoch': 3.18}
{'loss': 0.0372, 'learning_rate': 3.7500000000000003e-05, 'epoch': 3.41}
{'loss': 0.0424, 'learning_rate': 4e-05, 'epoch': 3.64}
{'loss': 0.0397, 'learning_rate': 4.

Saving model checkpoint to ../../results_2022-12-10/checkpoint-1000
Configuration saved in ../../results_2022-12-10/checkpoint-1000/config.json


{'loss': 0.0026, 'learning_rate': 9.51417004048583e-06, 'epoch': 22.73}


Model weights saved in ../../results_2022-12-10/checkpoint-1000/pytorch_model.bin


{'loss': 0.0021, 'learning_rate': 9.008097165991904e-06, 'epoch': 22.95}
{'loss': 0.002, 'learning_rate': 8.502024291497976e-06, 'epoch': 23.18}
{'loss': 0.002, 'learning_rate': 7.99595141700405e-06, 'epoch': 23.41}
{'loss': 0.0021, 'learning_rate': 7.489878542510122e-06, 'epoch': 23.64}
{'loss': 0.0019, 'learning_rate': 6.983805668016195e-06, 'epoch': 23.86}
{'loss': 0.0021, 'learning_rate': 6.4777327935222675e-06, 'epoch': 24.09}
{'loss': 0.0017, 'learning_rate': 5.971659919028341e-06, 'epoch': 24.32}
{'loss': 0.0019, 'learning_rate': 5.465587044534413e-06, 'epoch': 24.55}
{'loss': 0.0017, 'learning_rate': 4.9595141700404865e-06, 'epoch': 24.77}
{'loss': 0.0021, 'learning_rate': 4.453441295546559e-06, 'epoch': 25.0}
{'loss': 0.0019, 'learning_rate': 3.9473684210526315e-06, 'epoch': 25.23}
{'loss': 0.0015, 'learning_rate': 3.4412955465587043e-06, 'epoch': 25.45}
{'loss': 0.0016, 'learning_rate': 2.9352226720647772e-06, 'epoch': 25.68}
{'loss': 0.0019, 'learning_rate': 2.42914979757085



Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ../../results_2022-12-10
Configuration saved in ../../results_2022-12-10/config.json


{'train_runtime': 652.3249, 'train_samples_per_second': 116.224, 'train_steps_per_second': 1.821, 'train_loss': 0.04150153740045774, 'epoch': 27.0}


Model weights saved in ../../results_2022-12-10/pytorch_model.bin


## Evaluating

In [3]:
SAVED_PATH = OUTPUT_PATH

model = BertForTokenClassification.from_pretrained(
    SAVED_PATH,
).cuda()

In [4]:
ds_test = dataset_raw["test"].map(
    lambda E: nwr.convert_annotation_entries_to_matrices(
        E,
        return_type = "pt",
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

  0%|          | 0/11 [00:00<?, ?ba/s]

In [5]:
def _predict(
    examples: datasets.arrow_dataset.Batch
):
    predictions_raw = model.forward(
        input_ids = torch.tensor(examples["input_ids"]).cuda(),
        attention_mask = torch.tensor(examples["attention_mask"]).cuda(),
        token_type_ids  = torch.tensor(examples["token_type_ids"]).cuda(),
        return_dict = True,
    )

    examples["label_ids_predicted"] = (
        predictions_raw.logits
        .argmax(dim = 2,)
        .detach()
        .cpu()
        .numpy()
    )

    return examples
# === END ===

ds_test = ds_test.map(
    lambda e: (
        nwr.convert_predictions_to_annotations(
            nwr.convert_predictions_to_annotations(
                _predict(e),
                label_ids_key = "label_ids_predicted",
                comp_key = "comp_predicted",
            ),
            label_ids_key = "label_ids",
            comp_key = "comp_subword_aligned",
        )
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

  0%|          | 0/11 [00:00<?, ?ba/s]

In [6]:
metric = nwr.NERWithRootMetrics()

In [8]:
metric.add_batch(
    predictions = ds_test["label_ids_predicted"],
    references = ds_test["label_ids"],
)
metric_result = metric.compute()

In [24]:
import json
with open("result.json", "w") as f:
    json.dump(
        metric_result,
        f,
        ensure_ascii = False,
    )