In [172]:
from dataclasses import dataclass
import pandas as pd
import pathlib
import typing as T

import ast

In [204]:
@dataclass
class RawDataConfig:
    root = pathlib.Path("/home/khaymon/hse_nlp/spacy_ner/data/raw")
    
    train = root / "train.csv"
    val = root / "validation.csv"
    test = root / "test.csv"
    
@dataclass
class PreparedDataConfig:
    root = pathlib.Path("/home/khaymon/hse_nlp/spacy_ner/data/prepared")
    
    train = root / "train.spacy"
    val = root / "validation.spacy"
    test = root / "test.spacy"

In [174]:
DATA_COLUMNS = [f"Section_{idx}" for idx in range(1, 7)]


def parse_json_str(json_str: str) -> T.Dict:
    return ast.literal_eval(json_str)

def parse_json_str_series(column: pd.Series) -> pd.Series:
    return column.apply(parse_json_str)

def read_preprocess_dataframe(path: pathlib.Path, data_columns: T.Sequence[str] = DATA_COLUMNS) -> pd.DataFrame:
    dataframe = pd.read_csv(path, index_col="ID")
    for column in data_columns:
        dataframe[column] = parse_json_str_series(dataframe[column])
        
    return dataframe


train = read_preprocess_dataframe(RawDataConfig.train)
val = read_preprocess_dataframe(RawDataConfig.val)
test = read_preprocess_dataframe(RawDataConfig.test)

In [199]:
Entity = T.Tuple[int, int, str]
Annotation = T.Dict[str, T.List[Entity]]
Dataset = T.List[T.Tuple[str, Annotation]]


def get_entities_from_dict(example: T.Dict) -> T.List[Entity]:
    raw_entities = example["Entity_Recognition"]
    if raw_entities is None:
        return []
    unique_entities = set(
        (ent["BeginOffset"], ent["EndOffset"], ent["Type"])
        for ent in example["Entity_Recognition"]
        if ent["BeginOffset"] != ent["EndOffset"]
    )
    return list(unique_entities)


def get_dataset_from_series(column: pd.Series) -> Dataset:
    dataset = []
    for row in column:
        text = row["Section_Content"]
        if text is None:
            continue

        annotation = {"entities": get_entities_from_dict(row)}
        dataset.append((text, annotation))
    
    return dataset


def get_dataset_from_dataframe(dataframe: pd.DataFrame, columns: T.Sequence[str] = DATA_COLUMNS):
    dataset = []
    for column in columns:
        dataset.extend(get_dataset_from_series(dataframe[column]))

    return dataset


train_dataset = get_dataset_from_dataframe(train)
val_dataset = get_dataset_from_dataframe(val)
test_dataset = get_dataset_from_dataframe(test)

In [200]:
from spacy.tokens import DocBin
import spacy

nlp = spacy.blank("en")


def get_docbin(dataset: Dataset) -> DocBin:
    doc_bin = DocBin()
    for text, annotations in dataset:
        doc = nlp(text)
        entities = []
        for start_idx, end_idx, label in annotations["entities"]:
            span = doc.char_span(start_idx, end_idx, label=label)
            if span is not None:
                entities.append(span)
        try:
            doc.ents = entities
        except:
            ...
        doc_bin.add(doc)
        
    return doc_bin


train_docbin = get_docbin(train_dataset)
val_docbin = get_docbin(val_dataset)
test_docbin = get_docbin(test_dataset)

In [205]:
train_docbin.to_disk(PreparedDataConfig.train)
val_docbin.to_disk(PreparedDataConfig.val)
test_docbin.to_disk(PreparedDataConfig.test)

In [213]:
train_docbin.

TypeError: 'DocBin' object is not iterable

In [170]:
epochs = 20
optimizer = nlp.create_optimizer()

nlp = spacy.blank("en")

for _ in range(epochs):
    random.shuffle(train_dataset)
    
    for raw_example in train_dataset:
        doc = nlp(raw_example[0])
        annotation = raw_example[1]
        
        print(annotation, ';'.join(raw_example[0][from_idx:to_idx] for from_idx, to_idx, _ in annotation["entities"]))
        example = Example.from_dict(doc, annotation)
        nlp.update([example], sgd=optimizer)
    
    ner = nlp.get_pipe("ner")
    ner.to_disk("ner_model_spacy")

{'entities': [(5, 18, 'TREATMENT'), (20, 33, 'TREATMENT'), (44, 56, 'DX_NAME'), (104, 120, 'TREATMENT'), (121, 132, 'TIME_TO_MEDICATION_NAME'), (167, 181, 'PROBLEM'), (191, 200, 'DX_NAME'), (214, 254, 'PROBLEM'), (229, 230, 'NUMBER'), (279, 294, 'DX_NAME'), (307, 320, 'PROBLEM'), (321, 329, 'TIME_TO_DX_NAME'), (333, 341, 'DX_NAME'), (350, 355, 'DX_NAME'), (399, 405, 'DX_NAME'), (414, 432, 'DX_NAME'), (501, 511, 'DX_NAME'), (525, 529, 'DX_NAME'), (531, 538, 'DX_NAME'), (540, 548, 'DX_NAME'), (552, 585, 'PROBLEM'), (589, 593, 'SYSTEM_ORGAN_SITE'), (598, 602, 'SYSTEM_ORGAN_SITE'), (603, 616, 'DX_NAME'), (630, 634, 'DX_NAME'), (636, 644, 'DX_NAME'), (646, 653, 'DX_NAME'), (657, 686, 'PROBLEM'), (690, 695, 'DX_NAME'), (709, 722, 'TEST'), (726, 728, 'NUMBER'), (743, 752, 'DX_NAME'), (781, 790, 'DX_NAME'), (801, 809, 'DX_NAME'), (813, 818, 'PROBLEM'), (823, 838, 'PROBLEM'), (842, 852, 'DX_NAME'), (872, 876, 'DX_NAME'), (908, 913, 'SYSTEM_ORGAN_SITE'), (965, 981, 'PROBLEM'), (1001, 1012, 'DX_N

ValueError: [E103] Trying to set conflicting doc.ents: '(214, 254, 'PROBLEM')' and '(229, 230, 'NUMBER')'. A token can only be part of one entity, so make sure the entities you're setting don't overlap. To work with overlapping entities, consider using doc.spans instead.