In [2]:
import sys
# enable importing the modules from probcalkge
sys.path.append('../')
sys.path.append('../probcalkge')

In [3]:
import importlib
from pprint import pprint
import numpy as np
import pandas as pd

In [60]:
from probcalkge import Experiment, DatasetWrapper, ExperimentDatasets
from probcalkge import get_datasets, get_calibrators, get_kgemodels
from probcalkge import brier_score, negative_log_loss, ks_error

In [22]:
ds = get_datasets()
cals = get_calibrators()
kges = get_kgemodels()



In [61]:
def dataset_stats(ds: ExperimentDatasets) -> pd.DataFrame:
    vals = []
    rows = []
    cols = ['train', 'valid', 'test']
    for i in ds:
        rows.append(i.name)
        vals.append(i.stats)
    return pd.DataFrame(vals, index=rows, columns=cols)

In [27]:
df = dataset_stats(ds)
df

Unnamed: 0,train,valid,test
FB13k,316232,11816,47464
WN11,110361,4877,19706
YAGO39,354994,18474,18514
DBpedia50,32388,246,4196
UMLS,5216,1304,1322
Kinship,8544,2136,2148
Nations,1592,398,402


In [28]:
import random
from probcalkge import DatasetWrapper

def shrink_dataset(ds: DatasetWrapper, perc=0.5) -> DatasetWrapper:
    '''
    Shrink the size of dataset training samples by randon sampling
    '''
    new_X_train = random.sample(ds.X_train.tolist(), int(perc * len(ds.X_train)))
    new_X_train = np.array(new_X_train)
    new_X_ents = set(new_X_train[:, 0]).union(set(new_X_train[:, 2]))
    new_X_rels = set(new_X_train[:, 1])

    new_X_valid, new_y_valid = [], []
    for X, y in zip(ds.X_valid, ds.y_valid):
        if X[0] in new_X_ents and X[2] in new_X_ents and X[1] in new_X_rels:
            new_X_valid.append(X.tolist())
            new_y_valid.append(y)
    new_X_valid, new_y_valid = np.array(new_X_valid), np.array(new_y_valid)
    
    new_X_test, new_y_test = [], []
    for X, y in zip(ds.X_test, ds.y_test):
        if X[0] in new_X_ents and X[2] in new_X_ents and X[1] in new_X_rels:
            new_X_test.append(X.tolist())
            new_y_test.append(y)
    new_X_test, new_y_test = np.array(new_X_test), np.array(new_y_test)

    return DatasetWrapper(f'{ds.name}-shrinked({perc})',
                            new_X_train,
                            new_X_valid, new_y_valid,
                            new_X_test, new_y_test
    )


In [41]:
new_fb = shrink_dataset(ds.fb13, perc=0.1)
new_wn = shrink_dataset(ds.wn18, perc=0.1)

In [51]:
exp = Experiment(
    cals=[cals.uncal, cals.platt, cals.isot, cals.histbin, cals.beta], 
    datasets=[new_fb, new_wn], 
    kges=[kges.transE, kges.complEx, kges.distMult, kges.hoLE], 
    metrics=[brier_score, negative_log_loss, ks_error]
    )

In [54]:
exp.run()

training TransE on FB13k-shrinked(0.1) ...


Average TransE Loss:   1.136442: 100%|██████████| 100/100 [04:08<00:00,  2.48s/epoch]


training TransE on WN11-shrinked(0.1) ...


Average TransE Loss:   0.979976: 100%|██████████| 100/100 [02:01<00:00,  1.21s/epoch]


training ComplEx on FB13k-shrinked(0.1) ...


Average ComplEx Loss:   0.068007: 100%|██████████| 100/100 [07:44<00:00,  4.65s/epoch]


training ComplEx on WN11-shrinked(0.1) ...


Average ComplEx Loss:   0.107224: 100%|██████████| 100/100 [03:45<00:00,  2.25s/epoch]


training DistMult on FB13k-shrinked(0.1) ...


Average DistMult Loss:   0.121999: 100%|██████████| 100/100 [03:53<00:00,  2.34s/epoch]


training DistMult on WN11-shrinked(0.1) ...


Average DistMult Loss:   0.493793: 100%|██████████| 100/100 [01:54<00:00,  1.14s/epoch]


training HolE on FB13k-shrinked(0.1) ...


Average HolE Loss:   0.530168: 100%|██████████| 100/100 [07:47<00:00,  4.67s/epoch]


training HolE on WN11-shrinked(0.1) ...


Average HolE Loss:   0.747235: 100%|██████████| 100/100 [03:47<00:00,  2.28s/epoch]


training various calibrators for TransE on FB13k-shrinked(0.1) ...
True
training various calibrators for ComplEx on FB13k-shrinked(0.1) ...
False
training various calibrators for DistMult on FB13k-shrinked(0.1) ...
False
training various calibrators for HolE on FB13k-shrinked(0.1) ...
True
training various calibrators for TransE on WN11-shrinked(0.1) ...
True
training various calibrators for ComplEx on WN11-shrinked(0.1) ...
True
training various calibrators for DistMult on WN11-shrinked(0.1) ...
True
training various calibrators for HolE on WN11-shrinked(0.1) ...
True
{'FB13k-shrinked(0.1)': {'TransE':                    UncalCalibrator  PlattCalibrator  IsotonicCalibrator  \
brier_score               0.421575         0.226644            0.224566   
negative_log_loss         1.266595         0.644602            0.652227   
ks_error                  0.437952         0.015020            0.003566   

                   HistogramBinningCalibrator  BetaCalibrator  
brier_score             

<experiment.ExperimentResult at 0x29aba5cb388>

In [56]:
# exp_res = exp.run_with_trained_kges()
exp_res = _

In [57]:
exp_res.to_frame()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,ExpRes
dataset,kge,cal,metric,Unnamed: 4_level_1
FB13k-shrinked(0.1),TransE,UncalCalibrator,brier_score,0.421575
FB13k-shrinked(0.1),TransE,UncalCalibrator,negative_log_loss,1.266595
FB13k-shrinked(0.1),TransE,UncalCalibrator,ks_error,0.437952
FB13k-shrinked(0.1),TransE,PlattCalibrator,brier_score,0.226644
FB13k-shrinked(0.1),TransE,PlattCalibrator,negative_log_loss,0.644602
...,...,...,...,...
WN11-shrinked(0.1),HolE,HistogramBinningCalibrator,negative_log_loss,0.710289
WN11-shrinked(0.1),HolE,HistogramBinningCalibrator,ks_error,0.012714
WN11-shrinked(0.1),HolE,BetaCalibrator,brier_score,0.245186
WN11-shrinked(0.1),HolE,BetaCalibrator,negative_log_loss,0.683545


In [58]:
newdf = exp_res.to_frame().pivot_table(values='ExpRes', index=['dataset', 'kge', 'metric'], columns=['cal'])

In [59]:
newdf.idxmin(axis=1)

dataset              kge       metric           
FB13k-shrinked(0.1)  TransE    brier_score                  IsotonicCalibrator
                               negative_log_loss                BetaCalibrator
                               ks_error                     IsotonicCalibrator
                     ComplEx   brier_score          HistogramBinningCalibrator
                               negative_log_loss    HistogramBinningCalibrator
                               ks_error                     IsotonicCalibrator
                     DistMult  brier_score          HistogramBinningCalibrator
                               negative_log_loss    HistogramBinningCalibrator
                               ks_error                     IsotonicCalibrator
                     HolE      brier_score          HistogramBinningCalibrator
                               negative_log_loss    HistogramBinningCalibrator
                               ks_error                     IsotonicCalibrator
WN1

In [62]:
new_fb.stats

(31623, 4425, 17749)

In [66]:
new_wn.stats

(11036, 1218, 5169)