In [12]:
import sys
# enable importing the modules from probcalkge
sys.path.append('../')
sys.path.append('../probcalkge')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from ampligraph.latent_features import ComplEx
from ampligraph.utils import save_model, restore_model
from sklearn.metrics import accuracy_score
from sklearn.calibration import CalibrationDisplay
from scipy.special import expit

from probcalkge import Experiment, ExperimentResult
from probcalkge import get_calibrators
from probcalkge import get_datasets,  get_kgemodels
from probcalkge import brier_score, negative_log_loss, ks_error, ece

In [2]:
cals = get_calibrators()
kges = get_kgemodels()
ds = get_datasets()



In [5]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, ], 
    datasets=[ds.yago39], 
    kges=[ kges.distMult, kges.hoLE], 
    metrics=[ece]
    )

In [5]:
exp.train_kges()

training DistMult on YAGO39 ...


Average DistMult Loss:   0.375038: 100%|██████████| 100/100 [38:14<00:00, 22.95s/epoch]


training HolE on YAGO39 ...


Average HolE Loss:   0.774206: 100%|██████████| 100/100 [2:01:33<00:00, 72.94s/epoch] 


In [6]:
exp.save_trained_kges('../saved_models/newyago/')

made model directory: ../saved_models/newyago/
saved ../saved_models/newyago/YAGO39-DistMult.pkl.
saved ../saved_models/newyago/YAGO39-HolE.pkl.


In [7]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, ], 
    datasets=[ds.yago39], 
    kges=[ kges.complEx], 
    metrics=[ece]
    )

training ComplEx on YAGO39 ...


Average ComplEx Loss:   0.165952: 100%|██████████| 100/100 [1:59:29<00:00, 71.70s/epoch] 


saved ../saved_models/newyago/YAGO39-ComplEx.pkl.


In [35]:
cpm4 = ComplEx(loss='nll', verbose=True, epochs=200)
# cpm6 = ComplEx(loss='nll', verbose=True, regularizer='LP', regularizer_params={'p': 3, 'lambda':0.1})

cpm4.fit(ds.wn18.X_train)
save_model(cpm4, 'complex_nll_wn-200.pkl')
# cpm6.fit(ds.yago39.X_train)
# save_model(cpm6, 'complex_nll_reg_yg.pkl')


Average ComplEx Loss:   0.004186: 100%|██████████| 200/200 [36:06<00:00, 10.83s/epoch]


In [73]:

cpm1 = restore_model('../saved_models/FB13k-ComplEx.pkl')
cpm2 = restore_model('complex_nll_reg.pkl')

cpm3 = restore_model('../saved_models/WN11-ComplEx.pkl')
cpm4 = restore_model('complex_nll_reg_wn.pkl')

cpm5 = restore_model('../saved_models/YAGO39-ComplEx.pkl')
cpm6 = restore_model('complex_nll_reg_yg.pkl')


In [74]:
scores1 = cpm1.predict(ds.fb13.X_test)
probs1 = expit(scores1)

scores2 = cpm2.predict(ds.fb13.X_test)
probs2 = expit(scores2)

scores3 = cpm3.predict(ds.wn18.X_test)
probs3 = expit(scores3)

scores4 = cpm4.predict(ds.wn18.X_test)
probs4 = expit(scores4)

scores5 = cpm5.predict(ds.yago39.X_test)
probs5 = expit(scores5)

scores6 = cpm6.predict(ds.yago39.X_test)
probs6 = expit(scores6)

In [76]:
print(stats(cpm1, ds.fb13))
stats(cpm2, ds.fb13)

ComplEx FB13k
{'BS': 0.40704260654273144, 'NLL': 2.2465560127404207, 'ECE': 0.41877499798679346, 'ACC': 0.5538934771616383}
ComplEx FB13k


{'BS': 0.28289896739985426,
 'NLL': 0.8520768830593812,
 'ECE': 0.22725086202027317,
 'ACC': 0.6445937974043485}

In [55]:
from probcalkge import brier_score, get_cls_name
from sklearn.metrics import accuracy_score

def stats(model, data):
    print(get_cls_name(model), data.name)
    score = model.predict(data.X_test)
    probs = expit(score)
    return {
        'BS': brier_score(data.y_test, probs),
        'NLL': negative_log_loss(data.y_test, probs),
        'ECE': ece(data.y_test, probs),
        'ACC': accuracy_score(data.y_test, probs>0.5)
    }

In [58]:
import os
folder = r'F:\TREAT\kgcal\data-output\nll_100'
nll100 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\nll_100_reg'
nll100reg = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\nll_200'
nll200 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\nll_200_reg'
nll200reg = nll200 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

In [69]:
for m, data in zip(nll100, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))
    

ComplEx FB13k
{'BS': 0.3385623476617665, 'NLL': 1.1238786534586995, 'ECE': 0.29698127123219886, 'ACC': 0.5605511545592449}
ComplEx WN11
{'BS': 0.21457739982520513, 'NLL': 0.611389728233577, 'ECE': 0.10035214100082576, 'ACC': 0.6267634223079265}
ComplEx YAGO39
{'BS': 0.233304019267225, 'NLL': 1.0921743873702785, 'ECE': 0.2100945055801553, 'ACC': 0.7254552331550225}


In [71]:
for m, data in zip(nll200, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.3692481154185115, 'NLL': 1.6349988018906685, 'ECE': 0.3423706144943827, 'ACC': 0.585243553008596}
ComplEx WN11
{'BS': 0.31509702934211065, 'NLL': 1.0806541427920484, 'ECE': 0.3113624292561844, 'ACC': 0.6099157616969451}
ComplEx YAGO39
{'BS': 0.18819364883996217, 'NLL': 0.8173243217274974, 'ECE': 0.1715817111932596, 'ACC': 0.7724644729021451}


In [70]:
for m, data in zip(nll100reg, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.33040020928842395, 'NLL': 1.05243193629922, 'ECE': 0.2972321411277657, 'ACC': 0.5562952974886229}
ComplEx WN11
{'BS': 0.21263925633156505, 'NLL': 0.6075558812332627, 'ECE': 0.11812873590065721, 'ACC': 0.6331066680198925}
ComplEx YAGO39
{'BS': 0.24254226955633879, 'NLL': 1.003494769776883, 'ECE': 0.22192509207115274, 'ACC': 0.7083265791322202}


In [72]:
for m, data in zip(nll200reg, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.3692481154185115, 'NLL': 1.6349988018906685, 'ECE': 0.3423706144943827, 'ACC': 0.585243553008596}
ComplEx WN11
{'BS': 0.31509702934211065, 'NLL': 1.0806541427920484, 'ECE': 0.3113624292561844, 'ACC': 0.6099157616969451}
ComplEx YAGO39
{'BS': 0.18819364883996217, 'NLL': 0.8173243217274974, 'ECE': 0.1715817111932596, 'ACC': 0.7724644729021451}


In [77]:
folder = r'F:\TREAT\kgcal\data-output\am_100'
am100 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\am_100_reg'
am100reg = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

In [79]:
for m, data in zip(am100, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.3276006406724404, 'NLL': 2.281402174867122, 'ECE': 0.33469375273672003, 'ACC': 0.6015085117141412}
ComplEx WN11
{'BS': 0.22845736997785257, 'NLL': 0.646433796118873, 'ECE': 0.14880674689269527, 'ACC': 0.6182888460367401}
ComplEx YAGO39
{'BS': 0.48235713753169446, 'NLL': 7.1782716382449445, 'ECE': 0.48116350690792864, 'ACC': 0.5116442427189712}
