# Load SageMaker GroundTruth annotation

## Load task manifest file mapping text sample IDs to text

In [None]:
from collections import defaultdict
import codecs
from dataclasses import dataclass
import json
from pathlib import Path
from typing import Dict, List

from sklearn.model_selection import train_test_split
import spacy
from spacy import displacy

In [None]:
manifest_path = Path("annotations/glue-dir-kbase-dev-sagemaker-ground-truth-labeling-clone/annotations/intermediate/1/annotations.manifest")

In [None]:
assert manifest_path.is_file()

In [None]:
index_to_text_raw = {}

In [None]:
with open(manifest_path, "rt") as fin:
    for line in fin:
        columns = line.split("\t")
        index = int(columns[0])
        text = "\t".join(columns[1:-1])
        _ = columns[-1]  # no idea what this column is
        assert index not in index_to_text_raw
        index_to_text_raw[index] = text

## Load annotation for given sample text IDs 

In [None]:
annotations_dir = Path("annotations/glue-dir-kbase-dev-sagemaker-ground-truth-labeling-clone/annotations/worker-response/iteration-1/")

In [None]:
assert annotations_dir.is_dir()

In [None]:
index_to_annotation_raw = {}

In [None]:
for annotations_subdir in annotations_dir.iterdir():
    index = int(annotations_subdir.name)
    for i, annotations_file in enumerate(annotations_subdir.iterdir()):
        assert i == 0, f"found more than one annotation in {annotations_subdir}"
        with open(annotations_file, "rt") as fin:
            j = json.load(fin)
        assert index not in index_to_annotation_raw
        answers = j["answers"]
        assert len(answers) == 1
        entities = answers[0]['answerContent']['crowd-entity-annotation']['entities']
        index_to_annotation_raw[index] = entities

## Merge text and annotation whole removing episode IDs from text and adjusting entity positions

In [None]:
@dataclass
class EntityMatch:
    label: str
    start_offset: int
    end_offset: int

In [None]:
@dataclass
class Doc:
    id_: str
    text: str
    annotations: List[EntityMatch]

In [None]:
doc_id_to_doc = {}
for index, annotation_raw in index_to_annotation_raw.items():
    text_raw = index_to_text_raw[index]
    text = codecs.unicode_escape_decode(text_raw)[0]
    id_, text = text.split("\n", 1)
    id_offset = len(id_) + 1  # +1 due to newline which was stripped of before
    entity_matches = []
    for a in annotation_raw:
        match = EntityMatch(label=a["label"],
                            start_offset=a["startOffset"] - id_offset,
                            end_offset=a["endOffset"] - id_offset)
        entity_matches.append(match)
    doc = Doc(id_=id_, text=text, annotations=entity_matches)
    assert id_ not in doc_id_to_doc
    doc_id_to_doc[id_] = doc

## Create spacy Doc objects, load entity annotations and write DocBin to disk

In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
# spacy English NER labels https://spacy.io/models/en#en_core_web_sm-labels
# spacy glossary: https://github.com/explosion/spaCy/blob/master/spacy/glossary.py
label_to_spacy_ner_label = {
    "Book": "WORK_OF_ART",
    "Person": "PERSON",
    "Software": "PRODUCT"
}

In [None]:
id_to_spacy_docs = {}
check_ents = defaultdict(list)
for d in doc_id_to_doc.values():
    spacy_doc = nlp(d.text.encode('utf8','replace').decode('utf8')) # TODO is this necessary?
    my_ents = []
    for a in d.annotations:
        ent = spacy_doc.char_span(a.start_offset,
                                  a.end_offset,
                                  label=label_to_spacy_ner_label[a.label],
                                  alignment_mode="expand"
                                 )
        assert ent is not None
        my_ents.append(ent)
    assert len(my_ents) > 0
    
    # keep only the first entity if they overlap
    tokens_covered = set()
    non_overlapping_ents = []
    for ent in my_ents:
        keep_ent = True
        ent_tokens = set()
        for tok in range(ent.start, ent.end):
            if tok in tokens_covered:
                check_ents[d.id_].append(ent)
                keep_ent = False
                continue
            ent_tokens.add(tok)
        if keep_ent:
            tokens_covered.update(ent_tokens)
            non_overlapping_ents.append(ent)
    assert len(non_overlapping_ents) > 0
        
    spacy_doc.user_data = {"id": d.id_}
    # TODO keep spacy_doc.ents from default pipeline by setting default="unmodified" below?
    spacy_doc.set_ents(non_overlapping_ents, default="missing")
    assert d.id_ not in id_to_spacy_docs
    id_to_spacy_docs[d.id_] = spacy_doc

In [None]:
# inspect entities
for id_,ents in check_ents.items():
    for e in ents:
        print(f"{id_} -- {e.label} -- '{e.as_doc()}'")
    print("-------------------------------")

## Visualize for sanity checking

In [None]:
displacy.render(id_to_spacy_docs["PythonBytes:91"], style="ent")

In [None]:
displacy.render(id_to_spacy_docs["PythonBytes:100"], style="ent")

## Split into train/dev/test per podcast, taking time into account

In [None]:
podcast_episode = [x.split(":") for x in id_to_spacy_docs.keys()]
podcast_to_sorted_episode = defaultdict(list)
for podcast, episode in podcast_episode:
    episode = int(episode)
    podcast_to_sorted_episode[podcast].append(episode)
for episodes in podcast_to_sorted_episode.values():
    episodes.sort()

In [None]:
train_ids = []
dev_ids = []
test_ids = []
for p, es in podcast_to_sorted_episode.items():
    train, dev_test = train_test_split(es, test_size=0.3, shuffle=False)
    dev, test = train_test_split(dev_test, test_size=0.5, shuffle=False)
    train_ids.extend((p + ":" + str(t) for t in train))
    dev_ids.extend((p + ":" + str(d) for d in dev))
    test_ids.extend((p + ":" + str(t) for t in test))

In [None]:
for name, ids in zip(("train", "dev", "test"), (train_ids, dev_ids, test_ids)):
    output_path = Path(f"./{name}.spacy")
    assert not output_path.exists()
    current_docs = [id_to_spacy_docs[id_] for id_ in ids]
    doc_bin = spacy.tokens.DocBin(docs=current_docs)
    doc_bin.to_disk(output_path)