In [1]:
import json
import itertools
import re

import numpy as np
import pandas as pd
pd.options.display.float_format = '{: <10.1%}'.format

import torch
from transformers import BertForTokenClassification, BertConfig, IntervalStrategy, TrainingArguments, Trainer
import datasets
from datasets.arrow_dataset import Dataset
import ruamel.yaml
yaml = ruamel.yaml.YAML()

import abctk.obj.comparative as aoc
import abct_comp_ner_utils.models.NER_with_root as nwr

In [2]:
dataset_raw = datasets.load_dataset(
    "abctreebank/comparative-NER-BCCWJ",
    use_auth_token = True,
    revision = "e3cdaf016f1fba88d10194500c313f951b0d2df3",
)
ds_test = dataset_raw["test"]

Using custom data configuration default-935290dee194d9be
Found cached dataset parquet (/home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-935290dee194d9be/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

## Data descriptions

In [3]:
_RE_kurabe = re.compile(r"kurabe")
_RE_rentai = re.compile(r"関係|連体")
_RE_renyoo = re.compile(r"連用")
_RE_fromtree = re.compile(r"^FROM TREE:")
_RE_questionable = re.compile("？|\\?")

df_ds_all_stats = pd.DataFrame(
    data = np.zeros((2, 4), dtype = np.int_),
    index = ["より", "比べ"],
    columns = ["連体", "連用", "その他比較", "非比較"],
)

for record in datasets.concatenate_datasets(
    [dataset_raw["train"], dataset_raw["test"]],
):
    index = "比べ" if _RE_kurabe.search(record["ID"]) else "より"
    
    comments = tuple(
        c for c in record["comments"] 
        if not _RE_fromtree.search(c)
    )

    if any(_RE_rentai.search(c) for c in comments):
        df_ds_all_stats.loc[index, "連体"] += 1
    elif any(_RE_renyoo.search(c) for c in comments):
        df_ds_all_stats.loc[index, "連用"] += 1
    elif not record["comp"]:
        df_ds_all_stats.loc[index, "非比較"] += 1
    else:
        df_ds_all_stats.loc[index, "その他比較"] += 1

In [4]:
df_ds_all_stats

Unnamed: 0,連体,連用,その他比較,非比較
より,344,289,1111,778
比べ,55,103,487,293


## NER model 

In [5]:
BATCH_SIZE = 20

tokenizer = nwr.get_tokenizer()
model = BertForTokenClassification.from_pretrained(
    "abctreebank/comparative-NER-with-root",
    revision = "06e75bfe73a0bdf68ec5c57a93c07ebeb4704126",
    use_auth_token  = True,
).cuda()

In [6]:
ds_test = dataset_raw["test"].map(
    lambda E: nwr.convert_annotation_entries_to_matrices(
        E,
        return_type = "pt",
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-935290dee194d9be/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-b58d2c44a41c8064.arrow


In [7]:
def _chomp(
    example: datasets.arrow_dataset.Example
):
    chomped = aoc.chomp_CompRecord(
        tokens_subworeded = example["token_subwords"],
        comp = example["comp"],
    )

    example["comp_subword_aligned"] = chomped["comp"]
    return example
    
ds_test = ds_test.map(_chomp)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-935290dee194d9be/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-7b28814ef55a1dcf.arrow


In [8]:
def _predict(
    examples: datasets.arrow_dataset.Batch
):
    predictions_raw = model.forward(
        input_ids = torch.tensor(examples["input_ids"]).cuda(),
        attention_mask = torch.tensor(examples["attention_mask"]).cuda(),
        token_type_ids  = torch.tensor(examples["token_type_ids"]).cuda(),
        return_dict = True,
    )

    examples["label_ids_predicted_NER"] = (
        predictions_raw.logits
        .argmax(dim = 2,)
        .detach()
        .cpu()
        .numpy()
    )

    return examples
# === END ===

ds_test = ds_test.map(
    lambda e: (
        nwr.convert_predictions_to_annotations(
            _predict(e),
            label_ids_key = "label_ids_predicted_NER",
            comp_key = "comp_predicted_NER",
        )
    ),
    batched = True,
    batch_size = BATCH_SIZE,
)

metric_NER = aoc.calc_prediction_metrics(
    predictions = ds_test["comp_predicted_NER"],
    references = ds_test["comp_subword_aligned"],
)
ds_test = ds_test.add_column(
    "matching_NER",
    [
        aoc.print_AlignResult(
            prediction = pred,
            reference = ref,
            alignment = align
        )
        for pred, ref, align in zip(
            ds_test["comp_predicted_NER"],
            ds_test["comp_subword_aligned"],
            metric_NER["alignments"],
        )
    ]
)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-935290dee194d9be/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-2d5b0c69e46f27b1.arrow


In [9]:
metric_NER_wo_align = metric_NER.copy()
del metric_NER_wo_align["alignments"]

df_NER = pd.DataFrame.from_dict(
    metric_NER_wo_align["scores_spanwise"],
    orient = "index",
).fillna(0).astype(
    {
        "CORRECT": "int32", 
        "SPURIOUS": "int32", 
        "MISSING": "int32", 
        "WRONG_SPAN": "int32", 
    }
)

In [10]:
df_NER.loc[
    ["prej", "cont", "deg", "diff", "root"], 
    ["CORRECT", "SPURIOUS", "MISSING", "WRONG_SPAN",
        "precision_strict", "recall_strict", "F1_strict",
        "precision_partial", "recall_partial", "F1_partial",
    ]
]

Unnamed: 0,CORRECT,SPURIOUS,MISSING,WRONG_SPAN,precision_strict,recall_strict,F1_strict,precision_partial,recall_partial,F1_partial
prej,228,29,8,21,82.0%,88.7%,85.2%,85.8%,92.8%,89.2%
cont,119,53,23,35,57.5%,67.2%,62.0%,65.9%,77.1%,71.1%
deg,219,48,41,10,79.1%,81.1%,80.1%,80.9%,83.0%,81.9%
diff,66,9,7,10,77.6%,79.5%,78.6%,83.5%,85.5%,84.5%
root,141,87,4,110,41.7%,55.3%,47.6%,58.0%,76.9%,66.1%


### See scores without spuriousity

In [11]:
metric_NER_wo_spurious = aoc.calc_prediction_metrics(
    predictions = (
        rec["comp_predicted_NER"] for rec in ds_test
        if rec["comp_subword_aligned"] 
    ),
    references = (
        rec["comp_subword_aligned"] for rec in ds_test
        if rec["comp_subword_aligned"]
    )
)

df_NER_wo_spurious = pd.DataFrame.from_dict(
    metric_NER_wo_spurious["scores_spanwise"],
    orient = "index",
).fillna(0).astype(
    {
        "CORRECT": "int32", 
        "SPURIOUS": "int32", 
        "MISSING": "int32", 
        "WRONG_SPAN": "int32", 
    }
)

In [12]:
df_NER_wo_spurious.loc[
    ["prej", "cont", "deg", "diff", "root"], 
    ["CORRECT", "SPURIOUS", "MISSING", "WRONG_SPAN",
        "precision_strict", "recall_strict", "F1_strict",
        "precision_partial", "recall_partial", "F1_partial",
    ]
]

Unnamed: 0,CORRECT,SPURIOUS,MISSING,WRONG_SPAN,precision_strict,recall_strict,F1_strict,precision_partial,recall_partial,F1_partial
prej,228,7,8,21,89.1%,88.7%,88.9%,93.2%,92.8%,93.0%
cont,119,39,23,35,61.7%,67.2%,64.3%,70.7%,77.1%,73.8%
deg,219,31,41,10,84.2%,81.1%,82.6%,86.2%,83.0%,84.5%
diff,66,7,7,10,79.5%,79.5%,79.5%,85.5%,85.5%,85.5%
root,141,57,4,110,45.8%,55.3%,50.1%,63.6%,76.9%,69.6%


In [13]:
df_NER_wo_spurious.loc[
    ["prej", "cont", "deg", "diff"], 
    "precision_strict"
].mean() - df_NER.loc[
    ["prej", "cont", "deg", "diff"], 
    "precision_strict"
].mean()

0.045646576872003686

## Rule-based model based on Ginza

In [14]:
with open("./predictions_SpaCy_2023-01-13.jsonl") as f:
    ds_rule_predicted= dict(
        (record["ID"], record)
        for record in map(
            lambda c: aoc.dice_CompRecord(**json.loads(c)),
            filter(None, map(str.strip, f))
        )
    )

def _add_rulebased_prediction(
    entry, 
    preds: dict[str, aoc.CompRecord] = ds_rule_predicted
):
    ID = entry["ID"]
    diced = aoc.dice_CompRecord(
        tokens = entry["tokens"], comp = entry["comp"],
        ID = ID, 
    )
    comp_diced = []
    for span in diced["comp"]:
        label = span["label"]
        if (match := nwr._RE_FEAT_ARTIFACTS.match(label) ):
            label = match.group("name") or label
        
        comp_diced.append(
            {
                "start": span["start"],
                "end": span["end"],
                "label": label,
            }
        )
    entry["tokens_diced"] = diced["tokens"]
    entry["comp_diced"] = comp_diced
    
    entry["comp_predicted_rulebased"] = preds[ID]["comp"] if ID in preds else None

    return entry

ds_test = ds_test.map(_add_rulebased_prediction)

Loading cached processed dataset at /home/twotrees12/.cache/huggingface/datasets/abctreebank___parquet/default-935290dee194d9be/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-6d663900b78060c8.arrow


In [15]:
metric_rulebased = aoc.calc_prediction_metrics(
    predictions = (
        rec["comp_predicted_rulebased"] for rec in ds_test
    ),
    references = (
        rec["comp_diced"] for rec in ds_test
    )
)

ds_test = ds_test.add_column(
    "matching_rulebased",
    [
        aoc.print_AlignResult(
            prediction = pred,
            reference = ref,
            alignment = align
        )
        for pred, ref, align in zip(
            ds_test["comp_predicted_rulebased"],
            ds_test["comp_diced"],
            metric_rulebased["alignments"],
        )
    ]
)

In [16]:
df_rulebased = pd.DataFrame.from_dict(
    metric_rulebased["scores_spanwise"],
    orient = "index",
).fillna(0).astype(
    {
        "CORRECT": "int32", 
        "SPURIOUS": "int32", 
        "MISSING": "int32", 
        "WRONG_SPAN": "int32", 
    }
)

In [17]:
df_rulebased.loc[
    ["prej", "cont", "deg", "diff"], 
    ["CORRECT", "SPURIOUS", "MISSING", "WRONG_SPAN",
        "precision_strict", "recall_strict", "F1_strict",
        "precision_partial", "recall_partial", "F1_partial",
    ]
]

Unnamed: 0,CORRECT,SPURIOUS,MISSING,WRONG_SPAN,precision_strict,recall_strict,F1_strict,precision_partial,recall_partial,F1_partial
prej,239,86,4,14,70.5%,93.0%,80.2%,72.6%,95.7%,82.6%
cont,50,104,91,36,26.3%,28.2%,27.2%,35.8%,38.4%,37.1%
deg,158,159,100,12,48.0%,58.5%,52.8%,49.8%,60.7%,54.8%
diff,52,62,24,7,43.0%,62.7%,51.0%,45.9%,66.9%,54.4%


### See scores without spuriousity

In [18]:
metric_rulebased_wo_spurious = aoc.calc_prediction_metrics(
    predictions = (
        rec["comp_predicted_rulebased"] for rec in ds_test
        if rec["comp_diced"] 
    ),
    references = (
        rec["comp_diced"] for rec in ds_test
        if rec["comp_diced"]
    )
)

df_rulebased_wo_spurious = pd.DataFrame.from_dict(
    metric_rulebased_wo_spurious["scores_spanwise"],
    orient = "index",
).fillna(0).astype(
    {
        "CORRECT": "int32", 
        "SPURIOUS": "int32", 
        "MISSING": "int32", 
        "WRONG_SPAN": "int32", 
    }
)

In [19]:

df_rulebased_wo_spurious.loc[
    ["prej", "cont", "deg", "diff"], 
    ["CORRECT", "SPURIOUS", "MISSING", "WRONG_SPAN",
        "precision_strict", "recall_strict", "F1_strict",
        "precision_partial", "recall_partial", "F1_partial",
    ]
]

Unnamed: 0,CORRECT,SPURIOUS,MISSING,WRONG_SPAN,precision_strict,recall_strict,F1_strict,precision_partial,recall_partial,F1_partial
prej,239,1,4,14,94.1%,93.0%,93.5%,96.9%,95.7%,96.3%
cont,50,66,91,36,32.9%,28.2%,30.4%,44.7%,38.4%,41.3%
deg,158,82,100,12,62.7%,58.5%,60.5%,65.1%,60.7%,62.8%
diff,52,44,24,7,50.5%,62.7%,55.9%,53.9%,66.9%,59.7%


In [20]:
df_rulebased_wo_spurious.loc[
    ["prej", "cont", "deg", "diff"], 
    "precision_strict"
].mean() - df_rulebased.loc[
    ["prej", "cont", "deg", "diff"], 
    "precision_strict"
].mean()

0.13089071875441893

## Health Check

In [21]:
df_rulebased.loc[
    ["prej", "cont", "deg", "diff"], 
    ["CORRECT", "MISSING", "WRONG_SPAN"],
].sum().sum()

787

In [22]:
df_NER.loc[
    ["prej", "cont", "deg", "diff"], 
    ["CORRECT", "MISSING", "WRONG_SPAN"],
].sum().sum()

787

## See each prediction in detail

In [23]:
with open("./details.yaml", "w") as g:
    yaml.dump(
        list(
            {
                "ID": entry["ID"],
                "reference_linearlized": aoc.linearize_annotations(
                    tokens = entry["tokens"],
                    comp = entry["comp"]
                ),
                "predicted_NER_linearized": aoc.linearize_annotations(
                    tokens = entry["token_subwords"],
                    comp = entry["comp_predicted_NER"],
                ),
                "matching_NER": entry["matching_NER"],
                "predicted_rulebased_linearized": aoc.linearize_annotations(
                    tokens = entry["tokens_diced"],
                    comp = entry["comp_predicted_rulebased"],
                ),
                "matching_rulebased": entry["matching_rulebased"],
            }
            for entry in ds_test
        ),
        stream = g,
    )