# Imports

In [1]:
from agatha.ml.hypothesis_predictor import HypothesisPredictor
from agatha.ml.hypothesis_predictor.predicate_util import clean_coded_term
from pathlib import Path
import json
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import average_precision_score, roc_auc_score


# Its likely that we're going to get a "source code has changed" warning
# We're okay with that
import warnings
warnings.filterwarnings('ignore')


# Configure and Load the Agatha Model

In [2]:
#################################
# YOU NEED TO CHANGE THESE PATHS#
ROOT =           Path("/zfs/safrolab/users/jsybran/agatha/data"
                      "/releases/2015/model_release")
MODEL_PATH =     ROOT.joinpath("model.pt")
GRAPH_DB_PATH =  ROOT.joinpath("predicate_graph.sqlite3")
ENTITY_DB_PATH = ROOT.joinpath("predicate_entities.sqlite3")
EMBEDDING_DIR =  ROOT.joinpath("predicate_embeddings")
# YOU NEED TO CHANGE THESE PATHS#

#MODEL_PATH=Path("/zfs/safrolab/users/jsybran/agatha/data/experimental/edge2vec/default_param/agatha_model.pt")
#GRAPH_DB_PATH=Path("/zfs/safrolab/users/jsybran/agatha/data/releases/2015/predicate_graph.sqlite3")
#ENTITY_DB_PATH=Path("/zfs/safrolab/users/jsybran/agatha/data/experimental/edge2vec/default_param/embeddings/entities.sqlite3")
#EMBEDDING_DIR=Path("/zfs/safrolab/users/jsybran/agatha/data/experimental/edge2vec/default_param/embeddings")

#################################


# Make sure all the paths were set properly
assert MODEL_PATH.is_file(), f"Cannot find {MODEL_PATH}"
assert MODEL_PATH.suffix == ".pt", f"Expecting model in pytorch model format, not ckpt."
assert GRAPH_DB_PATH.is_file(), f"Cannot find {GRAPH_DB_PATH}"
assert ENTITY_DB_PATH.is_file(), f"Cannot find {ENTITY_DB_PATH}"
assert EMBEDDING_DIR.is_dir(), f"Cannot find {EMBEDDING_DIR}"

In [3]:
# Load the Agatha model
model = torch.load(MODEL_PATH)
model.configure_paths(
    graph_db=GRAPH_DB_PATH,
    entity_db=ENTITY_DB_PATH,
    embedding_dir=EMBEDDING_DIR
)

In [4]:
# Move the model to GPU, optional step to improve performance
model = model.cuda()

In [5]:
# Preload the model, optional step to improve performance
# Warning, will take a minute
model.preload()

# Load Test Data

In [6]:
# FIXED PATHS TO TEST DATA
BENCHMARK_DIR = Path(
    "/zfs/safrolab/users/jsybran/agatha"
    "/data/benchmarks/predicates_2015/"
)
TYPED_POPULAR_PREDICATES = BENCHMARK_DIR.joinpath(
    "all_pairs_top_20_types.json"
)
MOLIERE_BENCHMARK_POSITIVES = BENCHMARK_DIR.joinpath(
    "moliere_2015/published.txt"
)
MOLIERE_BENCHMARK_NEGATIVES = BENCHMARK_DIR.joinpath(
    "moliere_2015/noise.txt"
)

# Check that all files are in place
for p in [
    TYPED_POPULAR_PREDICATES,
    MOLIERE_BENCHMARK_POSITIVES,
    MOLIERE_BENCHMARK_NEGATIVES
]:
    assert p.is_file(), f"Cannot find {p} file"

In [7]:
# LOAD DATA
"""
Schema:
{
  "<query_set_name>": [
    {
      "source": "<name>",
      "target": "<name>",
      "label": [0 or 1]
    }
  ]
  ...
}

The file TYPED_POPULAR_PREDICATES includes 100 predicate entries per predicate type.
Types include:
  'aapp:dsyn', 'aapp:gngm', 'bpoc:aapp', 'gngm:neop', 'dsyn:dsyn',
  'cell:aapp', 'gngm:aapp', 'dsyn:humn', 'gngm:celf', 'orch:gngm', 
  'phsu:dsyn', 'bacs:aapp', 'gngm:cell', 'gngm:dsyn', 'gngm:gngm', 
  'aapp:neop', 'aapp:aapp', 'topp:dsyn', 'bacs:gngm', 'aapp:cell'
"""
query_sets = json.load(open(TYPED_POPULAR_PREDICATES))

In [8]:
# Add the moliere benchmark as a query set
"""
These files look like:
  C0454279|C0043251|2016
  C1563740|C0729627|2016
  C1522549|C0023759|2016
  C0516977|C0454448|2017
  ...
  
Note, Agatha can handle these names automatically
"""
query_sets["moliere"] = []
for path, label in [
    (MOLIERE_BENCHMARK_POSITIVES, 1), 
    (MOLIERE_BENCHMARK_NEGATIVES, 0)
]:
    with open(path) as file:
        for line in file:
            source, target, _ = line.strip().split("|")
            # this replaced "C###" with "m:c###"
            source = clean_coded_term(source)
            target = clean_coded_term(target)
            query_sets["moliere"].append(dict(
                source=source,
                target=target,
                label=label
            ))

In [9]:
# Filter out any predicate that uses terms missing from this model
# Clean all subject and object names
valid_keys = model.graph.keys()
for set_name, predicates in query_sets.items():
    removed_predicates = []
    kept_predicates = []
    for pred in predicates:
        if pred["source"] in valid_keys and pred["target"] in valid_keys:
            kept_predicates.append(pred)
        else:
            removed_predicates.append(pred)
    # Print some debug info
    if len(removed_predicates) > 0:
        num_pos = len([p for p in removed_predicates if p["label"]==1])
        num_neg = len([p for p in removed_predicates if p["label"]==0])
        print(f"Removed {num_pos} positive and {num_neg} negative examples from the {set_name} set")
    predicates.clear()
    predicates += kept_predicates

Removed 3 positive and 655 negative examples from the moliere set


In [10]:
# Print out query set details:
for set_name, predicates in query_sets.items():
    num_pos = len([p for p in predicates if p["label"]==1])
    print(f"{set_name}:\tTotal: {len(predicates)}\tPos: {num_pos/len(predicates)*100:2.2f}%")

aapp:dsyn:	Total: 4108	Pos: 8.23%
aapp:gngm:	Total: 2550	Pos: 9.06%
bpoc:aapp:	Total: 4161	Pos: 8.41%
gngm:neop:	Total: 2496	Pos: 18.59%
dsyn:dsyn:	Total: 5223	Pos: 4.14%
cell:aapp:	Total: 2524	Pos: 8.84%
gngm:aapp:	Total: 3166	Pos: 8.97%
dsyn:humn:	Total: 3225	Pos: 7.13%
gngm:celf:	Total: 1018	Pos: 21.22%
orch:gngm:	Total: 4488	Pos: 9.00%
phsu:dsyn:	Total: 4895	Pos: 6.44%
bacs:aapp:	Total: 3614	Pos: 5.59%
gngm:cell:	Total: 2412	Pos: 9.70%
gngm:dsyn:	Total: 4638	Pos: 6.08%
gngm:gngm:	Total: 4286	Pos: 6.25%
aapp:neop:	Total: 3002	Pos: 12.46%
aapp:aapp:	Total: 4828	Pos: 4.04%
topp:dsyn:	Total: 6269	Pos: 4.40%
bacs:gngm:	Total: 2477	Pos: 8.44%
aapp:cell:	Total: 1677	Pos: 10.38%
moliere:	Total: 1342	Pos: 74.29%


# Evaluate Each Predicate

This one is going to take a while

In [11]:
# This one is going to take a while
# For each query set
for set_name, predicates in query_sets.items():
    print("Predicting:", set_name)
    # pull out source and target for each predicate
    
    queries = tqdm([(p["source"], p["target"]) for p in predicates])
    # Add the model prediction score to each predicate
    for prediction, predicate in zip(
        model.predict_from_terms(queries, batch_size=64), 
        predicates
    ):
        predicate["prediction"] = prediction

Predicting: aapp:dsyn


HBox(children=(FloatProgress(value=0.0, max=4108.0), HTML(value='')))


Predicting: aapp:gngm


HBox(children=(FloatProgress(value=0.0, max=2550.0), HTML(value='')))


Predicting: bpoc:aapp


HBox(children=(FloatProgress(value=0.0, max=4161.0), HTML(value='')))


Predicting: gngm:neop


HBox(children=(FloatProgress(value=0.0, max=2496.0), HTML(value='')))


Predicting: dsyn:dsyn


HBox(children=(FloatProgress(value=0.0, max=5223.0), HTML(value='')))


Predicting: cell:aapp


HBox(children=(FloatProgress(value=0.0, max=2524.0), HTML(value='')))


Predicting: gngm:aapp


HBox(children=(FloatProgress(value=0.0, max=3166.0), HTML(value='')))


Predicting: dsyn:humn


HBox(children=(FloatProgress(value=0.0, max=3225.0), HTML(value='')))


Predicting: gngm:celf


HBox(children=(FloatProgress(value=0.0, max=1018.0), HTML(value='')))


Predicting: orch:gngm


HBox(children=(FloatProgress(value=0.0, max=4488.0), HTML(value='')))


Predicting: phsu:dsyn


HBox(children=(FloatProgress(value=0.0, max=4895.0), HTML(value='')))


Predicting: bacs:aapp


HBox(children=(FloatProgress(value=0.0, max=3614.0), HTML(value='')))


Predicting: gngm:cell


HBox(children=(FloatProgress(value=0.0, max=2412.0), HTML(value='')))


Predicting: gngm:dsyn


HBox(children=(FloatProgress(value=0.0, max=4638.0), HTML(value='')))


Predicting: gngm:gngm


HBox(children=(FloatProgress(value=0.0, max=4286.0), HTML(value='')))


Predicting: aapp:neop


HBox(children=(FloatProgress(value=0.0, max=3002.0), HTML(value='')))


Predicting: aapp:aapp


HBox(children=(FloatProgress(value=0.0, max=4828.0), HTML(value='')))


Predicting: topp:dsyn


HBox(children=(FloatProgress(value=0.0, max=6269.0), HTML(value='')))


Predicting: bacs:gngm


HBox(children=(FloatProgress(value=0.0, max=2477.0), HTML(value='')))


Predicting: aapp:cell


HBox(children=(FloatProgress(value=0.0, max=1677.0), HTML(value='')))


Predicting: moliere


HBox(children=(FloatProgress(value=0.0, max=1342.0), HTML(value='')))




In [12]:
# Balance Moliere set
pos_samples = list(filter(lambda r: r["label"]==1, query_sets["moliere"]))
neg_samples = list(filter(lambda r: r["label"]==0, query_sets["moliere"]))
l = min(len(pos_samples), len(neg_samples))
query_sets["moliere"] = pos_samples[:l] + neg_samples[:l]

# Calculate Metrics

In [13]:
def reciprocal_rank(labels_in_order):
    """
    Assumes that if i < j then labels_in_order[i]
    got a higher score than labels_in_order[j]
    """
    for idx, lbl in enumerate(labels_in_order):
        if lbl == 1:
            return 1 / (idx + 1)
    return 0

In [14]:
def precision_at_k(labels_in_order, k):
    """
    Assumes that if i < j then labels_in_order[i]
    got a higher score than labels_in_order[j].
    Assumes positive label is 1 and negative label is 0
    """
    return sum(labels_in_order[:k]) / k

In [15]:
def average_precision_at_k(labels_in_order, k):
    """
    Assumes that if i < j then labels_in_order[i]
    got a higher score than labels_in_order[j].
    Assumes positive label is 1 and negative label is 0
    """
    numerator = 0
    gtp = 0
    for idx in range(k):
        if labels_in_order[idx] == 1:
            gtp += 1
            numerator += (gtp/(idx+1))
    if gtp == 0:
        return 0
    else:
        return (numerator / gtp)

In [16]:
query_set2metrics = {}
for set_name, predicates in query_sets.items():
    # predictions is sorted high to low
    # predictions[i] corresponds to labels[i]
    predictions, labels = zip(
        *sorted(
            [
                (p["prediction"], p["label"])
                for p in predicates
            ],
            reverse=True
        )
    )
    query_set2metrics[set_name]=dict(
      roc_auc=roc_auc_score(y_true=labels, y_score=predictions),
      pr_auc=average_precision_score(y_true=labels, y_score=predictions),
      rr=reciprocal_rank(labels),
      p_at_10=precision_at_k(labels, 10),
      p_at_100=precision_at_k(labels, 100),
      ap_at_10=average_precision_at_k(labels, 10),
      ap_at_100=average_precision_at_k(labels, 100),
    )

In [17]:
print(model)

HypothesisPredictor(
  (embedding_transformation): Linear(in_features=512, out_features=512, bias=True)
  (encode_predicate_data): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_feature

In [18]:
import pandas

In [19]:
pandas.DataFrame(query_set2metrics).transpose()

Unnamed: 0,roc_auc,pr_auc,rr,p_at_10,p_at_100,ap_at_10,ap_at_100
aapp:dsyn,0.75275,0.229522,0.5,0.7,0.37,0.604252,0.459445
aapp:gngm,0.747492,0.212106,1.0,0.2,0.27,0.642857,0.272728
bpoc:aapp,0.75501,0.215246,1.0,0.5,0.29,0.644444,0.364908
gngm:neop,0.708031,0.332809,0.25,0.4,0.42,0.317262,0.460145
dsyn:dsyn,0.795387,0.196792,1.0,0.6,0.31,0.877381,0.537078
cell:aapp,0.732875,0.213951,1.0,0.4,0.34,0.6,0.369473
gngm:aapp,0.750258,0.235004,1.0,0.4,0.37,0.527083,0.429077
dsyn:humn,0.784245,0.210316,1.0,0.5,0.27,0.696825,0.403617
gngm:celf,0.671406,0.313398,0.25,0.2,0.32,0.291667,0.307716
orch:gngm,0.77353,0.225531,0.333333,0.4,0.31,0.35,0.291865
