In [1]:
import json
import itertools

import numpy as np
import pandas as pd
pd.options.display.float_format = '{: <10.2%}'.format

import torch
from transformers import BertForTokenClassification, BertConfig, IntervalStrategy, TrainingArguments, Trainer
import datasets
from datasets.arrow_dataset import Dataset
import ruamel.yaml
yaml = ruamel.yaml.YAML()

import abctk.obj.comparative as aoc
import abct_comp_ner_utils.models.NER_with_root as nwr

In [2]:
dataset_raw = datasets.load_dataset(
    "abctreebank/comparative-NER-BCCWJ",
    use_auth_token = True,
    revision = "6c51f916ecd23c32e546a3d4f695c69d8c47e21e",
)
ds_test = dataset_raw["test"]

Using custom data configuration default-9ccbc70a477221d0
Found cached dataset parquet (/home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-9ccbc70a477221d0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

## NER model 

In [3]:
BATCH_SIZE = 20

tokenizer = nwr.get_tokenizer()
model = BertForTokenClassification.from_pretrained(
    "abctreebank/comparative-NER-with-root",
    revision = "ed1b1834de445a5fc998839677d6c40872c8ad3c",
    use_auth_token  = True,
).cuda()

Downloading:   0%|          | 0.00/960 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [4]:
ds_test = dataset_raw["test"].map(
    lambda E: nwr.convert_annotation_entries_to_matrices(
        E,
        return_type = "pt",
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

  0%|          | 0/18 [00:00<?, ?ba/s]

In [5]:
def _chomp(
    example: datasets.arrow_dataset.Example
):
    chomped = aoc.chomp_CompRecord(
        tokens_subworeded = example["token_subwords"],
        comp = example["comp"],
    )

    example["comp_subword_aligned"] = chomped["comp"]
    return example
    
ds_test = ds_test.map(_chomp)

  0%|          | 0/349 [00:00<?, ?ex/s]

In [6]:
def _predict(
    examples: datasets.arrow_dataset.Batch
):
    predictions_raw = model.forward(
        input_ids = torch.tensor(examples["input_ids"]).cuda(),
        attention_mask = torch.tensor(examples["attention_mask"]).cuda(),
        token_type_ids  = torch.tensor(examples["token_type_ids"]).cuda(),
        return_dict = True,
    )

    examples["label_ids_predicted_NER"] = (
        predictions_raw.logits
        .argmax(dim = 2,)
        .detach()
        .cpu()
        .numpy()
    )

    return examples
# === END ===

ds_test = ds_test.map(
    lambda e: (
        nwr.convert_predictions_to_annotations(
            _predict(e),
            label_ids_key = "label_ids_predicted_NER",
            comp_key = "comp_predicted_NER",
        )
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

metric_NER = aoc.calc_prediction_metrics(
    predictions = ds_test["comp_predicted_NER"],
    references = ds_test["comp_subword_aligned"],
)
ds_test = ds_test.add_column(
    "matching_NER",
    [
        aoc.print_AlignResult(
            prediction = pred,
            reference = ref,
            alignment = align
        )
        for pred, ref, align in zip(
            ds_test["comp_predicted_NER"],
            ds_test["comp_subword_aligned"],
            metric_NER["alignments"],
        )
    ]
)

  0%|          | 0/18 [00:00<?, ?ba/s]

In [7]:
metric_NER_wo_align = metric_NER.copy()
del metric_NER_wo_align["alignments"]

df_NER = pd.DataFrame.from_dict(
    metric_NER_wo_align["scores_spanwise"],
    orient = "index",
).fillna(0).astype(
    {
        "CORRECT": "int32", 
        "SPURIOUS": "int32", 
        "MISSING": "int32", 
        "WRONG_SPAN": "int32", 
    }
)

In [8]:
df_NER.loc[
    ["prej", "cont", "deg", "diff", "root"], 
    ["CORRECT", "SPURIOUS", "MISSING", "WRONG_SPAN",
        "precision_strict", "recall_strict", "F1_strict",
        "precision_partial", "recall_partial", "F1_partial",
    ]
]

Unnamed: 0,CORRECT,SPURIOUS,MISSING,WRONG_SPAN,precision_strict,recall_strict,F1_strict,precision_partial,recall_partial,F1_partial
prej,228,24,8,20,83.82%,89.06%,86.36%,87.50%,92.97%,90.15%
cont,114,51,27,41,55.34%,62.64%,58.76%,65.29%,73.90%,69.33%
deg,216,43,35,16,78.55%,80.90%,79.70%,81.45%,83.90%,82.66%
diff,63,14,8,12,70.79%,75.90%,73.26%,77.53%,83.13%,80.23%
root,135,83,3,115,40.54%,53.36%,46.08%,57.81%,76.09%,65.70%


## Rule-based model based on Ginza

In [9]:
with open("./predictions_SpaCy_2023-01-07.jsonl") as f:
    ds_rule_predicted= dict(
        (record["ID"], record)
        for record in map(
            lambda c: aoc.dice_CompRecord(**json.loads(c)),
            filter(None, map(str.strip, f))
        )
    )

def _add_rulebased_prediction(
    entry, 
    preds: dict[str, aoc.CompRecord] = ds_rule_predicted
):
    ID = entry["ID"]
    diced = aoc.dice_CompRecord(
        tokens = entry["tokens"], comp = entry["comp"],
        ID = ID, 
    )
    comp_diced = []
    for span in diced["comp"]:
        label = span["label"]
        if (match := nwr._RE_FEAT_ARTIFACTS.match(label) ):
            label = match.group("name") or label
        
        comp_diced.append(
            {
                "start": span["start"],
                "end": span["end"],
                "label": label,
            }
        )
    entry["tokens_diced"] = diced["tokens"]
    entry["comp_diced"] = comp_diced
    
    entry["comp_predicted_rulebased"] = preds[ID]["comp"] if ID in preds else None

    return entry

ds_test = ds_test.map(_add_rulebased_prediction)

  0%|          | 0/349 [00:00<?, ?ex/s]

In [10]:
metric_rulebased = aoc.calc_prediction_metrics(
    predictions = (
        rec["comp_predicted_rulebased"] for rec in ds_test
        if rec["comp_predicted_rulebased"] is not None
    ),
    references = (
        rec["comp_diced"] for rec in ds_test
        if rec["comp_predicted_rulebased"] is not None
    )
)

ds_test = ds_test.add_column(
    "matching_rulebased",
    [
        aoc.print_AlignResult(
            prediction = pred,
            reference = ref,
            alignment = align
        )
        for pred, ref, align in zip(
            ds_test["comp_predicted_rulebased"],
            ds_test["comp_diced"],
            metric_rulebased["alignments"],
        )
    ]
)

In [11]:
df_rulebased = pd.DataFrame.from_dict(
    metric_rulebased["scores_spanwise"],
    orient = "index",
).fillna(0).astype(
    {
        "CORRECT": "int32", 
        "SPURIOUS": "int32", 
        "MISSING": "int32", 
        "WRONG_SPAN": "int32", 
    }
)

In [12]:
df_rulebased.loc[
    ["prej", "cont", "deg", "diff", "root"], 
    ["CORRECT", "SPURIOUS", "MISSING", "WRONG_SPAN",
        "precision_strict", "recall_strict", "F1_strict",
        "precision_partial", "recall_partial", "F1_partial",
    ]
]

Unnamed: 0,CORRECT,SPURIOUS,MISSING,WRONG_SPAN,precision_strict,recall_strict,F1_strict,precision_partial,recall_partial,F1_partial
prej,239,89,4,13,70.09%,93.36%,80.07%,71.99%,95.90%,82.24%
cont,49,244,78,55,14.08%,26.92%,18.49%,21.98%,42.03%,28.87%
deg,172,150,85,10,51.81%,64.42%,57.43%,53.31%,66.29%,59.10%
diff,50,96,21,12,31.65%,60.24%,41.49%,35.44%,67.47%,46.47%
root,0,0,253,0,inf%,0.00%,0.00%,inf%,0.00%,0.00%


## Health Check

In [13]:
df_rulebased.loc[
    ["prej", "cont", "deg", "diff"], 
    ["CORRECT", "MISSING", "WRONG_SPAN"],
].sum().sum()

788

In [14]:
df_NER.loc[
    ["prej", "cont", "deg", "diff"], 
    ["CORRECT", "MISSING", "WRONG_SPAN"],
].sum().sum()

788

## See each prediction in detail

In [15]:
with open("./details.yaml", "w") as g:
    yaml.dump(
        list(
            {
                "ID": entry["ID"],
                "reference_linearlized": aoc.linearize_annotations(
                    tokens = entry["tokens"],
                    comp = entry["comp"]
                ),
                "predicted_NER_linearized": aoc.linearize_annotations(
                    tokens = entry["token_subwords"],
                    comp = entry["comp_predicted_NER"],
                ),
                "matching_NER": entry["matching_NER"],
                "predicted_rulebased_linearized": aoc.linearize_annotations(
                    tokens = entry["tokens_diced"],
                    comp = entry["comp_predicted_rulebased"],
                ),
                "matching_rulebased": entry["matching_rulebased"],
            }
            for entry in ds_test
        ),
        stream = g,
    )