# Imports

In [1]:
from agatha.ml.hypothesis_predictor import HypothesisPredictor
from agatha.ml.hypothesis_predictor.predicate_util import clean_coded_term
from pathlib import Path
import json
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import average_precision_score, roc_auc_score


# Its likely that we're going to get a "source code has changed" warning
# We're okay with that
import warnings
warnings.filterwarnings('ignore')

# Configure and Load the Agatha Model

In [2]:
#################################
# YOU NEED TO CHANGE THESE PATHS#
MODEL_PATH =     Path("/zfs/safrolab/users/jsybran/agatha/data/releases/2020/hypothesis_predictor/2020_05_04.pt")
GRAPH_DB_PATH =  Path("/zfs/safrolab/users/jsybran/agatha/data/releases/2020/predicate_graph.sqlite3")
ENTITY_DB_PATH = Path("/zfs/safrolab/users/jsybran/agatha/data/releases/2020/predicate_entities.sqlite3")
EMBEDDING_DIR =  Path("/zfs/safrolab/users/jsybran/agatha/data/releases/2020/embeddings")
# YOU NEED TO CHANGE THESE PATHS#
#################################


# Make sure all the paths were set properly
assert MODEL_PATH.is_file(), f"Cannot find {MODEL_PATH}"
assert MODEL_PATH.suffix == ".pt", f"Expecting model in pytorch model format, not ckpt."
assert GRAPH_DB_PATH.is_file(), f"Cannot find {GRAPH_DB_PATH}"
assert ENTITY_DB_PATH.is_file(), f"Cannot find {ENTITY_DB_PATH}"
assert EMBEDDING_DIR.is_dir(), f"Cannot find {EMBEDDING_DIR}"

In [3]:
# Load the Agatha model
model = torch.load(MODEL_PATH)
model.configure_paths(
    graph_db=GRAPH_DB_PATH,
    entity_db=ENTITY_DB_PATH,
    embedding_dir=EMBEDDING_DIR
)

In [4]:
# Move the model to GPU, optional step to improve performance
model = model.cuda()

In [5]:
# Preload the model, optional step to improve performance
# Warning, will take a minute
model.preload()

# Load Test Data

In [6]:
# FIXED PATHS TO TEST DATA
BENCHMARK_DIR = Path(
    "/zfs/safrolab/users/jsybran/agatha"
    "/data/benchmarks/predicates_2015/"
)
TYPED_POPULAR_PREDICATES = BENCHMARK_DIR.joinpath(
    "all_pairs_top_20_types.json"
)
MOLIERE_BENCHMARK_POSITIVES = BENCHMARK_DIR.joinpath(
    "moliere_2015/published.txt"
)
MOLIERE_BENCHMARK_NEGATIVES = BENCHMARK_DIR.joinpath(
    "moliere_2015/noise.txt"
)

# Check that all files are in place
for p in [
    TYPED_POPULAR_PREDICATES,
    MOLIERE_BENCHMARK_POSITIVES,
    MOLIERE_BENCHMARK_NEGATIVES
]:
    assert p.is_file(), f"Cannot find {p} file"

In [7]:
# LOAD DATA
"""
Schema:
{
  "<query_set_name>": [
    {
      "source": "<name>",
      "target": "<name>",
      "label": [0 or 1]
    }
  ]
  ...
}

The file TYPED_POPULAR_PREDICATES includes 100 predicate entries per predicate type.
Types include:
  'aapp:dsyn', 'aapp:gngm', 'bpoc:aapp', 'gngm:neop', 'dsyn:dsyn',
  'cell:aapp', 'gngm:aapp', 'dsyn:humn', 'gngm:celf', 'orch:gngm', 
  'phsu:dsyn', 'bacs:aapp', 'gngm:cell', 'gngm:dsyn', 'gngm:gngm', 
  'aapp:neop', 'aapp:aapp', 'topp:dsyn', 'bacs:gngm', 'aapp:cell'
"""
query_sets = json.load(open(TYPED_POPULAR_PREDICATES))

In [8]:
# Add the moliere benchmark as a query set
"""
These files look like:
  C0454279|C0043251|2016
  C1563740|C0729627|2016
  C1522549|C0023759|2016
  C0516977|C0454448|2017
  ...
  
Note, Agatha can handle these names automatically
"""
query_sets["moliere"] = []
for path, label in [
    (MOLIERE_BENCHMARK_POSITIVES, 1), 
    (MOLIERE_BENCHMARK_NEGATIVES, 0)
]:
    with open(path) as file:
        for line in file:
            source, target, _ = line.strip().split("|")
            # this replaced "C###" with "m:c###"
            source = clean_coded_term(source)
            target = clean_coded_term(target)
            query_sets["moliere"].append(dict(
                source=source,
                target=target,
                label=label
            ))

In [9]:
# Filter out any predicate that uses terms missing from this model
# Clean all subject and object names
valid_keys = model.graph.keys()
for set_name, predicates in query_sets.items():
    removed_predicates = []
    kept_predicates = []
    for pred in predicates:
        if pred["source"] in valid_keys and pred["target"] in valid_keys:
            kept_predicates.append(pred)
        else:
            removed_predicates.append(pred)
    # Print some debug info
    if len(removed_predicates) > 0:
        num_pos = len([p for p in removed_predicates if p["label"]==1])
        num_neg = len([p for p in removed_predicates if p["label"]==0])
        print(f"Removed {num_pos} positive and {num_neg} negative examples from the {set_name} set")
    predicates.clear()
    predicates += kept_predicates

Removed 2 positive and 644 negative examples from the moliere set


In [10]:
# Print out query set details:
for set_name, predicates in query_sets.items():
    num_pos = len([p for p in predicates if p["label"]==1])
    print(f"{set_name}:\tTotal: {len(predicates)}\tPos: {num_pos/len(predicates)*100:2.2f}%")

aapp:dsyn:	Total: 4108	Pos: 8.23%
aapp:gngm:	Total: 2550	Pos: 9.06%
bpoc:aapp:	Total: 4161	Pos: 8.41%
gngm:neop:	Total: 2496	Pos: 18.59%
dsyn:dsyn:	Total: 5223	Pos: 4.14%
cell:aapp:	Total: 2524	Pos: 8.84%
gngm:aapp:	Total: 3166	Pos: 8.97%
dsyn:humn:	Total: 3225	Pos: 7.13%
gngm:celf:	Total: 1018	Pos: 21.22%
orch:gngm:	Total: 4488	Pos: 9.00%
phsu:dsyn:	Total: 4895	Pos: 6.44%
bacs:aapp:	Total: 3614	Pos: 5.59%
gngm:cell:	Total: 2412	Pos: 9.70%
gngm:dsyn:	Total: 4638	Pos: 6.08%
gngm:gngm:	Total: 4286	Pos: 6.25%
aapp:neop:	Total: 3002	Pos: 12.46%
aapp:aapp:	Total: 4828	Pos: 4.04%
topp:dsyn:	Total: 6269	Pos: 4.40%
bacs:gngm:	Total: 2477	Pos: 8.44%
aapp:cell:	Total: 1677	Pos: 10.38%
moliere:	Total: 1354	Pos: 73.71%


# Evaluate Each Predicate

This one is going to take a while

In [None]:
# This one is going to take a while
# For each query set
for set_name, predicates in query_sets.items():
    print("Predicting:", set_name)
    # pull out source and target for each predicate
    
    queries = tqdm([(p["source"], p["target"]) for p in predicates])
    # Add the model prediction score to each predicate
    for prediction, predicate in zip(
        model.predict_from_terms(queries, batch_size=64), 
        predicates
    ):
        predicate["prediction"] = prediction

Predicting: aapp:dsyn


HBox(children=(FloatProgress(value=0.0, max=4108.0), HTML(value='')))




# Calculate Metrics

In [16]:
def reciprocal_rank(labels_in_order):
    """
    Assumes that if i < j then labels_in_order[i]
    got a higher score than labels_in_order[j]
    """
    for idx, lbl in enumerate(labels_in_order):
        if lbl == 1:
            return 1 / (idx + 1)
    return 0

In [18]:
def precision_at_k(labels_in_order, k):
    """
    Assumes that if i < j then labels_in_order[i]
    got a higher score than labels_in_order[j].
    Assumes positive label is 1 and negative label is 0
    """
    return sum(labels_in_order[:k]) / k

In [21]:
def average_precision_at_k(labels_in_order, k):
    """
    Assumes that if i < j then labels_in_order[i]
    got a higher score than labels_in_order[j].
    Assumes positive label is 1 and negative label is 0
    """
    numerator = 0
    gtp = 0
    for idx in range(k):
        if labels_in_order[idx] == 1:
            gtp += 1
            numerator += (gtp/(idx+1))
    if gtp == 0:
        return 0
    else:
        return (numerator / gtp)

In [22]:
for set_name, predicates in query_sets.items():
    print(set_name)
    # predictions is sorted high to low
    # predictions[i] corresponds to labels[i]
    predictions, labels = zip(
        *sorted(
            [
                (p["prediction"], p["label"])
                for p in predicates
            ],
            reverse=True
        )
    )
    roc_auc = roc_auc_score(y_true=labels, y_score=predictions)
    pr_auc = average_precision_score(y_true=labels, y_score=predictions)
    rr = reciprocal_rank(labels)
    p_at_10 = precision_at_k(labels, 10)
    p_at_100 = precision_at_k(labels, 100)
    ap_at_10 = average_precision_at_k(labels, 10)
    ap_at_100 = average_precision_at_k(labels, 100)
    print(f"\tROC AUC: {roc_auc:0.4f}")
    print(f"\tPR AUC:  {pr_auc:0.4f}")
    print(f"\tRR:      {rr:0.4f}")
    print(f"\tP@10:    {p_at_10:0.4f}")
    print(f"\tP@100:   {p_at_100:0.4f}")
    print(f"\tAP@10:   {ap_at_10:0.4f}")
    print(f"\tAP@100:  {ap_at_100:0.4f}")

aapp:dsyn
	ROC AUC: 0.8058
	PR AUC:  0.2827
	RR:      1.0000
	P@10:    0.6000
	P@100:   0.4700
	AP@10:   0.6815
	AP@100:  0.5207
aapp:gngm
	ROC AUC: 0.7689
	PR AUC:  0.2542
	RR:      0.5000
	P@10:    0.6000
	P@100:   0.3100
	AP@10:   0.6347
	AP@100:  0.4680
bpoc:aapp
	ROC AUC: 0.7777
	PR AUC:  0.2532
	RR:      0.3333
	P@10:    0.6000
	P@100:   0.3800
	AP@10:   0.5653
	AP@100:  0.4619
gngm:neop
	ROC AUC: 0.8124
	PR AUC:  0.4743
	RR:      1.0000
	P@10:    0.9000
	P@100:   0.5800
	AP@10:   0.8412
	AP@100:  0.7232
dsyn:dsyn
	ROC AUC: 0.7989
	PR AUC:  0.1919
	RR:      0.5000
	P@10:    0.7000
	P@100:   0.3100
	AP@10:   0.7407
	AP@100:  0.4827
cell:aapp
	ROC AUC: 0.7830
	PR AUC:  0.2709
	RR:      1.0000
	P@10:    0.6000
	P@100:   0.3500
	AP@10:   0.8302
	AP@100:  0.4994
gngm:aapp
	ROC AUC: 0.7576
	PR AUC:  0.2593
	RR:      1.0000
	P@10:    0.6000
	P@100:   0.3400
	AP@10:   0.8524
	AP@100:  0.4997
dsyn:humn
	ROC AUC: 0.8031
	PR AUC:  0.2189
	RR:      1.0000
	P@10:    0.5000
	P@100:   0.2400
	A