# Probability Calibration in KG Embedding
This experiemnt is to investigate which calibration technique is the most suitable one given a dataset and a KG Embedding model.

Within this experiment, we are going to see the performance of 4 typical calibration techniques for 4 KGE models in 3 datasets:
- calibration techniques:
  - Platt Scaling
  - Isotonic Regression
  - Histogram BInning
  - Beta Calibration
- KG Embedding models
  - TransE
  - ComplEx
  - DistMult
  - HoLE
- Datasets
  - FB13k
  - Win11
  - Yago39

In [1]:
import sys
# enable importing the modules from probcalkge
sys.path.append('../')
sys.path.append('../probcalkge')

In [2]:
import importlib
from pprint import pprint
import numpy as np
import pandas as pd

In [3]:
from ampligraph.latent_features import RandomBaseline, TransE
import probcalkge
importlib.reload(probcalkge)
from probcalkge import Experiment
from probcalkge import get_calibrators
from probcalkge import get_datasets, get_fb13, get_wn11, get_kgemodels
from probcalkge import brier_score, negative_log_loss, ks_error

In [4]:
# ds = get_datasets()
cals = get_calibrators()
kges = get_kgemodels()




In [5]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, cals.histbin, cals.beta], 
    datasets=[get_fb13(), get_wn11()], 
    kges=[kges.transE, kges.complEx, kges.distMult, kges.hoLE], 
    metrics=[brier_score, negative_log_loss, ks_error]
    )

In [6]:
exp.load_trained_kges('../saved_models/07-16_15-18-09')
exp_res = exp.run_with_trained_kges()

Loaded models:
{'FB13k': {'ComplEx': <ampligraph.latent_features.models.ComplEx.ComplEx object at 0x000002445F53ED08>,
           'DistMult': <ampligraph.latent_features.models.DistMult.DistMult object at 0x00000244608407C8>,
           'HolE': <ampligraph.latent_features.models.HolE.HolE object at 0x00000244628878C8>,
           'TransE': <ampligraph.latent_features.models.TransE.TransE object at 0x00000244642B5A08>},
 'WN11': {'ComplEx': <ampligraph.latent_features.models.ComplEx.ComplEx object at 0x000002445F53EF08>,
          'DistMult': <ampligraph.latent_features.models.DistMult.DistMult object at 0x00000244608412C8>,
          'HolE': <ampligraph.latent_features.models.HolE.HolE object at 0x00000244616EC9C8>,
          'TransE': <ampligraph.latent_features.models.TransE.TransE object at 0x00000244628874C8>}}
{'FB13k': {'TransE':                    UncalCalibtator  PlattCalibtator  IsotonicCalibrator  \
brier_score               0.242014         0.212112            0.206325   
ne

In [7]:
exp_res.to_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,ExpRes
dataset,kge,cal,metric,Unnamed: 4_level_1
FB13k,TransE,UncalCalibtator,brier_score,0.242014
FB13k,TransE,UncalCalibtator,negative_log_loss,0.676226
FB13k,TransE,UncalCalibtator,ks_error,0.098866
FB13k,TransE,PlattCalibtator,brier_score,0.212112
FB13k,TransE,PlattCalibtator,negative_log_loss,0.616217
...,...,...,...,...
WN11,HolE,HistogramBinningCalibtator,negative_log_loss,0.557775
WN11,HolE,HistogramBinningCalibtator,ks_error,0.004432
WN11,HolE,BetaCalibtator,brier_score,0.197292
WN11,HolE,BetaCalibtator,negative_log_loss,0.568626


In [9]:
exp_res.slice(cal='PlattCalibtator', kge='TransE')

dataset,FB13k,WN11
metric,Unnamed: 1_level_1,Unnamed: 2_level_1
brier_score,0.212112,0.090147
negative_log_loss,0.616217,0.307826
ks_error,0.02772,0.024209


['BetaCalibtator']

{'FB13k': {'TransE': <ampligraph.latent_features.models.TransE.TransE at 0x219cd682648>,
  'ComplEx': <ampligraph.latent_features.models.ComplEx.ComplEx at 0x219cf734808>,
  'DistMult': <ampligraph.latent_features.models.DistMult.DistMult at 0x2198029cd48>,
  'HolE': <ampligraph.latent_features.models.HolE.HolE at 0x21982d9ed48>},
 'WN11': {'TransE': <ampligraph.latent_features.models.TransE.TransE at 0x219cf435f48>,
  'ComplEx': <ampligraph.latent_features.models.ComplEx.ComplEx at 0x219d0731e88>,
  'DistMult': <ampligraph.latent_features.models.DistMult.DistMult at 0x21982ae17c8>,
  'HolE': <ampligraph.latent_features.models.HolE.HolE at 0x219861d2f88>}}

In [None]:
import os
from datetime import datetime

SAVE_MODEL_PATH = os.path.join('../saved_models/', datetime.now().strftime('%m-%d_%H-%M-%S'))
LOAD_MODEL_PATH = '../saved_models/07-16_15-18-09'
os.mkdir(SAVE_MODEL_PATH)
print('made model directory:', SAVE_MODEL_PATH)
from ampligraph.utils import save_model
for ds, models in exp.trained_kge.items():
    for mname, model in models.items():
        save_model(model, os.path.join(SAVE_MODEL_PATH, f'{ds}-{mname}.pkl'))

made model directory: ../saved_models/07-16_15-18-09
