In [None]:
import json
import random
import spacy
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple

In [None]:
!pwd

In [None]:
@dataclass(frozen=True)
class SpacyNerEntity:
    start_char: int
    end_char: int
    label: str

    @staticmethod
    def from_annotation(annotation: dict, passage_offset: int) -> "SpacyNerEntity":
        start_char = annotation["locations"][0]["offset"] - passage_offset
        return SpacyNerEntity(
            start_char=start_char,
            end_char=start_char + annotation["locations"][0]["length"],
            label=annotation["infons"]["type"]
        )
    
    def as_primitive(self) -> Tuple[int, int, str]:
        return (
            self.start_char,
            self.end_char,
            self.label,
        )
        
@dataclass(frozen=True)
class SpacyNerExample:
    text: str
    entities: List[SpacyNerEntity]
        
    @staticmethod
    def from_passage(passage: dict) -> "SpacyNerExample":
        return SpacyNerExample(
            text=passage["text"],
            entities=[
                SpacyNerEntity.from_annotation(
                    annotation=annotation, 
                    passage_offset=passage["offset"],
                )
                for annotation in passage["annotations"]
            ]
        )
    
    def as_primitive(self) -> Tuple[str, Dict[str, List[Tuple[int, int, str]]]]:
        return (
            self.text,
            {
                "entities": [
                    entity.as_primitive()
                    for entity in self.entities
                ]
            }
        )
        
training_data = [
    SpacyNerExample.from_passage(
        passage=passage
    ).as_primitive()
    for file_path in Path(
        "../data/raw/pubtator/"
    ).glob("0to9999/0to999.jsonl")
    for example in file_path.read_text().split("\n")
    if example != ""
    for passage in json.loads(example)["passages"]
]

labels = {
    entity[2]
    for passage in training_data
    for entity in passage[1]["entities"]
}

labels

In [None]:
nlp = spacy.blank("en")
ner = nlp.add_pipe("ner")

for label in labels:
    ner.add_label(label)

training_data = [
    spacy.training.Example.from_dict(
        nlp.make_doc(primitive_example[0]),
        primitive_example[1],
    )
    for primitive_example in training_data
]

In [None]:
# Start the training
nlp.begin_training()

# Loop for 10 iterations
for itn in range(10):
    # Shuffle the training data
    random.shuffle(training_data)
    losses = {}

    # Batch the examples and iterate over them
    for batch in spacy.util.minibatch(training_data, size=2):

        # Update the model
        nlp.update(batch, losses=losses)
    print(losses)

In [None]:
for i in range(10, 20):

    print("\n\n===== prediction =====")

    spacy.displacy.render(
        nlp(training_data[i].text), 
        style="ent",
        options={
            "colors": {
                "Chemical": "lightgreen",
                "Disease": "orange",
            }
        }
    )

    print("===== reference =====")

    spacy.displacy.render(
        training_data[i].reference, 
        style="ent",
        options={
            "colors": {
                "Chemical": "lightgreen",
                "Disease": "orange",
            }
        }
    )