# Probability Calibration in KG Embedding
This experiemnt is to investigate which calibration technique is the most suitable one given a dataset and a KG Embedding model.

Within this experiment, we are going to see the performance of 4 typical calibration techniques for 4 KGE models in 3 datasets:
- calibration techniques:
  - Platt Scaling
  - Isotonic Regression
  - Histogram BInning
  - Beta Calibration
- KG Embedding models
  - TransE
  - ComplEx
  - DistMult
  - HoLE
- Datasets
  - FB13k
  - Win11
  - Yago39

In [12]:
import sys
# enable importing the modules from probcalkge
sys.path.append('../')
sys.path.append('../probcalkge')

In [13]:
import importlib
from pprint import pprint
import numpy as np
import pandas as pd

In [14]:
from ampligraph.latent_features import RandomBaseline, TransE
import probcalkge
importlib.reload(probcalkge)
from probcalkge import Experiment
from probcalkge import get_calibrators
from probcalkge import get_datasets, get_fb13, get_wn11, get_kgemodels, get_yago39
from probcalkge import brier_score, negative_log_loss, ks_error

In [11]:
ds = get_datasets()
cals = get_calibrators()
kges = get_kgemodels()


Exception ignored in: <function ScopedTFGraph.__del__ at 0x0000014BA1921288>
Traceback (most recent call last):
  File "c:\Users\s1904162\Downloads\Portable Python-3.7.9 x64\App\Python\lib\site-packages\tensorflow_core\python\framework\c_api_util.py", line 51, in __del__
    if c_api is not None and c_api.TF_DeleteGraph is not None:
KeyboardInterrupt




In [18]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, cals.histbin, cals.beta, cals.temperature], 
    datasets=[ds.fb13, ds.wn18, ds.yago39, ds.dp50, ds.nations, ds.kinship, ds.umls], 
    kges=[kges.transE, kges.complEx, kges.distMult, kges.hoLE], 
    metrics=[brier_score, negative_log_loss, ks_error]
    )

In [19]:
# exp.load_trained_kges('../saved_models/')
exp.train_kges()

training TransE on FB13k ...


Average TransE Loss:   1.088063: 100%|██████████| 100/100 [12:49<00:00,  7.70s/epoch]


training TransE on WN11 ...


Average TransE Loss:   0.960328: 100%|██████████| 100/100 [05:52<00:00,  3.52s/epoch]


training TransE on YAGO39 ...


Average TransE Loss:   0.908081: 100%|██████████| 100/100 [10:04<00:00,  6.04s/epoch]


training TransE on DBpedia50 ...


Average TransE Loss:   1.027427: 100%|██████████| 100/100 [03:23<00:00,  2.04s/epoch]


training TransE on Nations ...


Average TransE Loss:   1.373538: 100%|██████████| 100/100 [00:07<00:00, 13.75epoch/s]


training TransE on Kinship ...


Average TransE Loss:   1.363471: 100%|██████████| 100/100 [00:10<00:00,  9.58epoch/s]


training TransE on UMLS ...


Average TransE Loss:   1.205289: 100%|██████████| 100/100 [00:09<00:00, 10.34epoch/s]


training ComplEx on FB13k ...


Average ComplEx Loss:   1.064549:   7%|▋         | 7/100 [02:16<30:17, 19.55s/epoch]


KeyboardInterrupt: 

In [17]:
exp_res = exp.run_with_trained_kges()

training various calibrators for TransE on FB13k ...
True
[]
[]
training various calibrators for ComplEx on FB13k ...
True
[]
[]
training various calibrators for DistMult on FB13k ...
True
[]
[]
training various calibrators for HolE on FB13k ...
True
[]
[]
training various calibrators for TransE on WN11 ...
True
[]
[]
training various calibrators for ComplEx on WN11 ...
True
[]
[]
training various calibrators for DistMult on WN11 ...
True
[]
[]
training various calibrators for HolE on WN11 ...
True
[]
[]
training various calibrators for TransE on YAGO39 ...
True
[]
[]
training various calibrators for ComplEx on YAGO39 ...
False
[]
[1. 1. 1. 1. 1. 1. 1. 1. 1.]


RuntimeError: all elements of input should be between 0 and 1

In [None]:
exp_res.to_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,ExpRes
dataset,kge,cal,metric,Unnamed: 4_level_1
FB13k,TransE,UncalCalibrator,brier_score,0.242014
FB13k,TransE,UncalCalibrator,negative_log_loss,0.676226
FB13k,TransE,UncalCalibrator,ks_error,0.098866
FB13k,TransE,PlattCalibrator,brier_score,0.212112
FB13k,TransE,PlattCalibrator,negative_log_loss,0.616217
...,...,...,...,...
UMLS,HolE,HistogramBinningCalibrator,negative_log_loss,0.321492
UMLS,HolE,HistogramBinningCalibrator,ks_error,0.016966
UMLS,HolE,BetaCalibrator,brier_score,0.096088
UMLS,HolE,BetaCalibrator,negative_log_loss,0.317385


In [None]:
exp._train_cal_and_eval(exp.trained_kges['YAGO39']['ComplEx'], ds.yago39)

training various calibrators for ComplEx on YAGO39 ...


RuntimeError: all elements of input should be between 0 and 1