In [1]:
import numpy as np
import itertools
import json

from sklearn.isotonic import IsotonicRegression
from sklearn.calibration import calibration_curve, _SigmoidCalibration, _sigmoid_calibration
from ampligraph.evaluation import evaluate_performance, mr_score, mrr_score, hits_at_n_score, generate_corruptions_for_eval
from sklearn.metrics import brier_score_loss, log_loss, accuracy_score
from scipy.special import expit

from ampligraph.datasets import load_wn11
from ampligraph.latent_features.models import TransE, ComplEx, DistMult
import types

In [2]:
from generate_corruptions import generate_corruptions, calibration_loss, pos_iso

In [3]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [4]:
X = load_wn11()

In [5]:
X_valid_pos = X['valid'][X['valid_labels']]
X_valid_neg = X['valid'][~X['valid_labels']]

X_test_pos = X['test'][X['test_labels']]
X_test_neg = X['test'][~X['test_labels']]

In [6]:
losses =  ['self_adversarial', 'pairwise', 'nll', 'multiclass_nll']
models = [TransE, DistMult, ComplEx]

results = []

for m, l in itertools.product(models, losses):
    model = m(batches_count=64, seed=0, epochs=1000, k=100, eta=20,
                   optimizer='adam', optimizer_params={'lr':0.0001},
                   loss=l, verbose=False)

    model.fit(X['train'])
    
    scores = model.predict(X['test'])

    model.calibrate(X_valid_pos, batches_count=10, epochs=1000, positive_base_rate=0.5)
    print("pos", model.calibration_parameters)
    probas1 = model.predict_proba(X['test'])

    model.calibrate(X_valid_pos, X_valid_neg)
    print("pos neg", model.calibration_parameters)
    probas2 = model.predict_proba(X['test'])
    
    val_scores = model.predict(X['valid'])
    ir = IsotonicRegression(out_of_bounds='clip')
    ir.fit(np.squeeze(val_scores).astype(float), (X['valid_labels']).astype(float))
    probas3 = ir.predict(np.squeeze(scores).astype(float))
    
    model.generate_corruptions = types.MethodType(generate_corruptions, model)
    corruptions = model.generate_corruptions(X_valid_pos, batches_count=10, epochs=1000)
    val_pos_scores = np.squeeze(model.predict(X_valid_pos))
    iso_pos = pos_iso(IsotonicRegression(out_of_bounds='clip'), val_pos_scores, corruptions, positive_base_rate=0.5)
    probas4 = iso_pos.predict(np.squeeze(scores).astype(float))

    sc_pos = pos_iso(_SigmoidCalibration(), val_pos_scores, corruptions, positive_base_rate=0.5)
    print("pos sc", sc_pos.a_, sc_pos.b_)
    probas5 = sc_pos.predict(np.squeeze(scores).astype(float))
    
    val_neg_scores = np.squeeze(model.predict(X_valid_neg))
    sc_pos_neg = pos_iso(_SigmoidCalibration(), val_pos_scores, val_neg_scores, positive_base_rate=0.5)
    print("pos neg sc", sc_pos_neg.a_, sc_pos_neg.b_)
    probas6 = sc_pos_neg.predict(np.squeeze(scores).astype(float))
    
    thresholds = {r: np.median(np.sort(val_scores[X['valid'][:, 1] == r])) for r in np.unique(X['valid'][:, 1])}
    thresholds_test = np.vectorize(thresholds.get)(X['test'][:, 1])
    per_relation_acc = accuracy_score(X['test_labels'], scores > thresholds_test)

    acc_uncalib = accuracy_score(X['test_labels'], expit(scores) > 0.5)
    
    acc1 = accuracy_score(X['test_labels'], probas1 > 0.5)
    acc2 = accuracy_score(X['test_labels'], probas2 > 0.5)
    acc3 = accuracy_score(X['test_labels'], probas3 > 0.5)
    acc4 = accuracy_score(X['test_labels'], probas4 > 0.5)
    acc5 = accuracy_score(X['test_labels'], probas5 > 0.5)
    acc6 = accuracy_score(X['test_labels'], probas6 > 0.5)
    
    filter_triples = np.concatenate((X['train'], X_valid_pos, X_test_pos))
    ranks = evaluate_performance(X_test_pos, 
                                 model=model, 
                                 filter_triples=filter_triples,
                                 use_default_protocol=True, 
                                 verbose=False)

    results.append({
        'model': m.__name__,
        'loss': l,
        'brier_score_scores': brier_score_loss(X['test_labels'], expit(scores)),
        'log_loss_scores': log_loss(X['test_labels'], expit(scores), eps=1e-7),
        'brier_score_probas_pos': brier_score_loss(X['test_labels'], probas1),
        'log_loss_probas_pos': log_loss(X['test_labels'], probas1, eps=1e-7),
        'brier_score_probas_pos_neg': brier_score_loss(X['test_labels'], probas2),
        'log_loss_probas_pos_neg': log_loss(X['test_labels'], probas2, eps=1e-7),
        'brier_score_probas_pos_neg_iso': brier_score_loss(X['test_labels'], probas3),
        'log_loss_probas_pos_neg_iso': log_loss(X['test_labels'], probas3, eps=1e-7),
        'brier_score_probas_pos_iso': brier_score_loss(X['test_labels'], probas4),
        'log_loss_probas_pos_iso': log_loss(X['test_labels'], probas4, eps=1e-7),
        'brier_score_probas_pos_sc': brier_score_loss(X['test_labels'], probas5),
        'log_loss_probas_pos_sc': log_loss(X['test_labels'], probas5, eps=1e-7),
        'brier_score_probas_pos_neg_sc': brier_score_loss(X['test_labels'], probas6),
        'log_loss_probas_pos_neg_sc': log_loss(X['test_labels'], probas6, eps=1e-7),
        'ece_scores': calibration_loss(X['test_labels'], expit(scores)),
        'ece_probas_pos': calibration_loss(X['test_labels'], probas1),
        'ece_probas_pos_neg': calibration_loss(X['test_labels'], probas2),
        'ece_probas_pos_neg_iso': calibration_loss(X['test_labels'], probas3),
        'ece_probas_pos_iso': calibration_loss(X['test_labels'], probas4),
        'ece_probas_pos_sc': calibration_loss(X['test_labels'], probas5),
        'ece_probas_pos_neg_sc': calibration_loss(X['test_labels'], probas6),
        'metrics_mrr': mrr_score(ranks), 
        'metrics_hits@10': hits_at_n_score(ranks, n=10),
        'metrics_mr': mr_score(ranks),
        'accuracy_per_relation': per_relation_acc,
        'accuracy_uncalib': acc_uncalib,
        'accuracy_pos': acc1,
        'accuracy_pos_neg': acc2,
        'accuracy_pos_neg_iso': acc3,
        'accuracy_pos_iso': acc4,
        'accuracy_pos_sc': acc5,
        'accuracy_pos_neg_sc': acc6
    })
        
    print(json.dumps(results[-1], indent=2))

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, use
    tf.py_function, which takes a python function which manipulates tf eager
    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
    an ndarray (just call tensor.numpy()) but having access to eager tensors
    means `tf.py_function`s can use accelerators such as GPUs as well as
    being differentiable using a gradient tape.
    
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Use tf.cast instead.
pos [-1.1747085, -6.8984]
pos neg [-1.411434, -8.393471]
Instructions for updating:
Use tf.random.categorical instead.
pos sc -1.2899766086088724 -7.571272657176108
pos neg sc -1.4084979444808459 -8.398550811429823


  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "self_adversarial",
  "brier_score_scores": 0.44346406162225727,
  "log_loss_scores": 1.959084630629783,
  "brier_score_probas_pos": 0.09109558292235946,
  "log_loss_probas_pos": 0.30825534177589387,
  "brier_score_probas_pos_neg": 0.08912546028875036,
  "log_loss_probas_pos_neg": 0.30161135223550545,
  "brier_score_probas_pos_neg_iso": 0.08731137728045221,
  "log_loss_probas_pos_neg_iso": 0.29518257552306393,
  "brier_score_probas_pos_iso": 0.08744558557348009,
  "log_loss_probas_pos_iso": 0.2956627700025445,
  "brier_score_probas_pos_sc": 0.08980081978081432,
  "log_loss_probas_pos_sc": 0.30405143901998294,
  "brier_score_probas_pos_neg_sc": 0.08923570744979144,
  "log_loss_probas_pos_neg_sc": 0.30169943718357545,
  "ece_scores": 0.46545801992212155,
  "ece_probas_pos": 0.052743666863835886,
  "ece_probas_pos_neg": 0.03610154305522629,
  "ece_probas_pos_neg_iso": 0.01189540098861764,
  "ece_probas_pos_iso": 0.017002793966942147,
  "ece_probas_pos_sc":

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "pairwise",
  "brier_score_scores": 0.49250805747539333,
  "log_loss_scores": 5.233913629923759,
  "brier_score_probas_pos": 0.20856483052251926,
  "log_loss_probas_pos": 0.6063374306009015,
  "brier_score_probas_pos_neg": 0.20163406525426084,
  "log_loss_probas_pos_neg": 0.5907020842221917,
  "brier_score_probas_pos_neg_iso": 0.19837366718026453,
  "log_loss_probas_pos_neg_iso": 0.5853817281047783,
  "brier_score_probas_pos_iso": 0.20042963411258044,
  "log_loss_probas_pos_iso": 0.589109237515802,
  "brier_score_probas_pos_sc": 0.2023423084954731,
  "log_loss_probas_pos_sc": 0.591511554629423,
  "brier_score_probas_pos_neg_sc": 0.20175148381168845,
  "log_loss_probas_pos_neg_sc": 0.590642492531469,
  "ece_scores": 0.4925165175801786,
  "ece_probas_pos": 0.0777746192285957,
  "ece_probas_pos_neg": 0.04748816719082277,
  "ece_probas_pos_neg_iso": 0.009526206933827162,
  "ece_probas_pos_iso": 0.03673938266035769,
  "ece_probas_pos_sc": 0.05566147125937992

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "nll",
  "brier_score_scores": 0.22201063268476454,
  "log_loss_scores": 0.6700502351221408,
  "brier_score_probas_pos": 0.09324533034497749,
  "log_loss_probas_pos": 0.34213021432200946,
  "brier_score_probas_pos_neg": 0.09293337585657196,
  "log_loss_probas_pos_neg": 0.34168480309285526,
  "brier_score_probas_pos_neg_iso": 0.08769640172253451,
  "log_loss_probas_pos_neg_iso": 0.2984752555893492,
  "brier_score_probas_pos_iso": 0.08780424652325028,
  "log_loss_probas_pos_iso": 0.3003654371237818,
  "brier_score_probas_pos_sc": 0.0938658249957217,
  "log_loss_probas_pos_sc": 0.3436618233623619,
  "brier_score_probas_pos_neg_sc": 0.09317422275530084,
  "log_loss_probas_pos_neg_sc": 0.3419303718689412,
  "ece_scores": 0.29278871084631586,
  "ece_probas_pos": 0.06000074725168968,
  "ece_probas_pos_neg": 0.057673270092825496,
  "ece_probas_pos_neg_iso": 0.012591624148130889,
  "ece_probas_pos_iso": 0.01544802083695062,
  "ece_probas_pos_sc": 0.0638676505392

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "TransE",
  "loss": "multiclass_nll",
  "brier_score_scores": 0.49253953365446085,
  "log_loss_scores": 7.659103808423022,
  "brier_score_probas_pos": 0.2039921039598542,
  "log_loss_probas_pos": 0.5990255623699515,
  "brier_score_probas_pos_neg": 0.20369393981063982,
  "log_loss_probas_pos_neg": 0.5989555854895341,
  "brier_score_probas_pos_neg_iso": 0.1884424220536425,
  "log_loss_probas_pos_neg_iso": 0.548983498091133,
  "brier_score_probas_pos_iso": 0.18850522450100338,
  "log_loss_probas_pos_iso": 0.5501956245633723,
  "brier_score_probas_pos_sc": 0.20398060043126742,
  "log_loss_probas_pos_sc": 0.5989871088313099,
  "brier_score_probas_pos_neg_sc": 0.20398757410390173,
  "log_loss_probas_pos_neg_sc": 0.5990378068166818,
  "ece_scores": 0.49253983358484593,
  "ece_probas_pos": 0.10515479389137633,
  "ece_probas_pos_neg": 0.10444988028889171,
  "ece_probas_pos_neg_iso": 0.006049936858999714,
  "ece_probas_pos_iso": 0.010660992214912684,
  "ece_probas_pos_sc": 0.1059573

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "self_adversarial",
  "brier_score_scores": 0.488378335223354,
  "log_loss_scores": 5.624642671142175,
  "brier_score_probas_pos": 0.21361848624384808,
  "log_loss_probas_pos": 0.618242221305673,
  "brier_score_probas_pos_neg": 0.213457232654354,
  "log_loss_probas_pos_neg": 0.6182098549992107,
  "brier_score_probas_pos_neg_iso": 0.20790025293582376,
  "log_loss_probas_pos_neg_iso": 0.6036818747023027,
  "brier_score_probas_pos_iso": 0.20831913386270018,
  "log_loss_probas_pos_iso": 0.6011250361904669,
  "brier_score_probas_pos_sc": 0.2137826948514346,
  "log_loss_probas_pos_sc": 0.618371617960791,
  "brier_score_probas_pos_neg_sc": 0.21361054090289502,
  "log_loss_probas_pos_neg_sc": 0.6182369070459045,
  "ece_scores": 0.4898895212844896,
  "ece_probas_pos": 0.06165496879309724,
  "ece_probas_pos_neg": 0.05788633545848737,
  "ece_probas_pos_neg_iso": 0.009702682545481746,
  "ece_probas_pos_iso": 0.0245852114724739,
  "ece_probas_pos_sc": 0.0641194935

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "pairwise",
  "brier_score_scores": 0.22372001908672912,
  "log_loss_scores": 0.6359504224647637,
  "brier_score_probas_pos": 0.2173336239180292,
  "log_loss_probas_pos": 0.6210303633462188,
  "brier_score_probas_pos_neg": 0.21710696780029962,
  "log_loss_probas_pos_neg": 0.6210857484874194,
  "brier_score_probas_pos_neg_iso": 0.21069154105198049,
  "log_loss_probas_pos_neg_iso": 0.6064811887291971,
  "brier_score_probas_pos_iso": 0.21147651220390692,
  "log_loss_probas_pos_iso": 0.6056897792501762,
  "brier_score_probas_pos_sc": 0.2173884561558497,
  "log_loss_probas_pos_sc": 0.6216384920391862,
  "brier_score_probas_pos_neg_sc": 0.21727827970633276,
  "log_loss_probas_pos_neg_sc": 0.6211982888069971,
  "ece_scores": 0.098327869590038,
  "ece_probas_pos": 0.06725364103784756,
  "ece_probas_pos_neg": 0.06406736091394828,
  "ece_probas_pos_neg_iso": 0.006629287268029925,
  "ece_probas_pos_iso": 0.022242030858661443,
  "ece_probas_pos_sc": 0.06283050465

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "nll",
  "brier_score_scores": 0.46911904981480246,
  "log_loss_scores": 5.620262919216982,
  "brier_score_probas_pos": 0.22434492049124316,
  "log_loss_probas_pos": 0.6378381105222464,
  "brier_score_probas_pos_neg": 0.22416027474010272,
  "log_loss_probas_pos_neg": 0.6378677665758535,
  "brier_score_probas_pos_neg_iso": 0.21344538289612403,
  "log_loss_probas_pos_neg_iso": 0.6111088902401539,
  "brier_score_probas_pos_iso": 0.21378028224171725,
  "log_loss_probas_pos_iso": 0.6076548273666721,
  "brier_score_probas_pos_sc": 0.22447560059998464,
  "log_loss_probas_pos_sc": 0.6378917057042592,
  "brier_score_probas_pos_neg_sc": 0.2243521270280964,
  "log_loss_probas_pos_neg_sc": 0.6380260116509009,
  "ece_scores": 0.47490819940515716,
  "ece_probas_pos": 0.08531189923067208,
  "ece_probas_pos_neg": 0.0845508687716141,
  "ece_probas_pos_neg_iso": 0.00682326807080619,
  "ece_probas_pos_iso": 0.016599455026154033,
  "ece_probas_pos_sc": 0.088953909330603,

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "DistMult",
  "loss": "multiclass_nll",
  "brier_score_scores": 0.2624940331507355,
  "log_loss_scores": 0.79260674442257,
  "brier_score_probas_pos": 0.21212749424148317,
  "log_loss_probas_pos": 0.6085720189380257,
  "brier_score_probas_pos_neg": 0.2119198666748226,
  "log_loss_probas_pos_neg": 0.6087810543790794,
  "brier_score_probas_pos_neg_iso": 0.20489276532866985,
  "log_loss_probas_pos_neg_iso": 0.5878044108458256,
  "brier_score_probas_pos_iso": 0.2053603739395194,
  "log_loss_probas_pos_iso": 0.5910422088102314,
  "brier_score_probas_pos_sc": 0.2122140078457557,
  "log_loss_probas_pos_sc": 0.6091975243698445,
  "brier_score_probas_pos_neg_sc": 0.21209820981345825,
  "log_loss_probas_pos_neg_sc": 0.6088747173666084,
  "ece_scores": 0.20361488409651168,
  "ece_probas_pos": 0.07143433517723968,
  "ece_probas_pos_neg": 0.0685412009500627,
  "ece_probas_pos_neg_iso": 0.0117051686757507,
  "ece_probas_pos_iso": 0.01925589791540624,
  "ece_probas_pos_sc": 0.07015249906

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "ComplEx",
  "loss": "self_adversarial",
  "brier_score_scores": 0.48998108347093694,
  "log_loss_scores": 6.0610655380430405,
  "brier_score_probas_pos": 0.2399739136666372,
  "log_loss_probas_pos": 0.6736364674216707,
  "brier_score_probas_pos_neg": 0.23989405978095957,
  "log_loss_probas_pos_neg": 0.6735588625507477,
  "brier_score_probas_pos_neg_iso": 0.22821497362818002,
  "log_loss_probas_pos_neg_iso": 0.6506634153019699,
  "brier_score_probas_pos_iso": 0.22841327222789629,
  "log_loss_probas_pos_iso": 0.6501271398063591,
  "brier_score_probas_pos_sc": 0.2399559600900672,
  "log_loss_probas_pos_sc": 0.673707433347442,
  "brier_score_probas_pos_neg_sc": 0.24004025312876176,
  "log_loss_probas_pos_neg_sc": 0.6737047290363949,
  "ece_scores": 0.4906965382968086,
  "ece_probas_pos": 0.09525912788983082,
  "ece_probas_pos_neg": 0.08895132536937975,
  "ece_probas_pos_neg_iso": 0.007095154239283816,
  "ece_probas_pos_iso": 0.014366896759496352,
  "ece_probas_pos_sc": 0.0977

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "ComplEx",
  "loss": "pairwise",
  "brier_score_scores": 0.22595793554892946,
  "log_loss_scores": 0.642541800152352,
  "brier_score_probas_pos": 0.21302253434501378,
  "log_loss_probas_pos": 0.6110140760243306,
  "brier_score_probas_pos_neg": 0.21285017148308993,
  "log_loss_probas_pos_neg": 0.6109207767385638,
  "brier_score_probas_pos_neg_iso": 0.2079395479136505,
  "log_loss_probas_pos_neg_iso": 0.5975388311422987,
  "brier_score_probas_pos_iso": 0.20819877219227456,
  "log_loss_probas_pos_iso": 0.5956650929131967,
  "brier_score_probas_pos_sc": 0.2128966714842871,
  "log_loss_probas_pos_sc": 0.610996562327206,
  "brier_score_probas_pos_neg_sc": 0.21302010060502194,
  "log_loss_probas_pos_neg_sc": 0.6110030684320242,
  "ece_scores": 0.11875036996136606,
  "ece_probas_pos": 0.06041101492033216,
  "ece_probas_pos_neg": 0.05949448428516471,
  "ece_probas_pos_neg_iso": 0.008257719137435608,
  "ece_probas_pos_iso": 0.021844404378771852,
  "ece_probas_pos_sc": 0.059226524029

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count
  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


{
  "model": "ComplEx",
  "loss": "nll",
  "brier_score_scores": 0.47546571242906077,
  "log_loss_scores": 5.797461947070429,
  "brier_score_probas_pos": 0.2329036154193148,
  "log_loss_probas_pos": 0.6569482001384209,
  "brier_score_probas_pos_neg": 0.23272727754355887,
  "log_loss_probas_pos_neg": 0.6568614785949567,
  "brier_score_probas_pos_neg_iso": 0.21862520044078643,
  "log_loss_probas_pos_neg_iso": 0.6244158019674226,
  "brier_score_probas_pos_iso": 0.2187998019896913,
  "log_loss_probas_pos_iso": 0.6219035625076065,
  "brier_score_probas_pos_sc": 0.23290018845724947,
  "log_loss_probas_pos_sc": 0.6570058795577409,
  "brier_score_probas_pos_neg_sc": 0.2328908664032037,
  "log_loss_probas_pos_neg_sc": 0.6569923446147136,
  "ece_scores": 0.479877076584059,
  "ece_probas_pos": 0.10471958104068466,
  "ece_probas_pos_neg": 0.10603957987794728,
  "ece_probas_pos_neg_iso": 0.006298873763638806,
  "ece_probas_pos_iso": 0.01166845821523682,
  "ece_probas_pos_sc": 0.10542885298819801,
 

  avg_pred_true = y_true[i_start:i_end].sum() / delta_count
  bin_centroid = y_prob[i_start:i_end].sum() / delta_count


In [7]:
import pandas as pd

In [8]:
def highlight_min(s):
    is_min = s == s.min()
    return ['font-weight: bold' if v else '' for v in is_min]

In [9]:
df = pd.DataFrame(results).set_index(['model', 'loss'])

In [10]:
bs = df[(c for c in df.columns if c.startswith('brier'))]
bs.columns = [c[len("brier_score_"):] for c in bs.columns]
bs.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,probas_pos,probas_pos_iso,probas_pos_neg,probas_pos_neg_iso,probas_pos_neg_sc,probas_pos_sc,scores
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TransE,self_adversarial,0.0910956,0.0874456,0.0891255,0.0873114,0.0892357,0.0898008,0.443464
TransE,pairwise,0.208565,0.20043,0.201634,0.198374,0.201751,0.202342,0.492508
TransE,nll,0.0932453,0.0878042,0.0929334,0.0876964,0.0931742,0.0938658,0.222011
TransE,multiclass_nll,0.203992,0.188505,0.203694,0.188442,0.203988,0.203981,0.49254
DistMult,self_adversarial,0.213618,0.208319,0.213457,0.2079,0.213611,0.213783,0.488378
DistMult,pairwise,0.217334,0.211477,0.217107,0.210692,0.217278,0.217388,0.22372
DistMult,nll,0.224345,0.21378,0.22416,0.213445,0.224352,0.224476,0.469119
DistMult,multiclass_nll,0.212127,0.20536,0.21192,0.204893,0.212098,0.212214,0.262494
ComplEx,self_adversarial,0.239974,0.228413,0.239894,0.228215,0.24004,0.239956,0.489981
ComplEx,pairwise,0.213023,0.208199,0.21285,0.20794,0.21302,0.212897,0.225958


In [11]:
ll = df[(c for c in df.columns if c.startswith('log_loss'))]
ll.columns = [c[len("log_loss_"):] for c in ll.columns]
ll.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,probas_pos,probas_pos_iso,probas_pos_neg,probas_pos_neg_iso,probas_pos_neg_sc,probas_pos_sc,scores
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TransE,self_adversarial,0.308255,0.295663,0.301611,0.295183,0.301699,0.304051,1.95908
TransE,pairwise,0.606337,0.589109,0.590702,0.585382,0.590642,0.591512,5.23391
TransE,nll,0.34213,0.300365,0.341685,0.298475,0.34193,0.343662,0.67005
TransE,multiclass_nll,0.599026,0.550196,0.598956,0.548983,0.599038,0.598987,7.6591
DistMult,self_adversarial,0.618242,0.601125,0.61821,0.603682,0.618237,0.618372,5.62464
DistMult,pairwise,0.62103,0.60569,0.621086,0.606481,0.621198,0.621638,0.63595
DistMult,nll,0.637838,0.607655,0.637868,0.611109,0.638026,0.637892,5.62026
DistMult,multiclass_nll,0.608572,0.591042,0.608781,0.587804,0.608875,0.609198,0.792607
ComplEx,self_adversarial,0.673636,0.650127,0.673559,0.650663,0.673705,0.673707,6.06107
ComplEx,pairwise,0.611014,0.595665,0.610921,0.597539,0.611003,0.610997,0.642542


In [12]:
print((bs.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'scores', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
 .round(3)
 .to_latex()))

\begin{tabular}{llrrrrr}
\toprule
{} &     model &  scores &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &    TransE &   0.443 &           0.089 &               0.087 &       0.091 &           0.087 \\
1 &  DistMult &   0.488 &           0.213 &               0.208 &       0.214 &           0.208 \\
2 &   ComplEx &   0.490 &           0.240 &               0.228 &       0.240 &           0.228 \\
\bottomrule
\end{tabular}



In [13]:
print(ll.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'scores', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrrr}
\toprule
{} &     model &  scores &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &    TransE &   1.959 &           0.302 &               0.295 &       0.308 &           0.296 \\
1 &  DistMult &   5.625 &           0.618 &               0.604 &       0.618 &           0.601 \\
2 &   ComplEx &   6.061 &           0.674 &               0.651 &       0.674 &           0.650 \\
\bottomrule
\end{tabular}



In [14]:
print(bs.reset_index()
 .query("model == 'TransE' ")
 [['loss',  'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrr}
\toprule
{} &              loss &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &  self\_adversarial &           0.089 &               0.087 &       0.091 &           0.087 \\
1 &          pairwise &           0.202 &               0.198 &       0.209 &           0.200 \\
2 &               nll &           0.093 &               0.088 &       0.093 &           0.088 \\
3 &    multiclass\_nll &           0.204 &               0.188 &       0.204 &           0.189 \\
\bottomrule
\end{tabular}



In [15]:
print(ll.reset_index()
 .query("model == 'TransE' ")
 [['loss', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrr}
\toprule
{} &              loss &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &  self\_adversarial &           0.302 &               0.295 &       0.308 &           0.296 \\
1 &          pairwise &           0.591 &               0.585 &       0.606 &           0.589 \\
2 &               nll &           0.342 &               0.298 &       0.342 &           0.300 \\
3 &    multiclass\_nll &           0.599 &               0.549 &       0.599 &           0.550 \\
\bottomrule
\end{tabular}



In [16]:
print((acc*100).reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'pos_neg', 'pos_neg_iso', 'pos', 'pos_iso', 'uncalib', 'per_relation']]
 .reset_index(drop=True)
  .round(1)
 .to_latex())

NameError: name 'acc' is not defined

In [None]:
metrics = df[(c for c in df.columns if c.startswith('metrics'))]
metrics.columns = [c[len("metrics_"):] for c in metrics.columns]
metrics

In [None]:
def highlight_max(s):
    is_min = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_min]

acc = df[(c for c in df.columns if c.startswith('accuracy'))]
acc.columns = [c[len("accuracy_"):] for c in acc.columns]
acc.style.apply(highlight_max, axis=1)

In [None]:
df.corr(method='spearman').reset_index().query("index.str.startswith('accuracy')")[['index', 'log_loss_probas_pos_neg', 'log_loss_probas_pos_neg_iso', 'log_loss_probas_pos', 'log_loss_probas_pos_iso']]

In [None]:
df.corr(method='spearman').reset_index().query("index.str.startswith('accuracy')")[['index', 'brier_score_probas_pos_neg', 'brier_score_probas_pos_neg_iso', 'brier_score_probas_pos', 'brier_score_probas_pos_iso']]