# Load data

We restrict the data by entity type for now

In [2]:
import logging

import dedupe

dedupe_logger = logging.getLogger(dedupe.__name__)
dedupe_logger.setLevel(logging.DEBUG)
dedupe_logger.handlers = []
dedupe_logger.addHandler(logging.StreamHandler())

In [3]:
import neo4j

neo4j_driver = neo4j.AsyncGraphDatabase.driver("neo4j://localhost:7687")

In [4]:
DATASHARE_BASE_URL = "http://localhost:8080"
DATASHARE_PROJECT_URL = f"{DATASHARE_BASE_URL}/#/d/local-datashare"

In [5]:
from typing import AsyncGenerator, Optional


async def retrieve_ne(
    session: neo4j.AsyncSession,
    ne_category: str,
    *,
    limit: Optional[int] = None,
) -> AsyncGenerator[neo4j.Record, None]:
    query = f"""MATCH (ne:NamedEntity:{ne_category})-[rel]->(doc:Document)
OPTIONAL MATCH (doc)-[HAS_PARENT]->(rootDoc:Document)
RETURN ne, doc, rootDoc
"""
    if limit:
        query = f"""{query}
LIMIT {limit}
"""
    res = await session.run(query)
    async for rec in res:
        yield rec

In [6]:
from neo4j_app.ml.graph_dedupe import (
    NE_DEBUG_DOC_URL,
    NE_DEBUG_FILENAME,
    NE_DOC_CONTENT_TYPE,
    NE_DOC_DIR_NAME,
    NE_DOC_FILENAME,
    NE_DOC_ID,
    NE_DOC_ROOT_ID,
)
from neo4j_app.constants import NE_MENTION_NORM
import itertools
import string

In [7]:
from neo4j_app.constants import DOC_CONTENT_TYPE, DOC_DIRNAME, DOC_ID, DOC_PATH
from typing import Dict

_TRANSLATE_PUNCT = str.maketrans(dict(zip(string.punctuation, itertools.repeat(" "))))


# TODO: refine this violent preprocessing...


def _replace_double_white_spaces(s: str) -> str:
    while "  " in s:
        s = s.replace("  ", " ")
    return s


def preprocess_filename(filename: str) -> Optional[str]:
    filename = filename.translate(_TRANSLATE_PUNCT)
    filename = _replace_double_white_spaces(filename)
    if not filename:
        filename = None
    filename = filename.lower().strip()
    return filename


def preprocess_dirname(dirname: str) -> str:
    dirname = " ".join(item for item in dirname.split("/") if item)
    return dirname


def doc_url(project_url: str, *, doc_id: str, root_id: Optional[str]) -> str:
    return f"{project_url}/{root_id if root_id is not None else doc_id}/{doc_id}"


def neo4j_to_record(record: neo4j.Record, project_url: str) -> Dict:
    rec = dict()
    ne = record["ne"]
    rec[NE_MENTION_NORM] = ne[NE_MENTION_NORM]
    doc = record["doc"]
    rec[NE_DOC_ID] = doc[DOC_ID]
    rec[NE_DOC_DIR_NAME] = preprocess_dirname(doc[DOC_DIRNAME])
    raw_filename = doc[DOC_PATH].split("/")[-1]
    rec[NE_DOC_FILENAME] = preprocess_filename(raw_filename)
    rec[NE_DOC_CONTENT_TYPE] = doc[DOC_CONTENT_TYPE]
    root_doc = record["rootDoc"]
    root_id = None
    if root_doc is not None:
        root_id = root_doc[DOC_ID]
    rec[NE_DOC_ROOT_ID] = root_id
    # Debug
    rec[NE_DEBUG_DOC_URL] = doc_url(
        project_url, doc_id=rec[NE_DOC_ID], root_id=rec[NE_DOC_ROOT_ID]
    )
    rec[NE_DEBUG_FILENAME] = raw_filename
    return rec

In [8]:
import csv
from typing import AsyncIterable, List, TextIO



In [9]:
from neo4j_app import ROOT_DIR

DATA_PATH = ROOT_DIR.joinpath("data")
records_path = DATA_PATH / "person_records.csv"
training_set_path = DATA_PATH / "person_training.csv"
trained_model_path = DATA_PATH / "person_model.pickle"
clusters_path = DATA_PATH / "person_clusters.csv"
excluded_set_path = DATA_PATH / "excluded.txt"

In [10]:
from neo4j_app.ml.graph_dedupe import NE_FIELDNAMES, async_write_dataset

# TODO: remove the limit of training data
NUM_SAMPLES = None

async with neo4j_driver.session() as sess:
    if not records_path.exists():
        with records_path.open("w") as df:
            recs = (
                neo4j_to_record(rec, project_url=DATASHARE_PROJECT_URL)
                async for rec in retrieve_ne(
                    sess, ne_category="PERSON", limit=NUM_SAMPLES
                )
            )
            await async_write_dataset(recs, fieldnames=NE_FIELDNAMES, dataset_f=df)

# Training

In [None]:
import functools
from neo4j_app.ml.graph_dedupe import DocumentGraphDedupe, person_fields, run_training

# TODO: increase
training_sample_size = 50000
target_recall = 0.8


model = run_training(
    records_path,
    dedupe_getter=functools.partial(DocumentGraphDedupe, doc_id=NE_DOC_ID),
    fields_getter=functools.partial(person_fields, inside_doc=True),
    excluded_path=excluded_set_path,
    model_path=trained_model_path,
    training_path=training_set_path,
    sample_size=training_sample_size,
    id_column=NE_MENTION_NORM,
    recall=target_recall,
)