-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented hits@k metric for link-prediction #675
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
""" | ||
import logging | ||
import abc | ||
from collections import defaultdict | ||
from statistics import mean | ||
import torch as th | ||
|
||
|
@@ -26,7 +27,6 @@ | |
EARLY_STOP_CONSECUTIVE_INCREASE_STRATEGY, | ||
LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL) | ||
from ..utils import get_rank, get_world_size, barrier | ||
from .utils import gen_mrr_score | ||
|
||
def early_stop_avg_increase_judge(val_score, val_perf_list, comparator): | ||
""" | ||
|
@@ -781,12 +781,14 @@ def val_perf_rank_list(self): | |
|
||
|
||
class GSgnnMrrLPEvaluator(GSgnnLPEvaluator): | ||
""" The class for link prediction evaluation using Mrr metric. | ||
""" The class for link prediction evaluation using Mrr or Hits metrics. | ||
|
||
Parameters | ||
---------- | ||
eval_frequency: int | ||
The frequency (# of iterations) of doing evaluation. | ||
eval_metric: list of string | ||
Evaluation metric used during evaluation. | ||
data: GSgnnEdgeData | ||
The processed dataset | ||
num_negative_edges_eval: int | ||
|
@@ -802,14 +804,16 @@ class GSgnnMrrLPEvaluator(GSgnnLPEvaluator): | |
early_stop_strategy: str | ||
The early stop strategy. GraphStorm supports two strategies: | ||
1) consecutive_increase and 2) average_increase. | ||
k: int | ||
k used for computing metrics such as hits@k. | ||
""" | ||
def __init__(self, eval_frequency, data, | ||
def __init__(self, eval_frequency, eval_metric, data, | ||
num_negative_edges_eval, lp_decoder_type, | ||
use_early_stop=False, | ||
early_stop_burnin_rounds=0, | ||
early_stop_rounds=3, | ||
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): | ||
eval_metric = ["mrr"] | ||
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY, | ||
k=100): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better to define a hit@k evaluator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that will lead to a lot of duplicated codes. I saw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have two choices:
I would prefer the first one. |
||
super(GSgnnMrrLPEvaluator, self).__init__(eval_frequency, | ||
eval_metric, use_early_stop, early_stop_burnin_rounds, | ||
early_stop_rounds, early_stop_strategy) | ||
|
@@ -819,7 +823,7 @@ def __init__(self, eval_frequency, data, | |
self.num_negative_edges_eval = num_negative_edges_eval | ||
self.lp_decoder_type = lp_decoder_type | ||
|
||
self.metrics_obj = LinkPredictionMetrics() | ||
self.metrics_obj = LinkPredictionMetrics(k) | ||
|
||
self._best_val_score = {} | ||
self._best_test_score = {} | ||
|
@@ -834,7 +838,7 @@ def compute_score(self, rankings, train=False): # pylint:disable=unused-argument | |
|
||
Parameters | ||
---------- | ||
rankings: dict of tensors | ||
rankings: dict of tensors or tensor | ||
Rankings of positive scores in format of {etype: ranking} | ||
train: bool | ||
TODO: Reversed for future use cases when we want to use different | ||
|
@@ -845,25 +849,30 @@ def compute_score(self, rankings, train=False): # pylint:disable=unused-argument | |
------- | ||
Evaluation metric values: dict | ||
""" | ||
# We calculate global mrr, etype is ignored. | ||
# User can develop its own per etype MRR evaluator | ||
ranking = [] | ||
for _, rank in rankings.items(): | ||
ranking.append(rank) | ||
ranking = th.cat(ranking, dim=0) | ||
# We calculate global lp metrics, etype is ignored. | ||
# User can develop its own per etype LP evaluator | ||
if isinstance(rankings, dict): | ||
ranking = [] | ||
for _, rank in rankings.items(): | ||
ranking.append(rank) | ||
ranking = th.cat(ranking, dim=0) | ||
else: | ||
ranking = rankings | ||
|
||
metrics = gen_mrr_score(ranking) | ||
scores = {} | ||
for metric in self._metric: | ||
scores[metric] = self.metrics_obj.metric_function[metric](ranking) | ||
|
||
return_scores = {} | ||
# When world size == 1, we do not need the barrier | ||
if get_world_size() > 1: | ||
barrier() | ||
for _, metric_val in metrics.items(): | ||
for metric, metric_val in scores.items(): | ||
th.distributed.all_reduce(metric_val) | ||
return_metrics = {} | ||
for metric, metric_val in metrics.items(): | ||
for metric, metric_val in scores.items(): | ||
return_metric = metric_val / get_world_size() | ||
return_metrics[metric] = return_metric.item() | ||
return return_metrics | ||
return_scores[metric] = return_metric.item() | ||
return return_scores | ||
|
||
def evaluate(self, val_scores, test_scores, total_iters): | ||
""" GSgnnLinkPredictionModel.fit() will call this function to do user defined evalution. | ||
|
@@ -888,7 +897,7 @@ def evaluate(self, val_scores, test_scores, total_iters): | |
if test_scores is not None: | ||
test_score = self.compute_score(test_scores) | ||
else: | ||
test_score = {"mrr": "N/A"} # Dummy | ||
test_score = {metric: "N/A" for metric in self.metric} # Dummy | ||
|
||
if val_scores is not None: | ||
val_score = self.compute_score(val_scores) | ||
|
@@ -902,7 +911,7 @@ def evaluate(self, val_scores, test_scores, total_iters): | |
self._best_test_score[metric] = test_score[metric] | ||
self._best_iter[metric] = total_iters | ||
else: | ||
val_score = {"mrr": "N/A"} # Dummy | ||
val_score = {metric: "N/A" for metric in self.metric} # Dummy | ||
|
||
return val_score, test_score | ||
|
||
|
@@ -915,6 +924,8 @@ class GSgnnPerEtypeMrrLPEvaluator(GSgnnMrrLPEvaluator): | |
---------- | ||
eval_frequency: int | ||
The frequency (# of iterations) of doing evaluation. | ||
eval_metric: list of string | ||
Evaluation metric used during evaluation. | ||
data: GSgnnEdgeData | ||
The processed dataset | ||
num_negative_edges_eval: int | ||
|
@@ -932,16 +943,19 @@ class GSgnnPerEtypeMrrLPEvaluator(GSgnnMrrLPEvaluator): | |
early_stop_strategy: str | ||
The early stop strategy. GraphStorm supports two strategies: | ||
1) consecutive_increase and 2) average_increase. | ||
k: int | ||
k used for computing metrics such as hits@k. | ||
""" | ||
def __init__(self, eval_frequency, data, | ||
def __init__(self, eval_frequency, eval_metric, data, | ||
num_negative_edges_eval, lp_decoder_type, | ||
major_etype = LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL, | ||
use_early_stop=False, | ||
early_stop_burnin_rounds=0, | ||
early_stop_rounds=3, | ||
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY): | ||
early_stop_strategy=EARLY_STOP_AVERAGE_INCREASE_STRATEGY, | ||
k=100): | ||
self.major_etype = major_etype | ||
super(GSgnnPerEtypeMrrLPEvaluator, self).__init__(eval_frequency, | ||
super(GSgnnPerEtypeMrrLPEvaluator, self).__init__(eval_frequency, eval_metric, | ||
data, num_negative_edges_eval, lp_decoder_type, | ||
use_early_stop, early_stop_burnin_rounds, | ||
early_stop_rounds, early_stop_strategy) | ||
|
@@ -961,28 +975,16 @@ def compute_score(self, rankings, train=False): # pylint:disable=unused-argument | |
|
||
Returns | ||
------- | ||
Evaluation metric values: dict | ||
Evaluation metric values: dict of dict | ||
""" | ||
# We calculate global mrr, etype is ignored. | ||
# User can develop its own per etype MRR evaluator | ||
metrics = {} | ||
scores = {} | ||
for etype, rank in rankings.items(): | ||
metrics[etype] = gen_mrr_score(rank) | ||
|
||
# When world size == 1, we do not need the barrier | ||
if get_world_size() > 1: | ||
barrier() | ||
for _, metric in metrics.items(): | ||
for _, metric_val in metric.items(): | ||
th.distributed.all_reduce(metric_val) | ||
|
||
return_metrics = {} | ||
for etype, metric in metrics.items(): | ||
for metric_key, metric_val in metric.items(): | ||
return_metric = metric_val / get_world_size() | ||
if metric_key not in return_metrics: | ||
return_metrics[metric_key] = {} | ||
return_metrics[metric_key][etype] = return_metric.item() | ||
scores[etype] = super().compute_score(rank) | ||
# reorganize the nested dict to be keyed by metric, then by etype: | ||
return_metrics = defaultdict(dict) | ||
for metric in self.metric: | ||
for etype in rankings.keys(): | ||
return_metrics[metric][etype] = scores[etype][metric] | ||
return return_metrics | ||
|
||
def _get_major_score(self, score): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to have a list of ks instead of just 1 k?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think most of the LP/retrieval datasets only select one specific k.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to eval_hit_k