In [None]:
import sys
from pathlib import Path


In [None]:
module_path = Path.cwd().parents[2]
if module_path not in sys.path:
    sys.path.append(str(module_path))


In [None]:
module_path


In [None]:
!mkdir -p metrics/


In [None]:
import json
import pickle
import random
from time import perf_counter

import regex as re

In [None]:
import spacy
import torch
from spacy.scorer import Scorer
from spacy.tokens import Doc, DocBin, Span
from spacy.training import Example


In [None]:
METRICS_PATH = Path.cwd() / "metrics"
METRICS_PATH.mkdir(parents=True, exist_ok=True)

In [None]:
from src.loader import TextLoader
from src.model import DatasetType, Text


In [None]:
PREDICTED_EXAMPLES_PATH = Path("data/predicted_examples.pkl")
if not PREDICTED_EXAMPLES_PATH.exists():
    # For faster calculations, will fail after it passes calculations and saves the data to pkl
    # but at this point the data is saved and can be loaded from the file, so we can use CPU
    torch.cuda.empty_cache()
    spacy.require_gpu()


In [None]:
loader = TextLoader(dataset_type=DatasetType.V1_WITH_PREDICTIONSTRING)
nlp = spacy.load("models/spacy_resume/model-best/")


In [None]:
def create_manual_doc(text: Text):
    word_idx = []
    for disc in text.discourses:
        word_idx.extend((disc.predictionstring[0], disc.predictionstring[-1]))

    ents = []

    DS_token = "B-DS"
    DE_token = "B-DE"
    use_DS = True
    for ind, word in enumerate(text.words):
        if use_DS:
            curr_token = DS_token
        else:
            curr_token = DE_token

        if ind in word_idx:
            ents.append(curr_token)
            use_DS = not use_DS
        else:
            ents.append("O")

    return Doc(nlp.vocab, text.words, ents=ents)


def display_doc(doc: Doc):
    spacy.displacy.render(doc, style="ent", jupyter=True)


In [None]:
doc_bin = DocBin().from_disk("data/NER_test.spacy")
len(doc_bin)


In [None]:
def fix_tokenization(doc: Doc) -> Doc:
    """Fix tokenization of reference doc."""
    tokens = []
    spaces = []
    ents = []
    for token in doc:
        tokenized = list(nlp(token.text))

        if len(tokenized) == 1:
            tokens.append(token.text)
            spaces.append(True if token.whitespace_ else False)
            if token.ent_iob_ == "O":
                ents.append(token.ent_iob_)
                continue

            ents.append(f"{token.ent_iob_}-{token.ent_type_}")
        else:
            is_outside = not token.ent_type_
            for ind, tok in enumerate(tokenized):
                tokens.append(tok.text)
                spaces.append(True if tok.whitespace_ else False)

                if is_outside:
                    ents.append("O")
                else:
                    if ind == 0:
                        ents.append(f"B-{token.ent_type_}")
                    else:
                        ents.append(f"I-{token.ent_type_}")

    return Doc(nlp.vocab, words=tokens, spaces=spaces, ents=ents)

In [None]:
# dddd = fix_tokenization(predicted_examples[92].reference)

# pre = predicted_examples[92].predicted

# for ind, ref_token in enumerate(dddd):
#     pre_token = pre[ind]
#     print(f"{ref_token.text:14} {pre_token.text}")

In [None]:
if not PREDICTED_EXAMPLES_PATH.exists():
    predicted_examples: list[Example] = []
    references = list(doc_bin.get_docs(nlp.vocab))
    for ind, reference in enumerate(references):
        print(f"\r{ind + 1:3d}/{len(references)}", end="")
        doc = nlp(reference.text)
        
        reference_fixed = fix_tokenization(reference)
        assert len(reference_fixed.ents) == len(reference.ents), "Number of entities do not match!"
        assert [ent.text for ent in reference_fixed.ents] == [ent.text for ent in reference.ents], "Entites do not match!"

        predicted_examples.append(Example(doc, reference_fixed))

    pickle.dump(predicted_examples, open(PREDICTED_EXAMPLES_PATH, "wb"))

    raise Exception("Data saved, restart kernel to run on CPU")

else:
    with open(PREDICTED_EXAMPLES_PATH, "rb") as f:
        predicted_examples: list[Example] = pickle.load(f)


In [None]:
for ind, token in enumerate(predicted_examples[9].reference):
    if not token.ent_type_:
        continue

    print(f"{ind:3}", f"{token.text:14}", f"{token.ent_iob_}-{token.ent_type_}")

print()

test_doc = fix_tokenization(predicted_examples[9].reference)
for ind, token in enumerate(test_doc):
    if not token.ent_type_:
        continue

    print(f"{ind:3}", f"{token.text:14}", f"{token.ent_iob_}-{token.ent_type_}")

In [None]:
def get_fixed_doc(example: Example, idx: list[int]):
    doc = example.predicted
    ents = doc.ents

    tokens_fixed = []
    tokens_spaces = []
    ents_fixed = []
    curr_ind = idx.pop(0)
    for token in doc:
        if curr_ind is not None and token.i == curr_ind + 1:
            # tokens_fixex[-1] = tokens_fixex[-1] + token.text
            tokens_fixed.append(token.text)
            tokens_spaces.append(token.whitespace_)
            ents_fixed.append(f"I-{token.ent_type_}")
            if idx:
                curr_ind = idx.pop(0)
            else:
                curr_ind = None
        else:
            tokens_fixed.append(token.text)
            tokens_spaces.append(token.whitespace_)
            if token.ent_iob_ == "O":
                ents_fixed.append(token.ent_iob_)
            else:
                ents_fixed.append(f"{token.ent_iob_}-{token.ent_type_}")

    return Doc(nlp.vocab, tokens_fixed, spaces=tokens_spaces, ents=ents_fixed)


In [None]:
def merge_invalid_examples(predicted_examples: list[Example], verbose: bool = False) -> list[Example]:
    merged_examples = []
    counter = 0
    for ind, example in enumerate(predicted_examples):
        ents = example.predicted.ents

        idx = []
        for i, ent in enumerate(ents[:-1]):
            next_ent = ents[i + 1]
            if ent.end == next_ent.start and ent.label_ == next_ent.label_ and "'" in next_ent.text:
                idx.append(ent.start)

        if not idx:
            merged_examples.append(example)
            continue

        counter += 1

        if verbose:
            print(f"ind: {ind}")
            for ent in ents:
                print(f"{ent.start:>3} {ent.end:>3} {ent.label_} {ent.text:12}", end=" ")
                if ent.start in idx:
                    print("<<<<<")
                else:
                    print()
            print("\n----\n")

        fixed_doc = get_fixed_doc(example, idx)

        if verbose:
            for ent in fixed_doc.ents:
                print(f"{ent.start:>3} {ent.end:>3} {ent.label_} {ent.text:12}")
            print("\n----\n")

        merged_examples.append(Example(fixed_doc, example.reference))

    print(f"Fixed {counter} examples.")

    return merged_examples


In [None]:
merged_examples = merge_invalid_examples(predicted_examples, verbose=False)
print(f"All examples: {len(predicted_examples)}")


In [None]:
# Predicted ind: 14 has some issues with consequent DS/DE entities
display_doc(merged_examples[9].predicted)


In [None]:
def inference_missing_tags(
    examples: list[Example], use_first: bool = False, use_sentence_boundaries: bool = True
) -> list[Example]:
    """
    It may happened that consequitive tags are of the same type, e.g. ... DS DE DE ...
    which is not ideal as we cannot extract discourses from it. This function tries to add
    missing tags to such cases based on couple of rules:

    1. For the sequence of tags there musn't be any consequitive tags of the same type.
    2. For missing tags (e.g. DS DE DE we will try to find missing DS tag so that it becomes
       DS DE DS DE).
    3. Missing tag is added only as a start / end of sentence. In case of situation where
       there are couple of sentences between two consequitive tags, we can use either use first
       approach (add missing tag to the first found sentence) or use last approach (add
       missing tag to the last found sentence).
    4. If use_sentence_boundaries is set to True, then we will try to add missing tags only
       at the start / end of sentence. If this is not possible, the middle consequitive tag will be
       removed. If use_sentence_boundaries is set to False, then the missing tag will be added to the
       first / last untagged token before the next consequitive tag if there is no sentence boundary.
    """
    fixed_examples = []
    for ind, example in enumerate(examples):
        print(f"\r{ind:3d}/{len(examples) - 1}", end="")

        doc = example.predicted

        last_ent = None
        last_ent_ind = None

        last_start_sent_ind = 0
        last_end_sent_ind = 0

        saved_first_token_ind = None

        tokens_fixed = []
        tokens_spaces = []
        ents_fixed = []

        for ind, token in enumerate(doc):
            tokens_fixed.append(token.text)
            tokens_spaces.append(token.whitespace_)

            if (
                last_start_sent_ind is not None
                and last_start_sent_ind < ind
                and ents_fixed
                and last_ent is not None
            ):
                # Check if last idx are set on proper tokens, if not remove them
                if last_ent == "DE":
                    if ents_fixed[last_start_sent_ind] in ("B-DS", "I-DS", "B-DE", "I-DE"):
                        last_start_sent_ind, last_end_sent_ind = None, None
                else:
                    if ents_fixed[last_end_sent_ind] in ("B-DS", "I-DS", "B-DE", "I-DE"):
                        last_start_sent_ind, last_end_sent_ind = None, None

            if use_first and last_start_sent_ind is None and token.text == ".":
                last_start_sent_ind = ind + 1
                last_end_sent_ind = ind - 1

            elif not use_first and token.text == ".":
                last_start_sent_ind = ind + 1
                last_end_sent_ind = ind - 1

            if not token.ent_type_:
                if saved_first_token_ind is None:
                    saved_first_token_ind = ind

                ents_fixed.append(token.ent_iob_)
                continue

            if (token.ent_type_ == "DS" and last_ent == "DE") or (
                token.ent_type_ == "DE" and last_ent == "DS"
            ):
                last_ent = token.ent_type_
                last_ent_ind = ind
                ents_fixed.append(f"{token.ent_iob_}-{token.ent_type_}")

                # Okay so reset them
                last_start_sent_ind, last_end_sent_ind = None, None
                saved_first_token_ind = None
                continue

            if last_ent is None and token.ent_type_ == "DS":
                last_ent = token.ent_type_
                last_ent_ind = ind

                last_start_sent_ind, last_end_sent_ind = None, None
                saved_first_token_ind = None

                ents_fixed.append(f"{token.ent_iob_}-{token.ent_type_}")
                continue

            assert not (token.ent_type_ == "DE" and ind == 0), "First token must not be DE"

            # print(f"Current token: {token.text} {token.ent_iob_}-{token.ent_type_}")
            # print(f"Added ents: {ents_fixed}")
            if last_start_sent_ind is not None:
                if token.ent_type_ == "DE":
                    # print(
                    #     f"Last start sent ind: {last_start_sent_ind} - {tokens_fixed[last_start_sent_ind]} {ents_fixed[last_start_sent_ind]}"
                    # )
                    assert ents_fixed[last_start_sent_ind] not in (
                        "B-DS",
                        "I-DS",
                        "B-DE",
                        "I-DE",
                    ), "The tag for the start of sentence is already set!"
                    ents_fixed[last_start_sent_ind] = "B-DS"
                elif token.ent_type_ == "DS":
                    # print(
                    #     f"Last end sent ind: {last_end_sent_ind} - {tokens_fixed[last_end_sent_ind]} {ents_fixed[last_end_sent_ind]}"
                    # )
                    assert ents_fixed[last_end_sent_ind] not in (
                        "B-DS",
                        "I-DS",
                        "B-DE",
                        "I-DE",
                    ), "The tag for the end of sentence is already set!"
                    ents_fixed[last_end_sent_ind] = "B-DE"
                else:
                    assert False, "Should not happen"
            else:
                if token.ent_iob_ == "I" and ents_fixed[ind - 1] == f"B-{token.ent_type_}":
                    # Fixed case like:
                    # I  B-DS
                    # 'm I-DS
                    last_ent = token.ent_type_
                    last_ent_ind = ind
                    ents_fixed.append(f"{token.ent_iob_}-{token.ent_type_}")
                    continue
                elif ents_fixed[ind - 1] == f"B-{token.ent_type_}":
                    # Two consequitive DS / DE tags that were not fixed so we leave
                    # the first one and remove the second one (for DS) and
                    # remove the first one and leave the second one (for DE)
                    if token.ent_type_ == "DS":
                        ents_fixed.append("O")
                    else:
                        ents_fixed[ind - 1] = "O"
                        ents_fixed.append(f"{token.ent_iob_}-{token.ent_type_}")
                    continue

                if use_sentence_boundaries:
                    # Remove middle consequitive tag
                    if ents_fixed[last_ent_ind].startswith("I-"):
                        # Entity is composed of more than one token
                        curr_ind = last_ent_ind
                        while last_ent in ents_fixed[curr_ind]:
                            ents_fixed[curr_ind] = "O"
                            curr_ind -= 1
                    else:
                        ents_fixed[last_ent_ind] = "O"
                else:
                    # Add missing tag to the first / last untagged token
                    if token.ent_type_ == "DE":
                        assert ents_fixed[saved_first_token_ind] not in (
                            "B-DS",
                            "I-DS",
                            "B-DE",
                            "I-DE",
                        ), "Start token already has a tag"
                        ents_fixed[saved_first_token_ind] = "B-DS"
                    elif token.ent_type_ == "DS":
                        assert ents_fixed[ind - 1] not in (
                            "B-DS",
                            "I-DS",
                            "B-DE",
                            "I-DE",
                        ), "End token already has a tag"
                        ents_fixed[ind - 1] = "B-DE"
                    else:
                        assert False, "Should not happen"

            last_start_sent_ind, last_end_sent_ind = None, None
            saved_first_token_ind = None

            last_ent = token.ent_type_
            last_ent_ind = ind
            ents_fixed.append(f"{token.ent_iob_}-{token.ent_type_}")

        # For cases when last token should be DE
        if last_ent == "DS" and last_end_sent_ind is not None:
            ents_fixed[last_end_sent_ind] = "B-DE"
        elif last_ent == "DS":
            assert ents_fixed[-1] not in (
                "B-DS",
                "I-DS",
                "B-DE",
                "I-DE",
            ), "Last token already has a tag"
            ents_fixed[-1] = "B-DE"

        # Check if each DS tag has a DE tag and vice versa
        ents_filtered = [ent for ent in ents_fixed if ent != "O" and not ent.startswith("I-")]
        assert len(ents_filtered) % 2 == 0, f"Example {ind} has uneven number of tags: {ents_filtered}!"

        doc_fixed = Doc(nlp.vocab, tokens_fixed, spaces=tokens_spaces, ents=ents_fixed)
        fixed_examples.append(Example(doc_fixed, example.reference))

    print()

    return fixed_examples


In [None]:
# display_doc(merged_examples[108].predicted)

In [None]:
# inferenced_last_sents, = inference_missing_tags(
#     [merged_examples[108]], use_first=False, use_sentence_boundaries=True
# )

In [None]:
inferenced_last_loose = inference_missing_tags(
    merged_examples, use_first=False, use_sentence_boundaries=False
)

In [None]:
inferenced_last_loose = inference_missing_tags(
    merged_examples, use_first=False, use_sentence_boundaries=False
)
inferenced_last_sents = inference_missing_tags(
    merged_examples, use_first=False, use_sentence_boundaries=True
)
inferenced_first_loose = inference_missing_tags(
    merged_examples, use_first=True, use_sentence_boundaries=False
)
inferenced_first_sents = inference_missing_tags(
    merged_examples, use_first=True, use_sentence_boundaries=True
)


In [None]:
num = 13

display_doc(merged_examples[num].predicted)
print("\nLast loose\n")
display_doc(inferenced_last_loose[num].predicted)
print("\nLast sents\n")
display_doc(inferenced_last_sents[num].predicted)
print("\nFirst loose\n")
display_doc(inferenced_first_loose[num].predicted)
print("\nFirst sents\n")
display_doc(inferenced_first_sents[num].predicted)


In [None]:
if not (metric_path := METRICS_PATH / "predicted_metrics.json").exists():
    start = perf_counter()
    metrics = nlp.evaluate(predicted_examples, batch_size=256)
    with open(metric_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"Predicted done in {perf_counter() - start:.2f}s")

# ----

if not (metric_path := METRICS_PATH / "merged_metrics.json").exists():
    start = perf_counter()
    metrics = nlp.evaluate(merged_examples, batch_size=256)
    with open(metric_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"Merged done in {perf_counter() - start:.2f}s")

# ----

if not (metric_path := METRICS_PATH / "last_loose_metrics.json").exists():
    start = perf_counter()
    metrics = nlp.evaluate(inferenced_last_loose, batch_size=256)
    with open(metric_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"Last loose done in {perf_counter() - start:.2f}s")

# ----

if not (metric_path := METRICS_PATH / "last_sents_metrics.json").exists():
    start = perf_counter()
    metrics = nlp.evaluate(inferenced_last_sents, batch_size=256)
    with open(metric_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"Last sents done in {perf_counter() - start:.2f}s")

# ----

if not (metric_path := METRICS_PATH / "first_loose_metrics.json").exists():
    start = perf_counter()
    metrics = nlp.evaluate(inferenced_first_loose, batch_size=256)
    with open(metric_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"First loose done in {perf_counter() - start:.2f}s")

# ----

if not (metric_path := METRICS_PATH / "first_sents_metrics.json").exists():
    start = perf_counter()
    metrics = nlp.evaluate(inferenced_first_sents, batch_size=256)
    with open(metric_path, "w") as f:
        json.dump(metrics, f, indent=4)

    print(f"First sents done in {perf_counter() - start:.2f}s")

In [None]:
# Print all metrics
for path in METRICS_PATH.glob("*.json"):
    with open(path) as f:
        metrics = json.load(f)
    
    print(path.stem)
    print(metrics, end="\n\n")

In [None]:
def extract_discourses(doc: Doc, keep_first_ds: bool = False, keep_first_de: bool = False):
    discourses = []
    tokens = [token.text for token in doc]

    last_ent = None
    ents = []
    deleted_offset = 0
    for ind, ent in enumerate(doc.ents):
        if ent.label_ == "DS" and last_ent == "DS":
            if not keep_first_ds:
                ents[ind - deleted_offset - 1] = ent

            deleted_offset += 1
            continue

        if ent.label_ == "DE" and last_ent == "DE":
            if not keep_first_de:
                ents[ind - deleted_offset - 1] = ent

            deleted_offset += 1
            continue

        ents.append(ent)
        last_ent = ent.label_

    last_tag = None
    for ind, ent in enumerate(ents):
        if ent.label_ == "DS":
            start_pos = ent.start
            last_tag = "DS"
            continue

        if ent.label_ == "DE":
            assert last_tag == "DS", "DE without DS"
            disc = " ".join(tokens[start_pos : ent.end])
            disc = re.sub(r" \.", ".", disc)
            discourses.append(disc)
            start_pos = None
            last_tag = "DE"
            continue

    return discourses


In [None]:
doc = inferenced_last_loose[num].reference
ref = extract_discourses(doc)
ref


In [None]:
doc = inferenced_last_loose[num].predicted
pred = extract_discourses(doc, keep_first_de=True)
pred


In [None]:
def create_discourse_doc(doc: Doc):
    words = [token.text for token in doc]

    ents = []
    in_disc = False
    disc = "DISC"
    for token in doc:
        if token.ent_type_ == "DS":
            in_disc = True
            if token.ent_iob_ == "I":
                ents.append(f"I-{disc}")
            else:
                ents.append(f"B-{disc}")
        elif token.ent_type_ == "DE":
            in_disc = False
            ents.append(f"I-{disc}")
        elif in_disc:
            ents.append(f"I-{disc}")
        else:
            ents.append("O")

    return Doc(nlp.vocab, words, ents=ents)


In [None]:
doc_ref = inferenced_first_loose[9].reference
disc_doc_ref = create_discourse_doc(doc_ref)

doc_pred = inferenced_first_loose[9].predicted
disc_doc_pred = create_discourse_doc(doc_pred)

In [None]:
print("Ref:\n")
display_doc(doc_ref)
print("\nPred:\n")
display_doc(doc_pred)

In [None]:
for example_ind, example in enumerate(inferenced_first_loose):
    doc_ref = example.reference
    doc_pred = example.predicted

    offset = 0
    cons_idx = 0
    for ind, token_pred in enumerate(doc_ref):
        token_ref = doc_pred[ind + offset]
        if token_pred.text != token_ref.text:
            cons_idx += 1

            if cons_idx > 0:
                print(f"Consecutive for {cons_idx} times for: {example_ind}")
                print(f"Token mismatch: {token_pred.text} != {token_ref.text} (ind: {ind}, offset = {offset})\n")
                break

            offset += 1
        else:
            cons_idx = 0

In [None]:
print(inferenced_first_loose[35].reference[137 - 5: 137 + 5])
for token in inferenced_first_loose[35].predicted[137 - 5 + 7: 137 + 5 + 7]:
    print(token.text)

In [None]:
display_doc(inferenced_first_loose[35].reference)
print()
display_doc(inferenced_first_loose[35].predicted)

In [None]:
def check_coverage(doc_ref: Doc, doc_pred: Doc):
    """
    1. For each original discourse, predictions which are overlapping with the original one are compared
    2. If the overlap between a prediction and an original discourse is >= 0.5 AND overlap between an original discourse and a prediciton is >= 0.5, the prediction is a match and considered a True Positive (TP). For multiple predictions that overlap an original discourse, the prediction with the highest overlap (in both ways - pair of overlaps!) is taken only.
    3. Any unmatched original discourses are False Negatives (FN) and any unmatched predictions are False Positives (FP).

    WARNING: Tokenization is usually not the same for the reference and the prediction. This function introduces offsets to align the tokens.
    """
    # Calculate offsets for non-matching tokens
    offsets_at_pos = []
    for ind, token_ref in enumerate(doc_ref):
        token_pred = doc_pred[ind + len(offsets_at_pos)]
        if token_ref.text != token_pred.text:
            offsets_at_pos.append(ind)

    
    # {
    #   "0": [(start, end), [(pred1_start, pred1_end), (pred2_start, pred2_end), ...]],   
    #   "1": [(start, end), [(pred1_start, pred1_end), (pred2_start, pred2_end), ...]],
    #   ...
    #   "non-overlapping": [(start1, end1), (start2, end2), ...]
    # }

    # At this point if the offset is added to the index, the tokens should match
    discourses_overlaps = {}
    
    curr_ref_start = None
    curr_ref_end = None

    overlapping_pred_discourses = {}

    offset = 0
    for ind, token_ref in enumerate(doc_ref):
        key = len(discourses_overlaps)

        if ind in offsets_at_pos:
            offset += 1

        curr_ind = ind + offset
        token_pred = doc_pred[curr_ind]

        if token_ref.ent_type_ and token_ref.ent_iob_ == "B":
            # Start of a new ref discourse
            curr_ref_start = ind
        elif token_ref.ent_type_ and token_ref.ent_iob_ == "I":
            # Continue the ref discourse
            curr_ref_end = ind
            discourses_overlaps[key] = [(curr_ref_start, curr_ref_end), []]
            # Merge overlapping predictions
            ...

        if token_pred.ent_type_ and token_pred.ent_iob_ == "B":
            # Start of a new pred discourse
            pred_start = curr_ind

            # Check if the prediction overlaps with any of ref discourses
            ...
        elif token_pred.ent_type_ and token_pred.ent_iob_ == "I":
            # Continue the pred discourse
            pred_end = curr_ind

            # Check if the prediction overlaps with any of ref discourses
            ...
    

In [None]:
check_coverage(doc_ref, doc_pred)

In [None]:
raise StopIteration
