In [1]:
import sys
# enable importing the modules from probcalkge
sys.path.append('../')
sys.path.append('../probcalkge')

import numpy as np
import pandas as pd

from probcalkge import Experiment, ExperimentResult
from probcalkge import get_calibrators
from probcalkge import get_datasets,  get_kgemodels
from probcalkge import brier_score, negative_log_loss, ks_error, ece

In [2]:
cals = get_calibrators()
kges = get_kgemodels()
ds = get_datasets()



In [3]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, ], 
    datasets=[ds.fb13, ds.wn18, ds.yago39, ds.dp50, ds.nations, ds.kinship, ds.umls], 
    kges=[kges.transE, kges.complEx, kges.distMult, kges.hoLE], 
    metrics=[ece]
    )

In [4]:
exp.load_trained_kges('../saved_models/')

Loaded models:
{'DBpedia50': {'ComplEx': <ampligraph.latent_features.models.ComplEx.ComplEx object at 0x000002134F864908>,
               'DistMult': <ampligraph.latent_features.models.DistMult.DistMult object at 0x00000213008DB048>,
               'HolE': <ampligraph.latent_features.models.HolE.HolE object at 0x0000021303F9FFC8>,
               'TransE': <ampligraph.latent_features.models.TransE.TransE object at 0x0000021307CA4048>},
 'FB13k': {'ComplEx': <ampligraph.latent_features.models.ComplEx.ComplEx object at 0x00000213005C6D88>,
           'DistMult': <ampligraph.latent_features.models.DistMult.DistMult object at 0x0000021304185208>,
           'HolE': <ampligraph.latent_features.models.HolE.HolE object at 0x0000021379817D48>,
           'TransE': <ampligraph.latent_features.models.TransE.TransE object at 0x0000021379777508>},
 'Kinship': {'ComplEx': <ampligraph.latent_features.models.ComplEx.ComplEx object at 0x0000021303F9F948>,
             'DistMult': <ampligraph.latent_fea

In [5]:
exp_res = exp.run_with_trained_kges()

training various calibrators for TransE on FB13k ...
training various calibrators for ComplEx on FB13k ...
training various calibrators for DistMult on FB13k ...
training various calibrators for HolE on FB13k ...
training various calibrators for TransE on WN11 ...
training various calibrators for ComplEx on WN11 ...
training various calibrators for DistMult on WN11 ...
training various calibrators for HolE on WN11 ...
training various calibrators for TransE on YAGO39 ...
training various calibrators for ComplEx on YAGO39 ...
training various calibrators for DistMult on YAGO39 ...
training various calibrators for HolE on YAGO39 ...
training various calibrators for TransE on DBpedia50 ...
training various calibrators for ComplEx on DBpedia50 ...
training various calibrators for DistMult on DBpedia50 ...
training various calibrators for HolE on DBpedia50 ...
training various calibrators for TransE on Nations ...
training various calibrators for ComplEx on Nations ...
training various cali

In [6]:
from sklearn.metrics import accuracy_score
for data in ds:
    if data.name == 'YAGO_ET' or data.name == 'DBpedia_ET':
        continue
    for kge in exp.trained_kges[data.name].values():
        for cal in exp.trained_cals[data.name][kge.name].values():
            try:
                scores = kge.predict(data.X_test)
                pred = cal.predict(scores) > 0.5
                true = data.y_test
                print(data.name, kge.name, cal.name, accuracy_score(true, pred))
            except Exception as e:
                print(e)

    

FB13k TransE UncalCalibrator 0.4999789314006405
FB13k TransE PlattCalibrator 0.6719829765717175
FB13k TransE IsotonicCalibrator 0.6721936625653127
FB13k ComplEx UncalCalibrator 0.5538934771616383
FB13k ComplEx PlattCalibrator 0.6476908815101972
FB13k ComplEx IsotonicCalibrator 0.6927987527389179
FB13k DistMult UncalCalibrator 0.5748356649249958
FB13k DistMult PlattCalibrator 0.6133912017529075
FB13k DistMult IsotonicCalibrator 0.6421077026799258
FB13k HolE UncalCalibrator 0.5574119332546772
FB13k HolE PlattCalibrator 0.4878644867689196
FB13k HolE IsotonicCalibrator 0.6391159615708748
WN11 TransE UncalCalibrator 0.5074596569572719
WN11 TransE PlattCalibrator 0.882624581345783
WN11 TransE IsotonicCalibrator 0.8812544402719984
WN11 ComplEx UncalCalibrator 0.5592205419669136
WN11 ComplEx PlattCalibrator 0.5898711052471328
WN11 ComplEx IsotonicCalibrator 0.6226529990865726
WN11 DistMult UncalCalibrator 0.566020501370141
WN11 DistMult PlattCalibrator 0.6024561047396731
WN11 DistMult Isotonic

In [7]:
from scipy.special import expit
scores

array([ 0.37969097, -0.73925155,  2.517384  , -1.3643224 ,  1.8427471 ,
        1.7951121 , -0.68656415, -0.32101703,  2.9898565 , -0.30531   ,
        1.6717205 , -1.1341385 ,  2.8170815 , -1.6118786 ,  0.8553423 ,
       -0.00628156,  0.19329514, -0.16634853, -1.25345   , -0.37864637,
        1.9667399 , -0.3548869 ,  0.271867  , -0.3320169 , -1.1637987 ,
       -0.14545198,  2.6770198 , -0.61331856,  0.32798922, -1.4120852 ,
       -0.23974207, -1.045959  , -0.508881  ,  1.3710321 ,  0.82276124,
       -0.91255826, -0.6156652 , -0.6748142 ,  0.7169288 , -1.0101277 ,
       -1.4975625 ,  2.1100893 , -0.59926057, -1.6275283 , -0.7638622 ,
       -1.8272467 ,  0.49131042, -0.6721886 ,  2.1571662 , -1.9922723 ,
       -0.23249047,  2.1627636 , -0.5254793 ,  0.83464986,  1.2686589 ,
        0.80709255, -1.4557071 ,  3.3463652 ,  2.3658786 ,  2.249756  ,
        1.4055388 , -0.02009424, -1.2117125 , -1.4751886 , -1.4585468 ,
        1.4466853 , -0.21765012,  1.6046041 ,  0.56723434, -0.26

In [8]:
exp2 = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, ], 
    datasets=[ds.yago_et, ds.dbpedia_et], 
    kges=[kges.transE, kges.complEx, kges.distMult, kges.hoLE], 
    metrics=[ece]
    )

In [12]:
exp2.train_kges()
exp2.save_trained_kges('../saved_models/')

training TransE on YAGO_ET ...


Average TransE Loss:   0.954750: 100%|██████████| 100/100 [34:47<00:00, 20.87s/epoch]


training TransE on DBpedia_ET ...


Average TransE Loss:   0.971890: 100%|██████████| 100/100 [1:13:58<00:00, 44.39s/epoch]


training ComplEx on YAGO_ET ...


Average ComplEx Loss:   0.206788: 100%|██████████| 100/100 [1:57:58<00:00, 70.78s/epoch]


training ComplEx on DBpedia_ET ...


Average ComplEx Loss:   0.332118: 100%|██████████| 100/100 [6:30:26<00:00, 234.26s/epoch]   


training DistMult on YAGO_ET ...


Average DistMult Loss:   0.369770: 100%|██████████| 100/100 [31:56<00:00, 19.17s/epoch]


training DistMult on DBpedia_ET ...


Average DistMult Loss:   0.638364: 100%|██████████| 100/100 [1:13:08<00:00, 43.89s/epoch]


training HolE on YAGO_ET ...


Average HolE Loss:   0.964465: 100%|██████████| 100/100 [2:08:18<00:00, 76.99s/epoch]


training HolE on DBpedia_ET ...


Average HolE Loss:   1.370011:  30%|███       | 30/100 [1:18:55<3:04:08, 157.84s/epoch]


KeyboardInterrupt: 

In [13]:
exp2.save_trained_kges('../saved_models/')

saved ../saved_models/YAGO_ET-TransE.pkl.
saved ../saved_models/YAGO_ET-ComplEx.pkl.
saved ../saved_models/YAGO_ET-DistMult.pkl.
saved ../saved_models/YAGO_ET-HolE.pkl.
saved ../saved_models/DBpedia_ET-TransE.pkl.
saved ../saved_models/DBpedia_ET-ComplEx.pkl.
saved ../saved_models/DBpedia_ET-DistMult.pkl.


AttributeError: 'NoneType' object has no attribute 'all_params'