# Probability Calibration in KG Embedding
This experiemnt is to investigate which calibration technique is the most suitable one given a dataset and a KG Embedding model.

Within this experiment, we are going to see the performance of 4 typical calibration techniques for 4 KGE models in 3 datasets:
- calibration techniques:
  - Platt Scaling
  - Isotonic Regression
  - Histogram BInning
  - Beta Calibration
- KG Embedding models
  - TransE
  - ComplEx
  - DistMult
  - HoLE
- Datasets
  - FB13k
  - Win11
  - Yago39

In [1]:
import sys
# enable importing the modules from probcalkge
sys.path.append('../')
sys.path.append('../probcalkge')

In [2]:
import numpy as np
import pandas as pd

In [3]:
from ampligraph.latent_features import RandomBaseline, TransE
import probcalkge

from probcalkge import Experiment
from probcalkge import get_calibrators
from probcalkge import get_datasets, get_fb13, get_wn11, get_kgemodels
from probcalkge import brier_score, negative_log_loss, ks_error

In [4]:
# ds = get_datasets()
cals = get_calibrators()
kges = get_kgemodels()




In [5]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, cals.histbin, cals.beta], 
    datasets=[get_fb13(), get_wn11()], 
    kges=[kges.transE, kges.complEx, kges.distMult, kges.hoLE], 
    metrics=[brier_score, negative_log_loss, ks_error]
    )

In [6]:
exp.run()

training $TransE on $FB13k ...


Average TransE Loss:   1.088374: 100%|██████████| 100/100 [18:48<00:00, 11.28s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.242014         0.212112            0.206325   
negative_log_loss         0.676226         0.616217            0.600623   
ks_error                  0.098866         0.027720            0.004463   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.215394        0.209044  
negative_log_loss                    0.618863        0.609719  
ks_error                             0.004041        0.016290  
training $TransE on $WN11 ...


Average TransE Loss:   0.960578: 100%|██████████| 100/100 [08:23<00:00,  5.04s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.241891         0.090147            0.087408   
negative_log_loss         0.700540         0.307826            0.298104   
ks_error                  0.310065         0.024209            0.003957   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.087826        0.089906  
negative_log_loss                    0.299751        0.308226  
ks_error                             0.004747        0.018527  
training $ComplEx on $FB13k ...


Average ComplEx Loss:   0.190559: 100%|██████████| 100/100 [44:16<00:00, 26.56s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.410170         0.222095            0.208440   
negative_log_loss         2.252936         0.633720            0.602848   
ks_error                  0.421845         0.049897            0.002977   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.225157        0.225954  
negative_log_loss                    0.637211        0.642698  
ks_error                             0.003548        0.048580  
training $ComplEx on $WN11 ...


Average ComplEx Loss:   0.008841: 100%|██████████| 100/100 [18:48<00:00, 11.29s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.419791         0.225499            0.223953   
negative_log_loss         2.787404         0.634156            0.631768   
ks_error                  0.433548         0.016142            0.002854   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.225254        0.226225  
negative_log_loss                    0.637924        0.639116  
ks_error                             0.002236        0.014323  
training $DistMult on $FB13k ...


Average DistMult Loss:   0.215396: 100%|██████████| 100/100 [18:36<00:00, 11.16s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.376787         0.228464            0.222810   
negative_log_loss         1.872856         0.644144            0.633132   
ks_error                  0.372544         0.026884            0.003783   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.225781        0.229218  
negative_log_loss                    0.638461        0.646620  
ks_error                             0.003232        0.030662  
training $DistMult on $WN11 ...


Average DistMult Loss:   0.026890: 100%|██████████| 100/100 [08:28<00:00,  5.08s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.401315         0.222539            0.221629   
negative_log_loss         2.289132         0.628291            0.626881   
ks_error                  0.415413         0.016646            0.006268   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.222973        0.223681  
negative_log_loss                    0.633523        0.634880  
ks_error                             0.007008        0.015897  
training $HolE on $FB13k ...


Average HolE Loss:   0.731045: 100%|██████████| 100/100 [44:07<00:00, 26.48s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.314147         0.248683            0.216840   
negative_log_loss         0.935742         0.690483            0.616730   
ks_error                  0.169842         0.065722            0.001976   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.176534        0.235856  
negative_log_loss                    0.530080        0.660579  
ks_error                             0.001829        0.082448  
training $HolE on $WN11 ...


Average HolE Loss:   0.724753: 100%|██████████| 100/100 [18:22<00:00, 11.02s/epoch]


                   UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.210079         0.201418            0.192535   
negative_log_loss         0.597883         0.582792            0.553706   
ks_error                  0.095622         0.025250            0.004744   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.192816        0.197292  
negative_log_loss                    0.557775        0.568626  
ks_error                             0.004432        0.016271  
{'TransE': {'FB13k':                    UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.242014         0.212112            0.206325   
negative_log_loss         0.676226         0.616217            0.600623   
ks_error                  0.098866         0.027720            0.004463   

                   HistogramBinningCalibtator  BetaCalibtator  
brier_score                          0.215394        0.20

ValueError: conflicting sizes for dimension 'cal': length 4 on the data but length 5 on coordinate 'cal'

In [None]:
import probcalkge
import importlib
importlib.reload(probcalkge)

from probcalkge import ExperimentResult



df = pd.DataFrame([[1,2],[3,4]], index=['bs', 'nll'], columns=['uncal', 'platt'])

res = {
    'transE': {
        'fb13': df,
        'wn11':df
    },
    'DistMult': {
        'fb13': df,
        'wn11': df
    },
}

expres = ExperimentResult(exp, res)

[[array([[1, 2],
       [3, 4]]), array([[1, 2],
       [3, 4]])], [array([[1, 2],
       [3, 4]]), array([[1, 2],
       [3, 4]])]]
{'cal': ['UncalCalibtator', 'PlattCalibtator'], 'kge': ['TransE', 'HolE'], 'dataset': ['FB13k'], 'metric': ['brier_score', 'negative_log_loss']}


ValueError: conflicting sizes for dimension 'dataset': length 2 on the data but length 1 on coordinate 'dataset'

In [None]:
print(exp.datasets)

[<caldatasets.DatasetWrapper object at 0x7f5ff95f16d0>]


In [None]:
from probcalkge.calutils import get_cls_name
[get_cls_name(i) for i in exp.cals]

['BetaCalibtator']