In [15]:
import json
import itertools

import numpy as np
import torch
from transformers import BertForTokenClassification, BertConfig, IntervalStrategy, TrainingArguments, Trainer
import datasets
from datasets.arrow_dataset import Dataset
import ruamel.yaml
yaml = ruamel.yaml.YAML()

import abctk.obj.comparative as aoc
import abct_comp_ner_utils.models.NER_with_root as nwr

In [2]:
dataset_raw = datasets.load_dataset(
    "abctreebank/comparative-NER-BCCWJ",
    use_auth_token = True,
    revision = "18dcd7235a4ae43a3517b0545314c888a579995e",
)

Using custom data configuration default-60f4c3a656674579
Found cached dataset parquet (/home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-60f4c3a656674579/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

## NER model 

In [3]:
BATCH_SIZE = 16
MODEL_PATH = "../results_2022-12-27/"

tokenizer = nwr.get_tokenizer()
model = BertForTokenClassification.from_pretrained(
    MODEL_PATH,
).cuda()

In [4]:
ds_test = dataset_raw["test"].map(
    lambda E: nwr.convert_annotation_entries_to_matrices(
        E,
        return_type = "pt",
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-60f4c3a656674579/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-8a2e51971ad84ef9.arrow


In [5]:
def _predict(
    examples: datasets.arrow_dataset.Batch
):
    predictions_raw = model.forward(
        input_ids = torch.tensor(examples["input_ids"]).cuda(),
        attention_mask = torch.tensor(examples["attention_mask"]).cuda(),
        token_type_ids  = torch.tensor(examples["token_type_ids"]).cuda(),
        return_dict = True,
    )

    examples["label_ids_predicted_NER"] = (
        predictions_raw.logits
        .argmax(dim = 2,)
        .detach()
        .cpu()
        .numpy()
    )

    return examples
# === END ===

ds_test = ds_test.map(
    lambda e: (
        nwr.convert_predictions_to_annotations(
            nwr.convert_predictions_to_annotations(
                _predict(e),
                label_ids_key = "label_ids_predicted_NER",
                comp_key = "comp_predicted_NER",
            ),
            label_ids_key = "label_ids",
            comp_key = "comp_subword_aligned",
        )
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)
metric = nwr.NERWithRootMetrics()
metric.add_batch(
    predictions = ds_test["label_ids_predicted_NER"],
    references = ds_test["label_ids"],
)
metric_NER = metric.compute()
ds_test = ds_test.add_column(
    "matching_NER",
    [
        aoc.print_AlignResult(
            prediction = pred,
            reference = ref,
            alignment = align
        )
        for pred, ref, align in zip(
            ds_test["comp_predicted_NER"],
            ds_test["comp_subword_aligned"],
            metric_NER["alignments"],
        )
    ]
)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-60f4c3a656674579/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-d9e3852ed0114d36.arrow


In [6]:
ds_test["matching_NER"][0]

['ref: (9-18)root \t ↔ pred: (9-18)root',
 'ref: (9-12)cont \t ↔ pred: (9-12)cont',
 'ref: (12-14)prej \t ↔ pred: (12-14)prej',
 'ref: (14-15)diff \t ↔ pred: (14-15)diff',
 'ref: (15-17)deg \t ↔ pred: (15-17)deg']

In [7]:
for label, res in metric_NER["scores_spanwise"].items():
    print(label)
    print(res["F1_partial"])
    print(res["F1_strict"])
    print()

print(metric_NER["F1_partial_average"], metric_NER["F1_strict_average"])

root
0.7047913446676971
0.5007727975270478

cont
0.7167487684729065
0.5763546798029557

prej
0.8927943760984183
0.8471001757469245

diff
0.8390804597701149
0.7701149425287357

deg
0.8242320819112628
0.7610921501706485

0.79552940618408 0.6910869491552625


## Rule-based model based on Ginza

In [8]:
with open("./test_spacy.jsonl") as f:
    ds_rule_predicted= dict(
        (record["ID"], record)
        for record in map(
            lambda c: aoc.dice_CompRecord(**json.loads(c)),
            filter(None, map(str.strip, f))
        )
    )

def _add_rulebased_prediction(
    entry, 
    preds: dict[str, aoc.CompRecord] = ds_rule_predicted
):
    ID = entry["ID"]
    diced = aoc.dice_CompRecord(
        tokens = entry["tokens"], comp = entry["comp"],
        ID = ID, 
    )
    comp_diced = []
    for span in diced["comp"]:
        label = span["label"]
        if (match := nwr._RE_FEAT_ARTIFACTS.match(label) ):
            label = match.group("name") or label
        
        comp_diced.append(
            {
                "start": span["start"],
                "end": span["end"],
                "label": label,
            }
        )
    entry["tokens_diced"] = diced["tokens"]
    entry["comp_diced"] = comp_diced
    
    entry["comp_predicted_rulebased"] = preds[ID]["comp"] if ID in preds else None

    return entry

ds_test = ds_test.map(_add_rulebased_prediction)

  0%|          | 0/350 [00:00<?, ?ex/s]

In [10]:
metric_rulebased = aoc.calc_prediction_metrics(
    predictions = (
        rec["comp_predicted_rulebased"] for rec in ds_test
        if rec["comp_predicted_rulebased"] is not None
    ),
    references = (
        rec["comp_diced"] for rec in ds_test
        if rec["comp_predicted_rulebased"] is not None
    )
)

ds_test = ds_test.add_column(
    "matching_rulebased",
    [
        aoc.print_AlignResult(
            prediction = pred,
            reference = ref,
            alignment = align
        )
        for pred, ref, align in zip(
            ds_test["comp_predicted_rulebased"],
            ds_test["comp_diced"],
            metric_rulebased["alignments"],
        )
    ]
)

In [11]:
for label, res in metric_rulebased["scores_spanwise"].items():
    print(label)
    print(res["F1_partial"])
    print(res["F1_strict"])
    print()

print(metric_rulebased["F1_partial_average"], metric_rulebased["F1_strict_average"])

prej
0.8441971383147854
0.8139904610492845

deg
0.6911076443057722
0.5522620904836194

cont
0.3959731543624161
0.1375838926174497

diff
0.483050847457627
0.3983050847457627

root
nan
nan

0.6035821961101502 0.47553538222402914


In [9]:
with open("rulebased_scores.yaml", "w") as f:
    yaml.dump(
        {
            k: v for k, v in metric_rulebased.items()
            if k != "alignments"
        },
        stream = f
    )


In [16]:
with open("./details.yaml", "w") as g:
    yaml.dump(
        list(
            {
                "ID": entry["ID"],
                "reference_linearlized": aoc.linearize_annotations(
                    tokens = entry["tokens"],
                    comp = entry["comp"]
                ),
                "predicted_NER_linearized": aoc.linearize_annotations(
                    tokens = entry["token_subwords"],
                    comp = entry["comp_predicted_NER"],
                ),
                "matching_NER": entry["matching_NER"],
                "predicted_rulebased_linearized": aoc.linearize_annotations(
                    tokens = entry["tokens_diced"],
                    comp = entry["comp_predicted_rulebased"],
                ),
                "matching_rulebased": entry["matching_rulebased"],
            }
            for entry in ds_test
        ),
        stream = g,
    )