# Hits@Kの計算方法の確認
## 目的
pykeenのHits@Kの計算結果がおかしいので，自分で実装したものを比較してみる．

## 
moduels

In [47]:
import torch
import json
import pandas as pd
import numpy as np
from IPython.display import display, HTML
from pykeen.datasets import get_dataset
from pykeen.evaluation import RankBasedEvaluator

## variables, functions and classes

In [67]:
def vanilla_hits_at_k(kge_model, mapped_triples, ks=[1,3,5,10], batch_size=100, output_type='both'):

    dict_isin_at_k = {}
    dict_hits_at_k = {}
    
    dict_isin_at_k['tail'] = {k:[] for k in ks}
    dict_isin_at_k['head'] = {k:[] for k in ks}

    n_all = triples.shape[0]
    batch_size = 100

    for _type in ['head', 'tail']:
        
        for i in range(0,n_all,batch_size):
            
            j = min(i+batch_size,n_all)

            if _type == 'tail':
                # a list of tails
                nids = triples[i:j,2]
                # a list of score(h,r,*)
                scores = kge_model.score_t(triples[i:j,:2])
            else:
                # a list of heads
                nids = triples[i:j,0]
                # a list of score(*,r,t)
                scores = kge_model.score_h(triples[i:j,1:])

            # for each head/tail
            for nid, _scores in zip(nids, scores):
                # sort scores
                _sorted_nid = torch.argsort(_scores,descending=True)
                # judge if head/tail is in top k
                for k in ks:
                    dict_isin_at_k[_type][k].append((nid in _sorted_nid[:k]))

        dict_hits_at_k[_type] = {}
        for k, list_isin in dict_isin_at_k[_type].items():
            dict_hits_at_k[_type][k] = sum(list_isin)/len(list_isin)

    dict_hits_at_k['both'] = {}
    for k in [1,3,5,10]:
        v1 = dict_hits_at_k['tail'][k]
        v2 = dict_hits_at_k['head'][k]
        dict_hits_at_k['both'][k] = (v1 + v2)/2.0   

    if output_type == None:
        return dict_hits_at_k
    else:
        return dict_hits_at_k[output_type]

## parameters

In [3]:
# データセットの名称
dataset = 'fb15k237'
# 学習済みのモデルのあるディレクトリ
dir_model = './models/20240606/fb15k237_transe_no_option'

## main

### データセットを読み込む

In [4]:
dataset = get_dataset(dataset=dataset,dataset_kwargs={'create_inverse_triples':True})

### 学習済みの知識グラフ埋め込みモデルを読み込む

In [5]:
# embedding model
kge_model = torch.load(f'{dir_model}/trained_model.pkl')

In [6]:
## embeddign model information
with open(f'{dir_model}/results.json') as fin:
    dict_model_info = json.load(fin)

### pykeenで計算されたhits@kの取得

In [48]:
for _type in ['both']:
    display(HTML(f'<h4>{_type}</h4><hr>'))
    for k in [1,3,5,10]:
        val = dict_model_info['metrics'][_type]['realistic'][f'hits_at_{k}']
        display(HTML(f'Hits@{k}: {val}'))

### 独自の実装でhits@kを計算

In [81]:
dict_vanilla_hits_at_k = vanilla_hits_at_k(kge_model, dataset.testing.mapped_triples)

### pykeenのRankBasedEvaluatorを使って計算

↓下記に説明のある"filtered setting"を使っていることに注意．

<blockquate>
<p>The rank-based evaluation allows using the “filtered setting”, proposed by <a class="reference internal" href="../references.html#bordes2013" id="id1"><span>[bordes2013]</span></a>, which is enabled by default.
When evaluating the tail prediction for a triple <span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="12" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mo class="mjx-n"><mjx-c class="mjx-c28"></mjx-c></mjx-mo><mjx-mi class="mjx-i"><mjx-c class="mjx-c210E TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D45F TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c29"></mjx-c></mjx-mo></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><mi>h</mi><mo>,</mo><mi>r</mi><mo>,</mo><mi>t</mi><mo stretchy="false">)</mo></math></mjx-assistive-mml></mjx-container></span>, i.e. scoring all triples <span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="13" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mo class="mjx-n"><mjx-c class="mjx-c28"></mjx-c></mjx-mo><mjx-mi class="mjx-i"><mjx-c class="mjx-c210E TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D45F TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D452 TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c29"></mjx-c></mjx-mo></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><mi>h</mi><mo>,</mo><mi>r</mi><mo>,</mo><mi>e</mi><mo stretchy="false">)</mo></math></mjx-assistive-mml></mjx-container></span>, there
may be additional known triples <span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="14" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mo class="mjx-n"><mjx-c class="mjx-c28"></mjx-c></mjx-mo><mjx-mi class="mjx-i"><mjx-c class="mjx-c210E TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D45F TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-msup space="2"><mjx-mi class="mjx-i"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-script style="vertical-align: 0.363em;"><mjx-mo class="mjx-var" size="s"><mjx-c class="mjx-c2032"></mjx-c></mjx-mo></mjx-script></mjx-msup><mjx-mo class="mjx-n"><mjx-c class="mjx-c29"></mjx-c></mjx-mo></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><mi>h</mi><mo>,</mo><mi>r</mi><mo>,</mo><msup><mi>t</mi><mo data-mjx-alternate="1">′</mo></msup><mo stretchy="false">)</mo></math></mjx-assistive-mml></mjx-container></span> for <span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="15" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mi class="mjx-i"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n" space="4"><mjx-c class="mjx-c2260"></mjx-c></mjx-mo><mjx-msup space="4"><mjx-mi class="mjx-i"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-script style="vertical-align: 0.363em;"><mjx-mo class="mjx-var" size="s"><mjx-c class="mjx-c2032"></mjx-c></mjx-mo></mjx-script></mjx-msup></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mi>t</mi><mo>≠</mo><msup><mi>t</mi><mo data-mjx-alternate="1">′</mo></msup></math></mjx-assistive-mml></mjx-container></span>. If the model predicts a higher score for
<span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="16" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mo class="mjx-n"><mjx-c class="mjx-c28"></mjx-c></mjx-mo><mjx-mi class="mjx-i"><mjx-c class="mjx-c210E TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D45F TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-msup space="2"><mjx-mi class="mjx-i"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-script style="vertical-align: 0.363em;"><mjx-mo class="mjx-var" size="s"><mjx-c class="mjx-c2032"></mjx-c></mjx-mo></mjx-script></mjx-msup><mjx-mo class="mjx-n"><mjx-c class="mjx-c29"></mjx-c></mjx-mo></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><mi>h</mi><mo>,</mo><mi>r</mi><mo>,</mo><msup><mi>t</mi><mo data-mjx-alternate="1">′</mo></msup><mo stretchy="false">)</mo></math></mjx-assistive-mml></mjx-container></span>, the rank will increase, and hence the measured model performance will decrease. However, giving
<span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="17" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mo class="mjx-n"><mjx-c class="mjx-c28"></mjx-c></mjx-mo><mjx-mi class="mjx-i"><mjx-c class="mjx-c210E TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D45F TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-msup space="2"><mjx-mi class="mjx-i"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-script style="vertical-align: 0.363em;"><mjx-mo class="mjx-var" size="s"><mjx-c class="mjx-c2032"></mjx-c></mjx-mo></mjx-script></mjx-msup><mjx-mo class="mjx-n"><mjx-c class="mjx-c29"></mjx-c></mjx-mo></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><mi>h</mi><mo>,</mo><mi>r</mi><mo>,</mo><msup><mi>t</mi><mo data-mjx-alternate="1">′</mo></msup><mo stretchy="false">)</mo></math></mjx-assistive-mml></mjx-container></span> a high score (and thus a low rank) is desirable since it is a true triple as well. Thus, the
filtered evaluation setting ignores for a given triple <span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="18" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mo class="mjx-n"><mjx-c class="mjx-c28"></mjx-c></mjx-mo><mjx-mi class="mjx-i"><mjx-c class="mjx-c210E TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D45F TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c29"></mjx-c></mjx-mo></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><mi>h</mi><mo>,</mo><mi>r</mi><mo>,</mo><mi>t</mi><mo stretchy="false">)</mo></math></mjx-assistive-mml></mjx-container></span> the scores of all other <em>known</em> true triples
<span class="math notranslate nohighlight"><mjx-container class="MathJax CtxtMenu_Attached_0" jax="CHTML" tabindex="0" ctxtmenu_counter="19" style="font-size: 114.5%; position: relative;"><mjx-math class="MJX-TEX" aria-hidden="true"><mjx-mo class="mjx-n"><mjx-c class="mjx-c28"></mjx-c></mjx-mo><mjx-mi class="mjx-i"><mjx-c class="mjx-c210E TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-mi class="mjx-i" space="2"><mjx-c class="mjx-c1D45F TEX-I"></mjx-c></mjx-mi><mjx-mo class="mjx-n"><mjx-c class="mjx-c2C"></mjx-c></mjx-mo><mjx-msup space="2"><mjx-mi class="mjx-i"><mjx-c class="mjx-c1D461 TEX-I"></mjx-c></mjx-mi><mjx-script style="vertical-align: 0.363em;"><mjx-mo class="mjx-var" size="s"><mjx-c class="mjx-c2032"></mjx-c></mjx-mo></mjx-script></mjx-msup><mjx-mo class="mjx-n"><mjx-c class="mjx-c29"></mjx-c></mjx-mo></mjx-math><mjx-assistive-mml unselectable="on" display="inline"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><mi>h</mi><mo>,</mo><mi>r</mi><mo>,</mo><msup><mi>t</mi><mo data-mjx-alternate="1">′</mo></msup><mo stretchy="false">)</mo></math></mjx-assistive-mml></mjx-container></span>.</p>
</blockquate>

[ドキュメンテーション](https://pykeen.readthedocs.io/en/stable/api/pykeen.evaluation.RankBasedEvaluator.html)

In [70]:
evaluator = RankBasedEvaluator(filtered=True, metrics=['hits@k'])

In [85]:
dict_pykeen_hits_at_k = {}

#### filter無

In [86]:
%%capture --no-stderr
results = evaluator.evaluate(kge_model, dataset.testing.mapped_triples, 
                             additional_filter_triples=[dataset.training.mapped_triples])

In [87]:
dict_pykeen_hits_at_k['self-filtered'] = {k:results.get_metric(f'hits_at_{k}') for k in [1,3,5,10]}

#### filter有

In [88]:
%%capture --no-stderr
results = evaluator.evaluate(kge_model, triples, additional_filter_triples=[dataset.training.mapped_triples, dataset.validation.mapped_triples])

In [89]:
dict_pykeen_hits_at_k['filtered'] = {k:results.get_metric(f'hits_at_{k}') for k in [1,3,5,10]}

### 計算結果

**モデル学習時に計算されたhits@k**

In [80]:
for k in [1,3,5,10]:
    val = dict_model_info['metrics']['both']['realistic'][f'hits_at_{k}']
    display(HTML(f'Hits@{k}: {val}'))

**filter無（独自実装）**

In [82]:
for k, val in dict_vanilla_hits_at_k.items():
    display(HTML(f'Hits@{k}: {val}'))

**filter有（テストデータのみ）**  
filter無の場合よりも良い評価になる

In [90]:
for k, val in dict_pykeen_hits_at_k['self-filtered'].items():
    display(HTML(f'Hits@{k}: {val}'))

**filer有（教師データ，テストデータ，検証データ）**  
- モデル学習時に計算されるものと一致する.
- モデル学習時に計算されるHits@kはテストデータを対象に，自身に加え，教師データ，検証データでfilterをしたものが計算されていると考えられる．
- 基本的に，filterするとHits@kの精度は向上する． 

In [91]:
for k, val in dict_pykeen_hits_at_k['filtered'].items():
    display(HTML(f'Hits@{k}: {val}'))