In [1]:
import numpy as np
import itertools
import json

from sklearn.isotonic import IsotonicRegression
from sklearn.calibration import calibration_curve, _SigmoidCalibration
from ampligraph.evaluation import evaluate_performance, mr_score, mrr_score, hits_at_n_score, generate_corruptions_for_eval
from sklearn.metrics import brier_score_loss, log_loss, accuracy_score
from scipy.special import expit

from ampligraph.datasets import load_fb13
from ampligraph.latent_features.models import TransE, ComplEx, DistMult, ConvKB, HolE
import types

In [2]:
from generate_corruptions import generate_corruptions, calibration_loss, pos_iso

In [3]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [4]:
X = load_fb13()

In [5]:
X_valid_pos = X['valid'][X['valid_labels']]
X_valid_neg = X['valid'][~X['valid_labels']]

X_test_pos = X['test'][X['test_labels']]
X_test_neg = X['test'][~X['test_labels']]

In [6]:
losses =  ['self_adversarial', 'pairwise', 'nll', 'multiclass_nll']
models = [TransE, DistMult, ComplEx, HolE]


results = []

for m, l in itertools.product(models, losses):
    model = m(batches_count=32, seed=0, epochs=1000, k=100, eta=20,
               optimizer='adam', optimizer_params={'lr':0.0001},
               loss=l, verbose=False)
    
    try:
        model.fit(X['train'])

        scores = model.predict(X['test'])

        model.calibrate(X_valid_pos, batches_count=10, epochs=1000, positive_base_rate=0.5)
        print("pos", model.calibration_parameters)
        probas1 = model.predict_proba(X['test'])

        model.calibrate(X_valid_pos, X_valid_neg)
        print("pos neg", model.calibration_parameters)
        probas2 = model.predict_proba(X['test'])

        val_scores = model.predict(X['valid'])
        ir = IsotonicRegression(out_of_bounds='clip')
        ir.fit(np.squeeze(val_scores).astype(float), (X['valid_labels']).astype(float))
        probas3 = ir.predict(np.squeeze(scores).astype(float))

        model.generate_corruptions = types.MethodType(generate_corruptions, model)
        corruptions = model.generate_corruptions(X_valid_pos, batches_count=10, epochs=1000)
        val_pos_scores = np.squeeze(model.predict(X_valid_pos))
        iso_pos = pos_iso(IsotonicRegression(out_of_bounds='clip'), val_pos_scores, corruptions, positive_base_rate=0.5)
        probas4 = iso_pos.predict(np.squeeze(scores).astype(float))

        sc_pos = pos_iso(_SigmoidCalibration(), val_pos_scores, corruptions, positive_base_rate=0.5)
        print("pos sc", sc_pos.a_, sc_pos.b_)
        probas5 = sc_pos.predict(np.squeeze(scores).astype(float))

        val_neg_scores = np.squeeze(model.predict(X_valid_neg))
        sc_pos_neg = pos_iso(_SigmoidCalibration(), val_pos_scores, val_neg_scores, positive_base_rate=0.5)
        print("pos neg sc", sc_pos_neg.a_, sc_pos_neg.b_)
        probas6 = sc_pos_neg.predict(np.squeeze(scores).astype(float))

        thresholds = {r: np.median(np.sort(val_scores[X['valid'][:, 1] == r])) for r in np.unique(X['valid'][:, 1])}
        thresholds_test = np.vectorize(thresholds.get)(X['test'][:, 1])
        per_relation_acc = accuracy_score(X['test_labels'], scores > thresholds_test)
        
        acc_uncalib = accuracy_score(X['test_labels'], expit(scores) > 0.5)

        acc1 = accuracy_score(X['test_labels'], probas1 > 0.5)
        acc2 = accuracy_score(X['test_labels'], probas2 > 0.5)
        acc3 = accuracy_score(X['test_labels'], probas3 > 0.5)
        acc4 = accuracy_score(X['test_labels'], probas4 > 0.5)
        acc5 = accuracy_score(X['test_labels'], probas5 > 0.5)
        acc6 = accuracy_score(X['test_labels'], probas6 > 0.5)
        
        filter_triples = np.concatenate((X['train'], X_valid_pos, X_test_pos))
        ranks = evaluate_performance(X_test_pos, 
                                     model=model, 
                                     filter_triples=filter_triples,
                                     use_default_protocol=True, 
                                     verbose=False)
    except Exception as e:
        print("Exception: {}".format(e))
        continue
        
    results.append({
        'model': m.__name__,
        'loss': l,
        'brier_score_scores': brier_score_loss(X['test_labels'], expit(scores)),
        'log_loss_scores': log_loss(X['test_labels'], expit(scores), eps=1e-7),
        'brier_score_probas_pos': brier_score_loss(X['test_labels'], probas1),
        'log_loss_probas_pos': log_loss(X['test_labels'], probas1, eps=1e-7),
        'brier_score_probas_pos_neg': brier_score_loss(X['test_labels'], probas2),
        'log_loss_probas_pos_neg': log_loss(X['test_labels'], probas2, eps=1e-7),
        'brier_score_probas_pos_neg_iso': brier_score_loss(X['test_labels'], probas3),
        'log_loss_probas_pos_neg_iso': log_loss(X['test_labels'], probas3, eps=1e-7),
        'brier_score_probas_pos_iso': brier_score_loss(X['test_labels'], probas4),
        'log_loss_probas_pos_iso': log_loss(X['test_labels'], probas4, eps=1e-7),
        'brier_score_probas_pos_sc': brier_score_loss(X['test_labels'], probas5),
        'log_loss_probas_pos_sc': log_loss(X['test_labels'], probas5, eps=1e-7),
        'brier_score_probas_pos_neg_sc': brier_score_loss(X['test_labels'], probas6),
        'log_loss_probas_pos_neg_sc': log_loss(X['test_labels'], probas6, eps=1e-7),
        'ece_scores': calibration_loss(X['test_labels'], expit(scores)),
        'ece_probas_pos': calibration_loss(X['test_labels'], probas1),
        'ece_probas_pos_neg': calibration_loss(X['test_labels'], probas2),
        'ece_probas_pos_neg_iso': calibration_loss(X['test_labels'], probas3),
        'ece_probas_pos_iso': calibration_loss(X['test_labels'], probas4),
        'ece_probas_pos_sc': calibration_loss(X['test_labels'], probas5),
        'ece_probas_pos_neg_sc': calibration_loss(X['test_labels'], probas6),
        'metrics_mrr': mrr_score(ranks), 
        'metrics_hits@10': hits_at_n_score(ranks, n=10),
        'metrics_mr': mr_score(ranks),
        'accuracy_per_relation': per_relation_acc,
        'accuracy_uncalib': acc_uncalib,
        'accuracy_pos': acc1,
        'accuracy_pos_neg': acc2,
        'accuracy_pos_neg_iso': acc3,
        'accuracy_pos_iso': acc4,
        'accuracy_pos_sc': acc5,
        'accuracy_pos_neg_sc': acc6
    })
        
    print(json.dumps(results[-1], indent=2))

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Use tf.cast instead.
pos [-1.4291335, -5.5321407]
pos neg [-2.7499628, -9.8842945]
Instructions for updating:
Use tf.random.categorical instead.
pos sc -1.5018368343229849 -5.927595038309995
pos neg sc -2.750017213431433 -9.884593929146579


  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "self_adversarial",
  "brier_score_scores": 0.445572232212395,
  "log_loss_scores": 1.533928791554572,
  "brier_score_probas_pos": 0.14212788206615748,
  "log_loss_probas_pos": 0.4463757120985599,
  "brier_score_probas_pos_neg": 0.12403854710103454,
  "log_loss_probas_pos_neg": 0.389833326626873,
  "brier_score_probas_pos_neg_iso": 0.1240950381452728,
  "log_loss_probas_pos_neg_iso": 0.3898854834610226,
  "brier_score_probas_pos_iso": 0.14107614825076797,
  "log_loss_probas_pos_iso": 0.44215528119862046,
  "brier_score_probas_pos_sc": 0.14510261125934348,
  "log_loss_probas_pos_sc": 0.45124740933776486,
  "brier_score_probas_pos_neg_sc": 0.12403848918708212,
  "log_loss_probas_pos_neg_sc": 0.3898356655070431,
  "ece_scores": 0.4625255433962153,
  "ece_probas_pos": 0.10802915995225082,
  "ece_probas_pos_neg": 0.01067582282653529,
  "ece_probas_pos_neg_iso": 0.00699252277435059,
  "ece_probas_pos_iso": 0.08946030833209556,
  "ece_probas_pos_sc": 0.1130495

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "pairwise",
  "brier_score_scores": 0.4999600556860562,
  "log_loss_scores": 5.233593539454379,
  "brier_score_probas_pos": 0.22494239533363558,
  "log_loss_probas_pos": 0.6373193005651496,
  "brier_score_probas_pos_neg": 0.22469787795731397,
  "log_loss_probas_pos_neg": 0.6362553057238051,
  "brier_score_probas_pos_neg_iso": 0.20299433345446838,
  "log_loss_probas_pos_neg_iso": 0.5820868916678338,
  "brier_score_probas_pos_iso": 0.20803145528953107,
  "log_loss_probas_pos_iso": 0.5936251368204077,
  "brier_score_probas_pos_sc": 0.2284992373799987,
  "log_loss_probas_pos_sc": 0.6437693785737651,
  "brier_score_probas_pos_neg_sc": 0.22469730589396622,
  "log_loss_probas_pos_neg_sc": 0.6362553495353314,
  "ece_scores": 0.4999776426222399,
  "ece_probas_pos": 0.10634056276933913,
  "ece_probas_pos_neg": 0.156289998909715,
  "ece_probas_pos_neg_iso": 0.006800646281727416,
  "ece_probas_pos_iso": 0.06314431733619326,
  "ece_probas_pos_sc": 0.1514007931267311

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "nll",
  "brier_score_scores": 0.23567338666195467,
  "log_loss_scores": 0.663440711386646,
  "brier_score_probas_pos": 0.23965421321529687,
  "log_loss_probas_pos": 0.6763748084414889,
  "brier_score_probas_pos_neg": 0.20938845351015112,
  "log_loss_probas_pos_neg": 0.6144616960668199,
  "brier_score_probas_pos_neg_iso": 0.20316403105512415,
  "log_loss_probas_pos_neg_iso": 0.5923197087975974,
  "brier_score_probas_pos_iso": 0.24372140451844315,
  "log_loss_probas_pos_iso": 0.6846926645080735,
  "brier_score_probas_pos_sc": 0.2531269934159015,
  "log_loss_probas_pos_sc": 0.7037352964205206,
  "brier_score_probas_pos_neg_sc": 0.2093907411634531,
  "log_loss_probas_pos_neg_sc": 0.6144741272361858,
  "ece_scores": 0.14335939491576738,
  "ece_probas_pos": 0.16020022397215725,
  "ece_probas_pos_neg": 0.04415973265979175,
  "ece_probas_pos_neg_iso": 0.0090836734701206,
  "ece_probas_pos_iso": 0.17886891692122034,
  "ece_probas_pos_sc": 0.1870886979369448,
  

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "multiclass_nll",
  "brier_score_scores": 0.500020870354592,
  "log_loss_scores": 7.901857966070853,
  "brier_score_probas_pos": 0.16184198433056446,
  "log_loss_probas_pos": 0.5003637373407782,
  "brier_score_probas_pos_neg": 0.14617441819292532,
  "log_loss_probas_pos_neg": 0.45488400247670174,
  "brier_score_probas_pos_neg_iso": 0.1456411774965459,
  "log_loss_probas_pos_neg_iso": 0.45374960293180117,
  "brier_score_probas_pos_iso": 0.1586214491769169,
  "log_loss_probas_pos_iso": 0.48997647680508644,
  "brier_score_probas_pos_sc": 0.15659950329930739,
  "log_loss_probas_pos_sc": 0.4850075244687033,
  "brier_score_probas_pos_neg_sc": 0.14617440723072703,
  "log_loss_probas_pos_neg_sc": 0.4548842209901208,
  "ece_scores": 0.500020962877039,
  "ece_probas_pos": 0.11361927245130592,
  "ece_probas_pos_neg": 0.022855867649157452,
  "ece_probas_pos_neg_iso": 0.007904925558477308,
  "ece_probas_pos_iso": 0.07605843723007416,
  "ece_probas_pos_sc": 0.0842709

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "self_adversarial",
  "brier_score_scores": 0.4729471412739745,
  "log_loss_scores": 2.1774196678224476,
  "brier_score_probas_pos": 0.18645320999775755,
  "log_loss_probas_pos": 0.5535015160421436,
  "brier_score_probas_pos_neg": 0.17807651772507113,
  "log_loss_probas_pos_neg": 0.5334941857489571,
  "brier_score_probas_pos_neg_iso": 0.17049598097100216,
  "log_loss_probas_pos_neg_iso": 0.517608292686097,
  "brier_score_probas_pos_iso": 0.19185121871829725,
  "log_loss_probas_pos_iso": 0.5668575806850646,
  "brier_score_probas_pos_sc": 0.1851380969269908,
  "log_loss_probas_pos_sc": 0.5500556126594139,
  "brier_score_probas_pos_neg_sc": 0.17807713363817593,
  "log_loss_probas_pos_neg_sc": 0.53349467308897,
  "ece_scores": 0.48212609757927805,
  "ece_probas_pos": 0.07637187688362332,
  "ece_probas_pos_neg": 0.06214156248065968,
  "ece_probas_pos_neg_iso": 0.006862596424947496,
  "ece_probas_pos_iso": 0.08560017145607507,
  "ece_probas_pos_sc": 0.07359

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "pairwise",
  "brier_score_scores": 0.23753961116592057,
  "log_loss_scores": 0.6670091431740247,
  "brier_score_probas_pos": 0.21438765500491325,
  "log_loss_probas_pos": 0.6190712009264032,
  "brier_score_probas_pos_neg": 0.21403612631466135,
  "log_loss_probas_pos_neg": 0.6188065476391061,
  "brier_score_probas_pos_neg_iso": 0.203998687110917,
  "log_loss_probas_pos_neg_iso": 0.5940850779361583,
  "brier_score_probas_pos_iso": 0.21958758244628124,
  "log_loss_probas_pos_iso": 0.6286109215073183,
  "brier_score_probas_pos_sc": 0.21530682117297284,
  "log_loss_probas_pos_sc": 0.620546083187935,
  "brier_score_probas_pos_neg_sc": 0.21403602519643847,
  "log_loss_probas_pos_neg_sc": 0.618806552067774,
  "ece_scores": 0.15019018089873745,
  "ece_probas_pos": 0.09270147937316225,
  "ece_probas_pos_neg": 0.08561137947006649,
  "ece_probas_pos_neg_iso": 0.004676083815562329,
  "ece_probas_pos_iso": 0.08302598630499797,
  "ece_probas_pos_sc": 0.090938775366

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "nll",
  "brier_score_scores": 0.4382576676826786,
  "log_loss_scores": 3.087698253423287,
  "brier_score_probas_pos": 0.23801610641051038,
  "log_loss_probas_pos": 0.6691468268083168,
  "brier_score_probas_pos_neg": 0.23767329471797705,
  "log_loss_probas_pos_neg": 0.6687838055087311,
  "brier_score_probas_pos_neg_iso": 0.2262255686393773,
  "log_loss_probas_pos_neg_iso": 0.644588759389898,
  "brier_score_probas_pos_iso": 0.23646773454628486,
  "log_loss_probas_pos_iso": 0.6700626209445736,
  "brier_score_probas_pos_sc": 0.2383986951678472,
  "log_loss_probas_pos_sc": 0.6698704066745614,
  "brier_score_probas_pos_neg_sc": 0.23767361875040668,
  "log_loss_probas_pos_neg_sc": 0.6687839653700602,
  "ece_scores": 0.4445666813557051,
  "ece_probas_pos": 0.07794582965340642,
  "ece_probas_pos_neg": 0.07578706950798177,
  "ece_probas_pos_neg_iso": 0.006630927281390513,
  "ece_probas_pos_iso": 0.04808071010383346,
  "ece_probas_pos_sc": 0.07995826337918524,


  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "multiclass_nll",
  "brier_score_scores": 0.3140456660629251,
  "log_loss_scores": 1.2732361626009205,
  "brier_score_probas_pos": 0.18316063474984462,
  "log_loss_probas_pos": 0.5568240050750015,
  "brier_score_probas_pos_neg": 0.17989036581115075,
  "log_loss_probas_pos_neg": 0.5535077047667407,
  "brier_score_probas_pos_neg_iso": 0.17391330744516786,
  "log_loss_probas_pos_neg_iso": 0.5253263371622415,
  "brier_score_probas_pos_iso": 0.19501113019684788,
  "log_loss_probas_pos_iso": 0.5753216504126762,
  "brier_score_probas_pos_sc": 0.1829010773176935,
  "log_loss_probas_pos_sc": 0.5561097902997354,
  "brier_score_probas_pos_neg_sc": 0.1798908536749029,
  "log_loss_probas_pos_neg_sc": 0.5535076556850803,
  "ece_scores": 0.3222176159450904,
  "ece_probas_pos": 0.08348978130366937,
  "ece_probas_pos_neg": 0.06240462999650886,
  "ece_probas_pos_neg_iso": 0.0064616998852716654,
  "ece_probas_pos_iso": 0.10261534535769529,
  "ece_probas_pos_sc": 0.08102

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "ComplEx",
  "loss": "self_adversarial",
  "brier_score_scores": 0.4811223178274732,
  "log_loss_scores": 2.392532135475111,
  "brier_score_probas_pos": 0.18203577624335152,
  "log_loss_probas_pos": 0.5456178826873788,
  "brier_score_probas_pos_neg": 0.17749806921630634,
  "log_loss_probas_pos_neg": 0.5342846515320441,
  "brier_score_probas_pos_neg_iso": 0.16956978699251063,
  "log_loss_probas_pos_neg_iso": 0.515994842173244,
  "brier_score_probas_pos_iso": 0.18941839172044928,
  "log_loss_probas_pos_iso": 0.565039163884012,
  "brier_score_probas_pos_sc": 0.18158115303277242,
  "log_loss_probas_pos_sc": 0.5439933569081571,
  "brier_score_probas_pos_neg_sc": 0.1774978835552258,
  "log_loss_probas_pos_neg_sc": 0.5342847729083797,
  "ece_scores": 0.48747884693745314,
  "ece_probas_pos": 0.06332614075046515,
  "ece_probas_pos_neg": 0.04111855560102128,
  "ece_probas_pos_neg_iso": 0.006724254758265716,
  "ece_probas_pos_iso": 0.07309470863688503,
  "ece_probas_pos_sc": 0.059653

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "ComplEx",
  "loss": "pairwise",
  "brier_score_scores": 0.23555795939871763,
  "log_loss_scores": 0.6617420081052304,
  "brier_score_probas_pos": 0.22445463897347503,
  "log_loss_probas_pos": 0.6414263702001465,
  "brier_score_probas_pos_neg": 0.2240541206254878,
  "log_loss_probas_pos_neg": 0.6412851086107535,
  "brier_score_probas_pos_neg_iso": 0.21080926402743322,
  "log_loss_probas_pos_neg_iso": 0.6086382159343459,
  "brier_score_probas_pos_iso": 0.23469225029084576,
  "log_loss_probas_pos_iso": 0.6630263830311592,
  "brier_score_probas_pos_sc": 0.22481584602219065,
  "log_loss_probas_pos_sc": 0.6421728329535827,
  "brier_score_probas_pos_neg_sc": 0.2240517294342283,
  "log_loss_probas_pos_neg_sc": 0.6412859067751632,
  "ece_scores": 0.12309747593772162,
  "ece_probas_pos": 0.09198144371104088,
  "ece_probas_pos_neg": 0.08788703204447033,
  "ece_probas_pos_neg_iso": 0.0038032603625317188,
  "ece_probas_pos_iso": 0.11256971996891857,
  "ece_probas_pos_sc": 0.0954670685

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "ComplEx",
  "loss": "nll",
  "brier_score_scores": 0.4625924680726371,
  "log_loss_scores": 3.6493684307800405,
  "brier_score_probas_pos": 0.2346758330157365,
  "log_loss_probas_pos": 0.6622780543662069,
  "brier_score_probas_pos_neg": 0.228931059151346,
  "log_loss_probas_pos_neg": 0.6518663220069411,
  "brier_score_probas_pos_neg_iso": 0.21632703333282008,
  "log_loss_probas_pos_neg_iso": 0.6238048192279823,
  "brier_score_probas_pos_iso": 0.22694111626950392,
  "log_loss_probas_pos_iso": 0.649013100657048,
  "brier_score_probas_pos_sc": 0.23198145694462233,
  "log_loss_probas_pos_sc": 0.6567289160569362,
  "brier_score_probas_pos_neg_sc": 0.22893179685984202,
  "log_loss_probas_pos_neg_sc": 0.6518661446038495,
  "ece_scores": 0.46941578844315723,
  "ece_probas_pos": 0.09649594844809153,
  "ece_probas_pos_neg": 0.08862517486027677,
  "ece_probas_pos_neg_iso": 0.012830829586233241,
  "ece_probas_pos_iso": 0.08303049091925563,
  "ece_probas_pos_sc": 0.07176920694503337,


  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "ComplEx",
  "loss": "multiclass_nll",
  "brier_score_scores": 0.29856060284901625,
  "log_loss_scores": 1.1261303584927793,
  "brier_score_probas_pos": 0.19942511223510748,
  "log_loss_probas_pos": 0.5921297527611232,
  "brier_score_probas_pos_neg": 0.19661015981974805,
  "log_loss_probas_pos_neg": 0.5900205869263693,
  "brier_score_probas_pos_neg_iso": 0.18650290391319285,
  "log_loss_probas_pos_neg_iso": 0.5537825671675888,
  "brier_score_probas_pos_iso": 0.20885314019041265,
  "log_loss_probas_pos_iso": 0.6062831198962915,
  "brier_score_probas_pos_sc": 0.19870629973719495,
  "log_loss_probas_pos_sc": 0.5910888272852521,
  "brier_score_probas_pos_neg_sc": 0.1966109922900161,
  "log_loss_probas_pos_neg_sc": 0.5900206758028981,
  "ece_scores": 0.292833538174368,
  "ece_probas_pos": 0.09742225395895751,
  "ece_probas_pos_neg": 0.08382435012140989,
  "ece_probas_pos_neg_iso": 0.008990535432408799,
  "ece_probas_pos_iso": 0.11580457371941622,
  "ece_probas_pos_sc": 0.095225

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "HolE",
  "loss": "self_adversarial",
  "brier_score_scores": 0.4524849789534678,
  "log_loss_scores": 1.6806720229416392,
  "brier_score_probas_pos": 0.2420076983890822,
  "log_loss_probas_pos": 0.6767263242704648,
  "brier_score_probas_pos_neg": 0.2294761132125967,
  "log_loss_probas_pos_neg": 0.6499658735541506,
  "brier_score_probas_pos_neg_iso": 0.22819237325108171,
  "log_loss_probas_pos_neg_iso": 0.6512607852769633,
  "brier_score_probas_pos_iso": 0.2630681991220293,
  "log_loss_probas_pos_iso": 0.724805277586707,
  "brier_score_probas_pos_sc": 0.2386355637187181,
  "log_loss_probas_pos_sc": 0.6689784901307039,
  "brier_score_probas_pos_neg_sc": 0.22947579823980108,
  "log_loss_probas_pos_neg_sc": 0.6499649660219388,
  "ece_scores": 0.45997472404655065,
  "ece_probas_pos": 0.10799009383689208,
  "ece_probas_pos_neg": 0.026804119471593058,
  "ece_probas_pos_neg_iso": 0.006375218728987515,
  "ece_probas_pos_iso": 0.15997969244594482,
  "ece_probas_pos_sc": 0.092485022

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "HolE",
  "loss": "pairwise",
  "brier_score_scores": 0.23291515060139195,
  "log_loss_scores": 0.6535404976821143,
  "brier_score_probas_pos": 0.2149373772754216,
  "log_loss_probas_pos": 0.6210353871879349,
  "brier_score_probas_pos_neg": 0.21448199521229427,
  "log_loss_probas_pos_neg": 0.6206827404878698,
  "brier_score_probas_pos_neg_iso": 0.19967466235375098,
  "log_loss_probas_pos_neg_iso": 0.5834483158920643,
  "brier_score_probas_pos_iso": 0.2233954721226985,
  "log_loss_probas_pos_iso": 0.6354967023709988,
  "brier_score_probas_pos_sc": 0.21543123310995757,
  "log_loss_probas_pos_sc": 0.6216989794455341,
  "brier_score_probas_pos_neg_sc": 0.21447989833067485,
  "log_loss_probas_pos_neg_sc": 0.6206832952574213,
  "ece_scores": 0.14755065314950178,
  "ece_probas_pos": 0.09450239506696906,
  "ece_probas_pos_neg": 0.09453886579587169,
  "ece_probas_pos_neg_iso": 0.005358524883773092,
  "ece_probas_pos_iso": 0.12063815282859892,
  "ece_probas_pos_sc": 0.09445204505748

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "HolE",
  "loss": "nll",
  "brier_score_scores": 0.33461173199179783,
  "log_loss_scores": 1.18848447930473,
  "brier_score_probas_pos": 0.24927695708889283,
  "log_loss_probas_pos": 0.6916871320512149,
  "brier_score_probas_pos_neg": 0.24927398402253484,
  "log_loss_probas_pos_neg": 0.6916804365159468,
  "brier_score_probas_pos_neg_iso": 0.2282332366747824,
  "log_loss_probas_pos_neg_iso": 0.6433693097728314,
  "brier_score_probas_pos_iso": 0.2338295060006483,
  "log_loss_probas_pos_iso": 0.657914567622181,
  "brier_score_probas_pos_sc": 0.24927900507396344,
  "log_loss_probas_pos_sc": 0.6916916541303273,
  "brier_score_probas_pos_neg_sc": 0.2492739583845122,
  "log_loss_probas_pos_neg_sc": 0.69168033400879,
  "ece_scores": 0.2687403892518376,
  "ece_probas_pos": 0.045207817395641975,
  "ece_probas_pos_neg": 0.04453298290282581,
  "ece_probas_pos_neg_iso": 0.004970803743821933,
  "ece_probas_pos_iso": 0.05295066778403932,
  "ece_probas_pos_sc": 0.04570512295254479,
  "ece

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


In [7]:
import pandas as pd

In [8]:
def highlight_min(s):
    is_min = s == s.min()
    return ['font-weight: bold' if v else '' for v in is_min]

In [18]:
df = pd.DataFrame(results).set_index(['model', 'loss'])
df.to_csv("main_results_fb13.csv", index=False)

In [10]:
bs = df[(c for c in df.columns if c.startswith('brier'))]
bs.columns = [c[len("brier_score_"):] for c in bs.columns]
bs.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,probas_pos,probas_pos_iso,probas_pos_neg,probas_pos_neg_iso,probas_pos_neg_sc,probas_pos_sc,scores
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TransE,self_adversarial,0.142128,0.141076,0.124039,0.124095,0.124038,0.145103,0.445572
TransE,pairwise,0.224942,0.208031,0.224698,0.202994,0.224697,0.228499,0.49996
TransE,nll,0.239654,0.243721,0.209388,0.203164,0.209391,0.253127,0.235673
TransE,multiclass_nll,0.161842,0.158621,0.146174,0.145641,0.146174,0.1566,0.500021
DistMult,self_adversarial,0.186453,0.191851,0.178077,0.170496,0.178077,0.185138,0.472947
DistMult,pairwise,0.214388,0.219588,0.214036,0.203999,0.214036,0.215307,0.23754
DistMult,nll,0.238016,0.236468,0.237673,0.226226,0.237674,0.238399,0.438258
DistMult,multiclass_nll,0.183161,0.195011,0.17989,0.173913,0.179891,0.182901,0.314046
ComplEx,self_adversarial,0.182036,0.189418,0.177498,0.16957,0.177498,0.181581,0.481122
ComplEx,pairwise,0.224455,0.234692,0.224054,0.210809,0.224052,0.224816,0.235558


In [11]:
ll = df[(c for c in df.columns if c.startswith('log_loss'))]
ll.columns = [c[len("log_loss_"):] for c in ll.columns]
ll.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,probas_pos,probas_pos_iso,probas_pos_neg,probas_pos_neg_iso,probas_pos_neg_sc,probas_pos_sc,scores
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TransE,self_adversarial,0.446376,0.442155,0.389833,0.389885,0.389836,0.451247,1.53393
TransE,pairwise,0.637319,0.593625,0.636255,0.582087,0.636255,0.643769,5.23359
TransE,nll,0.676375,0.684693,0.614462,0.59232,0.614474,0.703735,0.663441
TransE,multiclass_nll,0.500364,0.489976,0.454884,0.45375,0.454884,0.485008,7.90186
DistMult,self_adversarial,0.553502,0.566858,0.533494,0.517608,0.533495,0.550056,2.17742
DistMult,pairwise,0.619071,0.628611,0.618807,0.594085,0.618807,0.620546,0.667009
DistMult,nll,0.669147,0.670063,0.668784,0.644589,0.668784,0.66987,3.0877
DistMult,multiclass_nll,0.556824,0.575322,0.553508,0.525326,0.553508,0.55611,1.27324
ComplEx,self_adversarial,0.545618,0.565039,0.534285,0.515995,0.534285,0.543993,2.39253
ComplEx,pairwise,0.641426,0.663026,0.641285,0.608638,0.641286,0.642173,0.661742


In [12]:
print((bs.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'scores', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
 .round(3)
 .to_latex()))

\begin{tabular}{llrrrrr}
\toprule
{} &     model &  scores &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &    TransE &   0.446 &           0.124 &               0.124 &       0.142 &           0.141 \\
1 &  DistMult &   0.473 &           0.178 &               0.170 &       0.186 &           0.192 \\
2 &   ComplEx &   0.481 &           0.177 &               0.170 &       0.182 &           0.189 \\
3 &      HolE &   0.452 &           0.229 &               0.228 &       0.242 &           0.263 \\
\bottomrule
\end{tabular}



In [13]:
print(ll.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'scores', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrrr}
\toprule
{} &     model &  scores &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &    TransE &   1.534 &           0.390 &               0.390 &       0.446 &           0.442 \\
1 &  DistMult &   2.177 &           0.533 &               0.518 &       0.554 &           0.567 \\
2 &   ComplEx &   2.393 &           0.534 &               0.516 &       0.546 &           0.565 \\
3 &      HolE &   1.681 &           0.650 &               0.651 &       0.677 &           0.725 \\
\bottomrule
\end{tabular}



In [15]:
print(bs.reset_index()
 .query("model == 'TransE' ")
 [['loss',  'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrr}
\toprule
{} &              loss &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &  self\_adversarial &           0.124 &               0.124 &       0.142 &           0.141 \\
1 &          pairwise &           0.225 &               0.203 &       0.225 &           0.208 \\
2 &               nll &           0.209 &               0.203 &       0.240 &           0.244 \\
3 &    multiclass\_nll &           0.146 &               0.146 &       0.162 &           0.159 \\
\bottomrule
\end{tabular}



In [16]:
print(ll.reset_index()
 .query("model == 'TransE' ")
 [['loss', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrr}
\toprule
{} &              loss &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &  self\_adversarial &           0.390 &               0.390 &       0.446 &           0.442 \\
1 &          pairwise &           0.636 &               0.582 &       0.637 &           0.594 \\
2 &               nll &           0.614 &               0.592 &       0.676 &           0.685 \\
3 &    multiclass\_nll &           0.455 &               0.454 &       0.500 &           0.490 \\
\bottomrule
\end{tabular}



In [22]:
acc = df[(c for c in df.columns if c.startswith('accuracy'))]
acc.columns = [c[len("accuracy_"):] for c in acc.columns]
acc.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,per_relation,pos,pos_iso,pos_neg,pos_neg_iso,pos_neg_sc,pos_sc,uncalib
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
TransE,self_adversarial,0.820685,0.806738,0.801639,0.824183,0.824056,0.824183,0.795951,0.499979
TransE,pairwise,0.804526,0.600287,0.640022,0.554462,0.642234,0.554547,0.610189,0.499979
TransE,nll,0.662902,0.589731,0.618068,0.678261,0.679399,0.678282,0.550544,0.499979
TransE,multiclass_nll,0.825721,0.789693,0.784995,0.796793,0.795656,0.796814,0.787228,0.499979
DistMult,self_adversarial,0.80817,0.72141,0.702216,0.724907,0.73167,0.724907,0.722864,0.500864
DistMult,pairwise,0.682644,0.680958,0.651041,0.685951,0.692693,0.685951,0.667285,0.570116
DistMult,nll,0.609704,0.600371,0.626433,0.610147,0.628223,0.610105,0.586508,0.540473
DistMult,multiclass_nll,0.764917,0.747535,0.708305,0.748167,0.749178,0.748167,0.743237,0.60252
ComplEx,self_adversarial,0.83577,0.741846,0.724001,0.737696,0.74172,0.737675,0.741783,0.500611
ComplEx,pairwise,0.636967,0.640528,0.607028,0.648808,0.67192,0.648808,0.61672,0.553936


In [23]:
print((acc*100).reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'pos_neg', 'pos_neg_iso', 'pos', 'pos_iso',  'uncalib', 'per_relation']]
 .reset_index(drop=True)
  .round(1)
 .to_latex())

\begin{tabular}{llrrrrrr}
\toprule
{} &     model &  pos\_neg &  pos\_neg\_iso &   pos &  pos\_iso &  uncalib &  per\_relation \\
\midrule
0 &    TransE &     82.4 &         82.4 &  80.7 &     80.2 &     50.0 &          82.1 \\
1 &  DistMult &     72.5 &         73.2 &  72.1 &     70.2 &     50.1 &          80.8 \\
2 &   ComplEx &     73.8 &         74.2 &  74.2 &     72.4 &     50.1 &          83.6 \\
3 &      HolE &     60.3 &         60.6 &  57.8 &     54.3 &     50.0 &          62.6 \\
\bottomrule
\end{tabular}



In [17]:
metrics = df[(c for c in df.columns if c.startswith('metrics'))]
metrics.columns = [c[len("metrics_"):] for c in metrics.columns]
metrics

Unnamed: 0_level_0,Unnamed: 1_level_0,hits@10,mr,mrr
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
TransE,self_adversarial,0.394156,3430.869359,0.295835
TransE,pairwise,0.374099,7631.636561,0.282413
TransE,nll,0.273269,6799.552016,0.202138
TransE,multiclass_nll,0.402267,5050.017486,0.308702
DistMult,self_adversarial,0.320103,5998.061349,0.181442
DistMult,pairwise,0.162011,11409.057599,0.095039
DistMult,nll,0.054945,16951.142734,0.028533
DistMult,multiclass_nll,0.25829,7125.716049,0.171373
ComplEx,self_adversarial,0.337273,6667.319092,0.183393
ComplEx,pairwise,0.156702,12744.418257,0.092388


In [21]:
print(metrics.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'mr', 'mrr', 'hits@10']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrr}
\toprule
{} &     model &        mr &    mrr &  hits@10 \\
\midrule
0 &    TransE &  3430.869 &  0.296 &    0.394 \\
1 &  DistMult &  5998.061 &  0.181 &    0.320 \\
2 &   ComplEx &  6667.319 &  0.183 &    0.337 \\
3 &      HolE &  8937.297 &  0.018 &    0.039 \\
\bottomrule
\end{tabular}



In [None]:
def highlight_max(s):
    is_min = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_min]

acc = df[(c for c in df.columns if c.startswith('accuracy'))]
acc.columns = [c[len("accuracy_"):] for c in acc.columns]
acc.style.apply(highlight_max, axis=1)

In [None]:
df.corr(method='spearman').reset_index().query("index.str.startswith('accuracy')")[['index', 'log_loss_probas_pos_neg', 'log_loss_probas_pos_neg_iso', 'log_loss_probas_pos', 'log_loss_probas_pos_iso']]

In [None]:
df.corr(method='spearman').reset_index().query("index.str.startswith('accuracy')")[['index', 'brier_score_probas_pos_neg', 'brier_score_probas_pos_neg_iso', 'brier_score_probas_pos', 'brier_score_probas_pos_iso']]

In [None]:
len(np.unique(X['valid'][:, 1]))

In [19]:
model = TransE(batches_count=64, seed=0, epochs=1000, k=100, eta=20,
               optimizer='adam', optimizer_params={'lr':0.0001},
               loss='self_adversarial', verbose=False)

model.fit(X['train'])

scores = model.predict(X['test'])

model.calibrate(X_valid_pos, batches_count=10, epochs=1000, positive_base_rate=0.5)
print("pos", model.calibration_parameters)
probas1 = model.predict_proba(X['test'])

model.calibrate(X_valid_pos, X_valid_neg)
print("pos neg", model.calibration_parameters)
probas2 = model.predict_proba(X['test'])

val_scores = model.predict(X['valid'])

thresholds = {r: np.median(np.sort(val_scores[X['valid'][:, 1] == r])) for r in np.unique(X['valid'][:, 1])}
thresholds_test = np.vectorize(thresholds.get)(X['test'][:, 1])
per_relation_acc = accuracy_score(X['test_labels'], scores > thresholds_test)

print(thresholds)


pos [-1.4380144, -5.5623884]
pos neg [-2.752314, -9.8970175]
{'cause_of_death': -3.5680597, 'ethnicity': -3.4997067, 'gender': -3.4051323, 'institution': -3.547462, 'nationality': -3.8507419, 'profession': -3.7040129, 'religion': -3.5918012}


In [20]:
per_relation_acc

0.8216332378223495