In [1]:
import sys
# enable importing the modules from probcalkge
sys.path.append('../')
sys.path.append('../probcalkge')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from ampligraph.latent_features import ComplEx
from ampligraph.utils import save_model, restore_model
from sklearn.metrics import accuracy_score
from sklearn.calibration import CalibrationDisplay
from scipy.special import expit

from probcalkge import Experiment, ExperimentResult
from probcalkge import get_calibrators
from probcalkge import get_datasets,  get_kgemodels
from probcalkge import brier_score, negative_log_loss, ks_error, ece

In [2]:
cals = get_calibrators()
kges = get_kgemodels()
ds = get_datasets()



In [5]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, ], 
    datasets=[ds.yago39], 
    kges=[ kges.distMult, kges.hoLE], 
    metrics=[ece]
    )

In [3]:

cpm1 = restore_model('../saved_models/FB13k-ComplEx.pkl')
cpm2 = restore_model('complex_nll_reg.pkl')

cpm3 = restore_model('../saved_models/WN11-ComplEx.pkl')
cpm4 = restore_model('complex_nll_reg_wn.pkl')

cpm5 = restore_model('../saved_models/YAGO39-ComplEx.pkl')
cpm6 = restore_model('complex_nll_reg_yg.pkl')


In [6]:
from probcalkge import brier_score, get_cls_name
from sklearn.metrics import accuracy_score

def stats(model, data):
    print(get_cls_name(model), data.name)
    score = model.predict(data.X_test)
    probs = expit(score)
    return {
        'BS': brier_score(data.y_test, probs),
        'NLL': negative_log_loss(data.y_test, probs),
        'ECE': ece(data.y_test, probs),
        'ACC': accuracy_score(data.y_test, probs>0.5)
    }

In [10]:
print(stats(cpm1, ds.fb13))
print(stats(cpm2, ds.fb13))
print(stats(cpm3, ds.wn18))
print(stats(cpm4, ds.wn18))
print(stats(cpm5, ds.yago39))
print(stats(cpm6, ds.yago39))


ComplEx FB13k
{'BS': 0.34087604756408535, 'NLL': 1.1358389222927454, 'ECE': 0.2971215639851043, 'ACC': 0.5546519467385809}
ComplEx FB13k
{'BS': 0.28289896739985426, 'NLL': 0.8520768830593812, 'ECE': 0.22725086202027317, 'ACC': 0.6445937974043485}
ComplEx WN11
{'BS': 0.21823013386012838, 'NLL': 0.61923571632016, 'ECE': 0.09829177064005401, 'ACC': 0.6127575357759059}
ComplEx WN11
{'BS': 0.21148498136546337, 'NLL': 0.6180642181848163, 'ECE': 0.1139539661112475, 'ACC': 0.7053181772049122}
ComplEx YAGO39
{'BS': 0.236838019833028, 'NLL': 1.1180476823235181, 'ECE': 0.21355257483021275, 'ACC': 0.7251850651104987}
ComplEx YAGO39
{'BS': 0.20742494417307547, 'NLL': 0.7036543243983794, 'ECE': 0.19394461117915468, 'ACC': 0.7378289295942076}


In [11]:
import os
folder = r'F:\TREAT\kgcal\data-output\nll_100'
nll100 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\nll_100_reg'
nll100reg = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\nll_200'
nll200 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\nll_200_reg'
nll200reg = nll200 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

In [12]:
for m, data in zip(nll100, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))
    

ComplEx FB13k
{'BS': 0.3385623476617665, 'NLL': 1.1238786534586995, 'ECE': 0.29698127123219886, 'ACC': 0.5605511545592449}
ComplEx WN11
{'BS': 0.21457739982520513, 'NLL': 0.611389728233577, 'ECE': 0.10035214100082576, 'ACC': 0.6267634223079265}
ComplEx YAGO39
{'BS': 0.233304019267225, 'NLL': 1.0921743873702785, 'ECE': 0.2100945055801553, 'ACC': 0.7254552331550225}


In [13]:
for m, data in zip(nll100reg, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.33040020928842395, 'NLL': 1.05243193629922, 'ECE': 0.2972321411277657, 'ACC': 0.5562952974886229}
ComplEx WN11
{'BS': 0.21263925633156505, 'NLL': 0.6075558812332627, 'ECE': 0.11812873590065721, 'ACC': 0.6331066680198925}
ComplEx YAGO39
{'BS': 0.24254226955633879, 'NLL': 1.003494769776883, 'ECE': 0.22192509207115274, 'ACC': 0.7083265791322202}


In [None]:
for m, data in zip(nll200, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.3692481154185115, 'NLL': 1.6349988018906685, 'ECE': 0.3423706144943827, 'ACC': 0.585243553008596}
ComplEx WN11
{'BS': 0.31509702934211065, 'NLL': 1.0806541427920484, 'ECE': 0.3113624292561844, 'ACC': 0.6099157616969451}
ComplEx YAGO39
{'BS': 0.18819364883996217, 'NLL': 0.8173243217274974, 'ECE': 0.1715817111932596, 'ACC': 0.7724644729021451}


In [None]:
folder = r'F:\TREAT\kgcal\data-output\am_100'
am100 = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

folder = r'F:\TREAT\kgcal\data-output\am_100_reg'
am100reg = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]

In [None]:
for m, data in zip(nll200reg, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.3692481154185115, 'NLL': 1.6349988018906685, 'ECE': 0.3423706144943827, 'ACC': 0.585243553008596}
ComplEx WN11
{'BS': 0.31509702934211065, 'NLL': 1.0806541427920484, 'ECE': 0.3113624292561844, 'ACC': 0.6099157616969451}
ComplEx YAGO39
{'BS': 0.18819364883996217, 'NLL': 0.8173243217274974, 'ECE': 0.1715817111932596, 'ACC': 0.7724644729021451}


In [None]:
for m, data in zip(am100, [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

ComplEx FB13k
{'BS': 0.3276006406724404, 'NLL': 2.281402174867122, 'ECE': 0.33469375273672003, 'ACC': 0.6015085117141412}
ComplEx WN11
{'BS': 0.22845736997785257, 'NLL': 0.646433796118873, 'ECE': 0.14880674689269527, 'ACC': 0.6182888460367401}
ComplEx YAGO39
{'BS': 0.48235713753169446, 'NLL': 7.1782716382449445, 'ECE': 0.48116350690792864, 'ACC': 0.5116442427189712}


In [3]:
import os
folder = r'C:\Users\s1904162\Downloads\kgcal\distmult'
distms = [restore_model(os.path.join(folder, i)) for i in os.listdir(folder)]


In [7]:
for m, data in zip([distms[1], distms[4], distms[7]], [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

DistMult FB13k
{'BS': 0.36901716330141854, 'NLL': 1.5688020553158053, 'ECE': 0.30874837615725315, 'ACC': 0.5756994774987358}
DistMult WN11
{'BS': 0.29666016749928664, 'NLL': 0.9519767140735105, 'ECE': 0.2673518698997735, 'ACC': 0.6163604993403025}
DistMult YAGO39
{'BS': 0.20485342387600067, 'NLL': 1.0795890836601485, 'ECE': 0.18260177722412904, 'ACC': 0.7677095153185282}


In [8]:
for m, data in zip([distms[2], distms[5], distms[8]], [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

DistMult FB13k
{'BS': 0.4052278433635866, 'NLL': 2.4744044963495777, 'ECE': 0.3994369438311742, 'ACC': 0.5630583178830272}
DistMult WN11
{'BS': 0.40066911178071246, 'NLL': 2.4889007664103446, 'ECE': 0.41685810815261776, 'ACC': 0.5716533035623668}
DistMult YAGO39
{'BS': 0.18163512318656264, 'NLL': 1.142190223670159, 'ECE': 0.16360297554170783, 'ACC': 0.7988869076565623}


In [9]:
for m, data in zip([distms[0], distms[3], distms[6]], [ds.fb13, ds.wn18, ds.yago39]):
    print(stats(m, data))

DistMult FB13k
{'BS': 0.36901716614915436, 'NLL': 1.5688020520369588, 'ECE': 0.30874837157583873, 'ACC': 0.5756994774987358}
DistMult WN11
{'BS': 0.2966601628173823, 'NLL': 0.9519766811659367, 'ECE': 0.26735185794098654, 'ACC': 0.6163604993403025}
DistMult YAGO39
{'BS': 0.20485342551980767, 'NLL': 1.0795891002068012, 'ECE': 0.1826017776801625, 'ACC': 0.7677095153185282}
