In [1]:
import numpy as np
import itertools
import json

from sklearn.isotonic import IsotonicRegression
from sklearn.calibration import calibration_curve, _SigmoidCalibration, _sigmoid_calibration
from ampligraph.evaluation import evaluate_performance, mr_score, mrr_score, hits_at_n_score, generate_corruptions_for_eval
from sklearn.metrics import brier_score_loss, log_loss, accuracy_score
from scipy.special import expit

from ampligraph.datasets import load_wn11
from ampligraph.latent_features.models import TransE, ComplEx, DistMult, ConvKB, HolE
import types

In [2]:
from generate_corruptions import generate_corruptions, calibration_loss, pos_iso

In [3]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [4]:
X = load_wn11()

In [5]:
X_valid_pos = X['valid'][X['valid_labels']]
X_valid_neg = X['valid'][~X['valid_labels']]

X_test_pos = X['test'][X['test_labels']]
X_test_neg = X['test'][~X['test_labels']]

In [None]:
losses =  ['self_adversarial', 'pairwise', 'nll', 'multiclass_nll']
models = [TransE, DistMult, ComplEx, HolE]

results = []

for m, l in itertools.product(models, losses):
    model = m(batches_count=64, seed=0, epochs=1000, k=100, eta=20,
                   optimizer='adam', optimizer_params={'lr':0.0001},
                   loss=l, verbose=False)

    model.fit(X['train'])
    
    scores = model.predict(X['test'])

    model.calibrate(X_valid_pos, batches_count=10, epochs=1000, positive_base_rate=0.5)
    print("pos", model.calibration_parameters)
    probas1 = model.predict_proba(X['test'])

    model.calibrate(X_valid_pos, X_valid_neg)
    print("pos neg", model.calibration_parameters)
    probas2 = model.predict_proba(X['test'])
    
    val_scores = model.predict(X['valid'])
    ir = IsotonicRegression(out_of_bounds='clip')
    ir.fit(np.squeeze(val_scores).astype(float), (X['valid_labels']).astype(float))
    probas3 = ir.predict(np.squeeze(scores).astype(float))
    
    model.generate_corruptions = types.MethodType(generate_corruptions, model)
    corruptions = model.generate_corruptions(X_valid_pos, batches_count=10, epochs=1000)
    val_pos_scores = np.squeeze(model.predict(X_valid_pos))
    iso_pos = pos_iso(IsotonicRegression(out_of_bounds='clip'), val_pos_scores, corruptions, positive_base_rate=0.5)
    probas4 = iso_pos.predict(np.squeeze(scores).astype(float))

    sc_pos = pos_iso(_SigmoidCalibration(), val_pos_scores, corruptions, positive_base_rate=0.5)
    print("pos sc", sc_pos.a_, sc_pos.b_)
    probas5 = sc_pos.predict(np.squeeze(scores).astype(float))
    
    val_neg_scores = np.squeeze(model.predict(X_valid_neg))
    sc_pos_neg = pos_iso(_SigmoidCalibration(), val_pos_scores, val_neg_scores, positive_base_rate=0.5)
    print("pos neg sc", sc_pos_neg.a_, sc_pos_neg.b_)
    probas6 = sc_pos_neg.predict(np.squeeze(scores).astype(float))
    
    thresholds = {r: np.median(np.sort(val_scores[X['valid'][:, 1] == r])) for r in np.unique(X['valid'][:, 1])}
    thresholds_test = np.vectorize(thresholds.get)(X['test'][:, 1])
    per_relation_acc = accuracy_score(X['test_labels'], scores > thresholds_test)

    acc_uncalib = accuracy_score(X['test_labels'], expit(scores) > 0.5)
    
    acc1 = accuracy_score(X['test_labels'], probas1 > 0.5)
    acc2 = accuracy_score(X['test_labels'], probas2 > 0.5)
    acc3 = accuracy_score(X['test_labels'], probas3 > 0.5)
    acc4 = accuracy_score(X['test_labels'], probas4 > 0.5)
    acc5 = accuracy_score(X['test_labels'], probas5 > 0.5)
    acc6 = accuracy_score(X['test_labels'], probas6 > 0.5)
    
    filter_triples = np.concatenate((X['train'], X_valid_pos, X_test_pos))
    ranks = evaluate_performance(X_test_pos, 
                                 model=model, 
                                 filter_triples=filter_triples,
                                 use_default_protocol=True, 
                                 verbose=False)

    results.append({
        'model': m.__name__,
        'loss': l,
        'brier_score_scores': brier_score_loss(X['test_labels'], expit(scores)),
        'log_loss_scores': log_loss(X['test_labels'], expit(scores), eps=1e-7),
        'brier_score_probas_pos': brier_score_loss(X['test_labels'], probas1),
        'log_loss_probas_pos': log_loss(X['test_labels'], probas1, eps=1e-7),
        'brier_score_probas_pos_neg': brier_score_loss(X['test_labels'], probas2),
        'log_loss_probas_pos_neg': log_loss(X['test_labels'], probas2, eps=1e-7),
        'brier_score_probas_pos_neg_iso': brier_score_loss(X['test_labels'], probas3),
        'log_loss_probas_pos_neg_iso': log_loss(X['test_labels'], probas3, eps=1e-7),
        'brier_score_probas_pos_iso': brier_score_loss(X['test_labels'], probas4),
        'log_loss_probas_pos_iso': log_loss(X['test_labels'], probas4, eps=1e-7),
        'brier_score_probas_pos_sc': brier_score_loss(X['test_labels'], probas5),
        'log_loss_probas_pos_sc': log_loss(X['test_labels'], probas5, eps=1e-7),
        'brier_score_probas_pos_neg_sc': brier_score_loss(X['test_labels'], probas6),
        'log_loss_probas_pos_neg_sc': log_loss(X['test_labels'], probas6, eps=1e-7),
        'ece_scores': calibration_loss(X['test_labels'], expit(scores)),
        'ece_probas_pos': calibration_loss(X['test_labels'], probas1),
        'ece_probas_pos_neg': calibration_loss(X['test_labels'], probas2),
        'ece_probas_pos_neg_iso': calibration_loss(X['test_labels'], probas3),
        'ece_probas_pos_iso': calibration_loss(X['test_labels'], probas4),
        'ece_probas_pos_sc': calibration_loss(X['test_labels'], probas5),
        'ece_probas_pos_neg_sc': calibration_loss(X['test_labels'], probas6),
        'metrics_mrr': mrr_score(ranks), 
        'metrics_hits@10': hits_at_n_score(ranks, n=10),
        'metrics_mr': mr_score(ranks),
        'accuracy_per_relation': per_relation_acc,
        'accuracy_uncalib': acc_uncalib,
        'accuracy_pos': acc1,
        'accuracy_pos_neg': acc2,
        'accuracy_pos_neg_iso': acc3,
        'accuracy_pos_iso': acc4,
        'accuracy_pos_sc': acc5,
        'accuracy_pos_neg_sc': acc6
    })
        
    print(json.dumps(results[-1], indent=2))

In [7]:
import pandas as pd

In [8]:
def highlight_min(s):
    is_min = s == s.min()
    return ['font-weight: bold' if v else '' for v in is_min]

In [17]:
df = pd.DataFrame(results).set_index(['model', 'loss'])
df.to_csv("main_results_wn11.csv", index=False)

In [10]:
bs = df[(c for c in df.columns if c.startswith('brier'))]
bs.columns = [c[len("brier_score_"):] for c in bs.columns]
bs.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,probas_pos,probas_pos_iso,probas_pos_neg,probas_pos_neg_iso,probas_pos_neg_sc,probas_pos_sc,scores
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TransE,self_adversarial,0.0911227,0.0873972,0.0891628,0.0872974,0.0892708,0.0898369,0.443472
TransE,pairwise,0.208565,0.20043,0.201634,0.198374,0.201751,0.202342,0.492508
TransE,nll,0.0932402,0.0877708,0.0929306,0.0877687,0.0931713,0.0938609,0.222008
TransE,multiclass_nll,0.20406,0.188635,0.203787,0.18849,0.204075,0.204056,0.492539
DistMult,self_adversarial,0.213618,0.208319,0.213457,0.2079,0.213611,0.213783,0.488378
DistMult,pairwise,0.217338,0.211455,0.217111,0.210665,0.217282,0.217392,0.223721
DistMult,nll,0.224345,0.21378,0.22416,0.213445,0.224352,0.224476,0.469119
DistMult,multiclass_nll,0.212127,0.20536,0.21192,0.204893,0.212098,0.212214,0.262494
ComplEx,self_adversarial,0.239974,0.228413,0.239894,0.228215,0.24004,0.239956,0.489981
ComplEx,pairwise,0.213033,0.20825,0.212861,0.208004,0.213031,0.212907,0.225967


In [11]:
ll = df[(c for c in df.columns if c.startswith('log_loss'))]
ll.columns = [c[len("log_loss_"):] for c in ll.columns]
ll.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,probas_pos,probas_pos_iso,probas_pos_neg,probas_pos_neg_iso,probas_pos_neg_sc,probas_pos_sc,scores
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
TransE,self_adversarial,0.308351,0.295608,0.301742,0.29519,0.301824,0.304179,1.95926
TransE,pairwise,0.606337,0.589109,0.590702,0.585382,0.590642,0.591512,5.23391
TransE,nll,0.342109,0.298839,0.34167,0.298653,0.341915,0.34364,0.670003
TransE,multiclass_nll,0.599069,0.549829,0.599019,0.549195,0.5991,0.599037,7.65297
DistMult,self_adversarial,0.618242,0.601125,0.61821,0.603682,0.618237,0.618372,5.62464
DistMult,pairwise,0.621037,0.605643,0.621093,0.606392,0.621206,0.621645,0.635952
DistMult,nll,0.637838,0.607655,0.637868,0.611109,0.638026,0.637892,5.62026
DistMult,multiclass_nll,0.608572,0.591042,0.608781,0.587805,0.608875,0.609198,0.792607
ComplEx,self_adversarial,0.673636,0.650127,0.673559,0.650664,0.673705,0.673707,6.06107
ComplEx,pairwise,0.611038,0.595796,0.610944,0.597727,0.611027,0.61102,0.642561


In [12]:
print((bs.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'scores', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
 .round(3)
 .to_latex()))

\begin{tabular}{llrrrrr}
\toprule
{} &     model &  scores &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &    TransE &   0.443 &           0.089 &               0.087 &       0.091 &           0.087 \\
1 &  DistMult &   0.488 &           0.213 &               0.208 &       0.214 &           0.208 \\
2 &   ComplEx &   0.490 &           0.240 &               0.228 &       0.240 &           0.228 \\
3 &      HolE &   0.474 &           0.235 &               0.235 &       0.235 &           0.236 \\
\bottomrule
\end{tabular}



In [13]:
print(ll.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'scores', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrrr}
\toprule
{} &     model &  scores &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &    TransE &   1.959 &           0.302 &               0.295 &       0.308 &           0.296 \\
1 &  DistMult &   5.625 &           0.618 &               0.604 &       0.618 &           0.601 \\
2 &   ComplEx &   6.061 &           0.674 &               0.651 &       0.674 &           0.650 \\
3 &      HolE &   2.731 &           0.663 &               0.661 &       0.663 &           0.668 \\
\bottomrule
\end{tabular}



In [14]:
print(bs.reset_index()
 .query("model == 'TransE' ")
 [['loss',  'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrr}
\toprule
{} &              loss &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &  self\_adversarial &           0.089 &               0.087 &       0.091 &           0.087 \\
1 &          pairwise &           0.202 &               0.198 &       0.209 &           0.200 \\
2 &               nll &           0.093 &               0.088 &       0.093 &           0.088 \\
3 &    multiclass\_nll &           0.204 &               0.188 &       0.204 &           0.189 \\
\bottomrule
\end{tabular}



In [15]:
print(ll.reset_index()
 .query("model == 'TransE' ")
 [['loss', 'probas_pos_neg', 'probas_pos_neg_iso', 'probas_pos', 'probas_pos_iso']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrrr}
\toprule
{} &              loss &  probas\_pos\_neg &  probas\_pos\_neg\_iso &  probas\_pos &  probas\_pos\_iso \\
\midrule
0 &  self\_adversarial &           0.302 &               0.295 &       0.308 &           0.296 \\
1 &          pairwise &           0.591 &               0.585 &       0.606 &           0.589 \\
2 &               nll &           0.342 &               0.299 &       0.342 &           0.299 \\
3 &    multiclass\_nll &           0.599 &               0.549 &       0.599 &           0.550 \\
\bottomrule
\end{tabular}



In [25]:
acc = df[(c for c in df.columns if c.startswith('accuracy'))]
acc.columns = [c[len("accuracy_"):] for c in acc.columns]
acc.style.apply(highlight_min, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,per_relation,pos,pos_iso,pos_neg,pos_neg_iso,pos_neg_sc,pos_sc,uncalib
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
TransE,self_adversarial,0.882066,0.888663,0.888917,0.888054,0.888968,0.887648,0.888613,0.50746
TransE,pairwise,0.705623,0.691008,0.68913,0.690551,0.695626,0.689282,0.689282,0.50746
TransE,nll,0.879326,0.878565,0.882625,0.879123,0.882422,0.878565,0.878362,0.50746
TransE,multiclass_nll,0.697859,0.679387,0.715061,0.68365,0.715011,0.67959,0.67888,0.50746
DistMult,self_adversarial,0.67213,0.664062,0.671217,0.664671,0.67208,0.664823,0.660865,0.508373
DistMult,pairwise,0.625241,0.63808,0.658683,0.641328,0.656907,0.638181,0.637928,0.601644
DistMult,nll,0.611235,0.61905,0.646808,0.623465,0.649295,0.619456,0.616969,0.521719
DistMult,multiclass_nll,0.639957,0.653659,0.671775,0.657921,0.671927,0.655587,0.654471,0.611337
ComplEx,self_adversarial,0.595605,0.599564,0.624378,0.606059,0.624328,0.599259,0.597026,0.508069
ComplEx,pairwise,0.639247,0.650513,0.663706,0.652644,0.663554,0.650208,0.651527,0.609155


In [26]:
print((acc*100).reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'pos_neg', 'pos_neg_iso', 'pos', 'pos_iso', 'uncalib', 'per_relation']]
 .reset_index(drop=True)
  .round(1)
 .to_latex())

\begin{tabular}{llrrrrrr}
\toprule
{} &     model &  pos\_neg &  pos\_neg\_iso &   pos &  pos\_iso &  uncalib &  per\_relation \\
\midrule
0 &    TransE &     88.8 &         88.9 &  88.9 &     88.9 &     50.7 &          88.2 \\
1 &  DistMult &     66.5 &         67.2 &  66.4 &     67.1 &     50.8 &          67.2 \\
2 &   ComplEx &     60.6 &         62.4 &  60.0 &     62.4 &     50.8 &          59.6 \\
3 &      HolE &     59.3 &         59.0 &  59.3 &     59.0 &     50.9 &          60.8 \\
\bottomrule
\end{tabular}



In [21]:
metrics = df[(c for c in df.columns if c.startswith('metrics'))]
metrics.columns = [c[len("metrics_"):] for c in metrics.columns]
metrics

Unnamed: 0_level_0,Unnamed: 1_level_0,hits@10,mr,mrr
model,loss,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
TransE,self_adversarial,0.308572,2288.616938,0.155053
TransE,pairwise,0.134247,7259.422007,0.057914
TransE,nll,0.226973,2514.538224,0.133938
TransE,multiclass_nll,0.234134,7455.075314,0.108405
DistMult,self_adversarial,0.08129,9999.898825,0.044532
DistMult,pairwise,0.092881,12092.537606,0.050656
DistMult,nll,0.092108,12594.588038,0.049014
DistMult,multiclass_nll,0.108902,11518.330105,0.060369
ComplEx,self_adversarial,0.093859,13814.839738,0.053883
ComplEx,pairwise,0.114053,11376.982124,0.067699


In [22]:
print(metrics.reset_index()
 .query("loss == 'self_adversarial' ")
 [['model', 'mr', 'mrr', 'hits@10']]
 .reset_index(drop=True)
  .round(3)
 .to_latex())

\begin{tabular}{llrrr}
\toprule
{} &     model &         mr &    mrr &  hits@10 \\
\midrule
0 &    TransE &   2288.617 &  0.155 &    0.309 \\
1 &  DistMult &   9999.899 &  0.045 &    0.081 \\
2 &   ComplEx &  13814.840 &  0.054 &    0.094 \\
3 &      HolE &  13354.699 &  0.017 &    0.035 \\
\bottomrule
\end{tabular}



In [None]:
def highlight_max(s):
    is_min = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_min]

acc = df[(c for c in df.columns if c.startswith('accuracy'))]
acc.columns = [c[len("accuracy_"):] for c in acc.columns]
acc.style.apply(highlight_max, axis=1)

In [None]:
df.corr(method='spearman').reset_index().query("index.str.startswith('accuracy')")[['index', 'log_loss_probas_pos_neg', 'log_loss_probas_pos_neg_iso', 'log_loss_probas_pos', 'log_loss_probas_pos_iso']]

In [None]:
df.corr(method='spearman').reset_index().query("index.str.startswith('accuracy')")[['index', 'brier_score_probas_pos_neg', 'brier_score_probas_pos_neg_iso', 'brier_score_probas_pos', 'brier_score_probas_pos_iso']]

In [18]:
model = TransE(batches_count=64, seed=0, epochs=1000, k=100, eta=20,
               optimizer='adam', optimizer_params={'lr':0.0001},
               loss='self_adversarial', verbose=False)

model.fit(X['train'])

scores = model.predict(X['test'])

model.calibrate(X_valid_pos, batches_count=10, epochs=1000, positive_base_rate=0.5)
print("pos", model.calibration_parameters)
probas1 = model.predict_proba(X['test'])

model.calibrate(X_valid_pos, X_valid_neg)
print("pos neg", model.calibration_parameters)
probas2 = model.predict_proba(X['test'])

val_scores = model.predict(X['valid'])

thresholds = {r: np.median(np.sort(val_scores[X['valid'][:, 1] == r])) for r in np.unique(X['valid'][:, 1])}
thresholds_test = np.vectorize(thresholds.get)(X['test'][:, 1])
per_relation_acc = accuracy_score(X['test_labels'], scores > thresholds_test)

print(thresholds)


pos [-1.1740658, -6.8939967]
pos neg [-1.4084864, -8.3745365]
{'_domain_region': -6.0069733, '_domain_topic': -5.5207396, '_has_instance': -6.2901406, '_has_part': -5.673306, '_member_holonym': -6.3117476, '_member_meronym': -5.982978, '_part_of': -5.798244, '_similar_to': -6.852225, '_subordinate_instance_of': -5.4750223, '_synset_domain_topic': -6.6392403, '_type_of': -6.743014}


In [20]:
per_relation_acc

0.8819648837917385