In [108]:
import pandas as pd
import random

# Read the triples from the file
df = pd.read_csv('drkg.tsv', sep='\t', header=None, names=['head', 'relation', 'tail'])

In [109]:
df.info

<bound method DataFrame.info of                 head                        relation         tail
0         Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::2157
1         Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::5264
2         Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::2158
3         Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::3309
4         Gene::2157  bioarx::HumGenHumGen:Gene:Gene  Gene::28912
...              ...                             ...          ...
5874256  Gene::29099     STRING::REACTION::Gene:Gene   Gene::1643
5874257  Gene::51645     STRING::REACTION::Gene:Gene   Gene::3183
5874258    Gene::865    STRING::CATALYSIS::Gene:Gene    Gene::983
5874259   Gene::1066      STRING::BINDING::Gene:Gene   Gene::7365
5874260   Gene::6118      STRING::BINDING::Gene:Gene   Gene::1111

[5874261 rows x 3 columns]>

In [110]:
from pykeen.triples import TriplesFactory

In [111]:
#Get previously used training triples
rules_training_triples = TriplesFactory.from_path_binary('TransE + rules/training_triples')
no_rules_training_triples = TriplesFactory.from_path_binary('TransE (no rules)/training_triples')\

In [112]:
#Filter out entities/relations that weren't in the training data ()
df_rules_sample = df.loc[df['head'].isin(rules_training_triples.entity_to_id)]
df_rules_sample = df_sample.loc[df['tail'].isin(rules_training_triples.entity_to_id)]
df_rules_sample = df_sample.loc[df['relation'].isin(rules_training_triples.relation_to_id)]

In [113]:
df_no_rules_sample = df.loc[df['head'].isin(no_rules_training_triples.entity_to_id)]
df_no_rules_sample = df_sample.loc[df['tail'].isin(no_rules_training_triples.entity_to_id)]
df_no_rules_sample = df_sample.loc[df['relation'].isin(no_rules_training_triples.relation_to_id)]

In [114]:
# Check if there are at least 10,000 records
if len(df_rules_sample) < 10000:
    print("The file has fewer than 10,000 records.")
else:
    # Randomly select 10,000 records
    df_rules_sample = df_rules_sample.sample(n=10000, random_state=random.randint(1, 100))

In [115]:
# Check if there are at least 10,000 records
if len(df_no_rules_sample) < 10000:
    print("The file has fewer than 10,000 records.")
else:
    # Randomly select 10,000 records
    df_no_rules_sample = df_no_rules_sample.sample(n=10000, random_state=random.randint(1, 100))

In [116]:
df_rules_sample.head(10)

Unnamed: 0,head,relation,tail
634752,Compound::DB00700,DRUGBANK::ddi-interactor-in::Compound:Compound,Compound::DB01366
409437,Compound::DB00425,DRUGBANK::ddi-interactor-in::Compound:Compound,Compound::DB01175
156610,Compound::DB00105,DRUGBANK::ddi-interactor-in::Compound:Compound,Compound::DB13203
3518360,Compound::DB01059,Hetionet::CcSE::Compound:Side Effect,Side Effect::C0011991
3155335,Gene::6608,Hetionet::Gr>G::Gene:Gene,Gene::1066
5138217,Gene::7039,STRING::BINDING::Gene:Gene,Gene::9919
1371227,Compound::DB09273,DRUGBANK::ddi-interactor-in::Compound:Compound,Compound::DB12338
1257977,Compound::DB06738,DRUGBANK::ddi-interactor-in::Compound:Compound,Compound::DB13234
1984127,Gene::528,Hetionet::GpBP::Gene:Biological Process,Biological Process::GO:0007169
2459037,Gene::6310,Hetionet::GiG::Gene:Gene,Gene::23051


In [117]:
# Create a triples factory from the DataFrame
testing_rules_factory = TriplesFactory.from_labeled_triples(df_rules_sample.values, 
    entity_to_id=rules_training_triples.entity_to_id, relation_to_id=rules_training_triples.relation_to_id)
testing_no_rules_factory = TriplesFactory.from_labeled_triples(df_no_rules_sample.values, 
    entity_to_id=no_rules_training_triples.entity_to_id, relation_to_id=no_rules_training_triples.relation_to_id)

In [118]:
#Load in pretrained models

import torch

model_with_rules = torch.load('TransE + rules/trained_model.pkl')
model_without_rules = torch.load('TransE (no rules)/trained_model.pkl')

Recommended to do Precision Recall curves according to this paper: https://arxiv.org/abs/1505.04094
Rank based evaluation seems to be overall worse.

Pykeen rank based evaluator paper here: https://github.com/pykeen/ranking-metrics-manuscript
with associated paper: https://arxiv.org/abs/2203.07544

In [119]:
# Pick an evaluator

from pykeen.evaluation import RankBasedEvaluator
from pykeen.evaluation import ClassificationEvaluator
rank_evaluator = RankBasedEvaluator()
prauc_evaluator = ClassificationEvaluator()

In [120]:
# Evaluate base model without rules
results_without_rules =  rank_evaluator.evaluate(
    model=model_without_rules,
    mapped_triples=testing_no_rules_factory.mapped_triples,
    additional_filter_triples=[
        no_rules_training_triples.mapped_triples
    ],
)

Evaluating on cpu:   0%|          | 0.00/10.0k [00:00<?, ?triple/s]

In [121]:
results_without_rules.to_dict()

{'head': {'optimistic': {'adjusted_arithmetic_mean_rank_index': 0.15849821275947829,
   'adjusted_inverse_harmonic_mean_rank': 0.004017321663151037,
   'harmonic_mean_rank': 203.08676565276016,
   'variance': 10638845.85220839,
   'inverse_arithmetic_mean_rank': 0.00021922069629272156,
   'adjusted_arithmetic_mean_rank': 0.8415310271521191,
   'standard_deviation': 3261.724367908544,
   'z_arithmetic_mean_rank': 27.450162850319174,
   'inverse_median_rank': 0.0002346591575736243,
   'arithmetic_mean_rank': 4561.6131,
   'median_rank': 4261.5,
   'z_geometric_mean_rank': 34.302017914958,
   'count': 10000.0,
   'inverse_geometric_mean_rank': 0.00038098525955144015,
   'adjusted_geometric_mean_rank_index': 0.34227114981202744,
   'z_inverse_harmonic_mean_rank': 32.6728742428985,
   'inverse_harmonic_mean_rank': 0.004924003771421572,
   'median_absolute_deviation': 4186.86866505982,
   'geometric_mean_rank': 2624.773465454721,
   'hits_at_1': 0.0008,
   'hits_at_3': 0.0045,
   'hits_at_5'

In [123]:
results_with_rules = rank_evaluator.evaluate(
    model=model_with_rules,
    mapped_triples=testing_rules_factory.mapped_triples,
    additional_filter_triples=[
        rules_training_triples.mapped_triples
    ],
)

Evaluating on cpu:   0%|          | 0.00/10.0k [00:00<?, ?triple/s]

In [124]:
results_with_rules.to_dict()

{'head': {'optimistic': {'adjusted_arithmetic_mean_rank_index': 0.16059114200344415,
   'adjusted_inverse_harmonic_mean_rank': 0.004338985543246143,
   'harmonic_mean_rank': 190.64414294104554,
   'variance': 10766531.17705399,
   'inverse_arithmetic_mean_rank': 0.00021976718348069857,
   'adjusted_arithmetic_mean_rank': 0.8394384840160561,
   'standard_deviation': 3281.239274581174,
   'z_arithmetic_mean_rank': 27.81263538262155,
   'inverse_median_rank': 0.00023866348448687351,
   'arithmetic_mean_rank': 4550.2699,
   'median_rank': 4190.0,
   'z_geometric_mean_rank': 34.96052539244559,
   'count': 10000.0,
   'inverse_geometric_mean_rank': 0.0003848282479039973,
   'adjusted_geometric_mean_rank_index': 0.34884184517327277,
   'z_inverse_harmonic_mean_rank': 35.28896505347471,
   'inverse_harmonic_mean_rank': 0.005245374888381639,
   'median_absolute_deviation': 4247.655356018549,
   'geometric_mean_rank': 2598.5618401107317,
   'hits_at_1': 0.0012,
   'hits_at_3': 0.0046,
   'hits_a