In [60]:
import logging

import dedupe

dedupe_logger = logging.getLogger(dedupe.__name__)
dedupe_logger.setLevel(logging.DEBUG)
dedupe_logger.handlers = []
dedupe_logger.addHandler(logging.StreamHandler())

In [61]:
from neo4j_app import ROOT_DIR

DATA_PATH = ROOT_DIR.joinpath("data")
records_path = DATA_PATH / "person_records.csv"
clusters_path = DATA_PATH / "person_clusters.json"
doc_level_dedupe_path = DATA_PATH / "person_deduped.csv"
excluded_set_path = DATA_PATH / "excluded.txt"
training_set_path = DATA_PATH / "person_training.csv"
graph_level_trained_model_path = DATA_PATH / "graph_level_person_model.pickle"
graph_level_records_path = DATA_PATH / "graph_level_person_records.csv"
graph_level_training_path = DATA_PATH / "graph_level_person_training.json"
graph_level_excluded_set_path = DATA_PATH / "graph_level_excluded.txt"

# Load a sample of the data

In [62]:
from neo4j_app.ml.graph_dedupe import NE_DOC_ID, NE_MENTION_NORM_DOC_ID
import csv
from dedupe._typing import RecordDict
from typing import Dict, Set, TextIO, Tuple
from neo4j_app.constants import NE_MENTION_NORM


def read_graph_level_records(
    dataset_f: TextIO,
    *,
    mention_column: str,
    doc_id_column: str,
    invalid_mentions: Set[str],
) -> Tuple[Dict[str, RecordDict], Set[str]]:
    reader = csv.DictReader(dataset_f)
    records = dict()
    invalid_records = set()
    for row in reader:
        # TODO: here we simplify the problem by using saying that all identical (mentionNorm, docId) refers to the same entity in
        #  practice that's not true. If Michel A and Michel B are both referred to as Michel in doc2 then this will result in a failure..
        rec_id = str((row[mention_column], row[doc_id_column]))
        if row[mention_column] not in invalid_mentions:
            row[NE_MENTION_NORM_DOC_ID] = rec_id
            records[rec_id] = row
        else:
            invalid_records.add(rec_id)
    return records, invalid_records

In [63]:
from typing import Sequence
from dedupe._typing import Data


def sample_data(dataset: Data, n_samples: int, sort_keys: Sequence[str]) -> Data:
    samples = sorted(dataset.items(), key=lambda i: tuple((i[1][k] for k in sort_keys)))
    return dict(samples[:n_samples])

In [64]:
from copy import deepcopy
from dedupe._typing import RecordDict, TrainingData
from typing import Callable, Dict


def add_mention_cluster_field(
    record: RecordDict,
    clusters: Dict,
    *,
    mention_field: str,
    cluster_field: str,
    doc_id_field: str
) -> RecordDict:
    record = deepcopy(record)
    record[cluster_field] = clusters[record[mention_field]]
    record[NE_MENTION_NORM_DOC_ID] = str((record[mention_field], record[doc_id_field]))
    return record


def augment_training_set(
    labeled_pairs: TrainingData, augment_fn: Callable[[RecordDict], RecordDict]
) -> TrainingData:
    distinct = [
        (augment_fn(lhs), augment_fn(rhs)) for lhs, rhs in labeled_pairs["distinct"]
    ]
    match = [(augment_fn(lhs), augment_fn(rhs)) for lhs, rhs in labeled_pairs["match"]]
    training = TrainingData(distinct=distinct, match=match)
    return training

In [65]:
from neo4j_app.ml.graph_dedupe import read_records

with doc_level_dedupe_path.open() as f:
    doc_level_clusters = read_records(f, id_column=NE_MENTION_NORM, invalid_ids=set())
doc_level_clusters = {
    rec_id: rec["cluster_id"] for rec_id, rec in doc_level_clusters.items()
}

In [66]:
from neo4j_app.ml.graph_dedupe import (
    NE_FIELDNAMES,
    NE_MENTION_CLUSTER,
)

N_SAMPLES = None
GRAPH_FIELDNAMES = NE_FIELDNAMES + [NE_MENTION_CLUSTER, NE_MENTION_NORM_DOC_ID]

if graph_level_records_path.exists():
    with graph_level_excluded_set_path.open() as f:
        invalid_ids = (line.strip() for line in f)
        invalid_ids = set(i for i in invalid_ids if i)
    with graph_level_records_path.open() as f:
        graph_level_data = read_records(
            f,
            id_column=NE_MENTION_NORM_DOC_ID,
            invalid_ids=invalid_ids,
        )
else:
    with records_path.open() as f:
        inv_mentions = None
        if not graph_level_excluded_set_path.exists():
            with excluded_set_path.open() as xf:
                inv_mentions = (line.strip() for line in xf)
                inv_mentions = set(i for i in inv_mentions if i)
        graph_level_data, invalid_ids = read_graph_level_records(
            f,
            mention_column=NE_MENTION_NORM,
            doc_id_column=NE_DOC_ID,
            invalid_mentions=inv_mentions,
        )
        graph_level_data = {
            rec_id: add_mention_cluster_field(
                rec,
                clusters=doc_level_clusters,
                mention_field=NE_MENTION_NORM,
                cluster_field=NE_MENTION_CLUSTER,
                doc_id_field=NE_DOC_ID,
            )
            for rec_id, rec in graph_level_data.items()
        }
        if not graph_level_excluded_set_path.exists():
            graph_level_excluded_set_path.write_text("\n".join(invalid_ids))
        else:
            with graph_level_excluded_set_path.open() as xf:
                invalid_ids = (line.strip() for line in xf)
                invalid_ids = set(i for i in invalid_ids if i)

if N_SAMPLES is not None:
    graph_level_data = sample_data(
        graph_level_data, n_samples=N_SAMPLES, sort_keys=["docId"]
    )

In [67]:
from typing import Set
from dedupe._typing import TrainingData
from dedupe import read_training


def filter_training_set(labeled_pairs: TrainingData, invalid: Set[str]) -> TrainingData:
    labeled_pairs = deepcopy(labeled_pairs)
    labeled_pairs["distinct"] = [
        (left, right)
        for left, right in labeled_pairs["distinct"]
        if not left[NE_MENTION_NORM] in invalid
        and not right[NE_MENTION_NORM] in invalid
    ]
    labeled_pairs["match"] = [
        (left, right)
        for left, right in labeled_pairs["match"]
        if not left[NE_MENTION_NORM] in invalid
        and not right[NE_MENTION_NORM] in invalid
    ]
    return labeled_pairs

# Compute graph level features

In [68]:
from neo4j_app.ml.graph_dedupe import write_dataset

with graph_level_records_path.open("w") as f:
    write_dataset(graph_level_data.values(), fieldnames=GRAPH_FIELDNAMES, dataset_f=f)

In [69]:
import functools
from dedupe import write_training

if not graph_level_training_path.exists():
    with excluded_set_path.open() as xf:
        inv_mentions = (line.strip() for line in xf)
        inv_mentions = set(i for i in inv_mentions if i)
    with training_set_path.open() as f:
        training_set = filter_training_set(read_training(f), invalid=inv_mentions)
    add_mention_cluster_field_fn = functools.partial(
        add_mention_cluster_field,
        clusters=doc_level_clusters,
        mention_field=NE_MENTION_NORM,
        cluster_field=NE_MENTION_CLUSTER,
        doc_id_field=NE_DOC_ID,
    )
    new_training_set = augment_training_set(training_set, add_mention_cluster_field_fn)
    with graph_level_training_path.open("w") as f:
        write_training(new_training_set, f)

# Make the most of the already annotated data

Keep the training set, compute the new features

In [70]:
import functools
from neo4j_app.ml.graph_dedupe import (
    ConfigurableClassifierDedupe,
    person_fields,
    run_training,
)

# TODO: increase
training_sample_size = 50000
target_recall = 0.8

clf_args = {"max_iter": 100000}
model = run_training(
    graph_level_data,
    dedupe_getter=functools.partial(ConfigurableClassifierDedupe, clf_args=clf_args),
    fields_getter=functools.partial(person_fields, inside_docs=False),
    excluded_path=excluded_set_path,
    model_path=graph_level_trained_model_path,
    training_path=graph_level_training_path,
    sample_size=training_sample_size,
    id_column=NE_MENTION_NORM,
    recall=target_recall,
)

reading training from file


KeyboardInterrupt: 