diff --git a/ranx/meta/compare.py b/ranx/meta/compare.py index 3c12987..5e93488 100644 --- a/ranx/meta/compare.py +++ b/ranx/meta/compare.py @@ -65,7 +65,7 @@ def compare( Report: See report. """ metrics = format_metrics(metrics) - assert all(type(m) == str for m in metrics), "Metrics error" + assert all(isinstance(m, str) for m in metrics), "Metrics error" model_names = [] results = defaultdict(dict) diff --git a/ranx/meta/evaluate.py b/ranx/meta/evaluate.py index f99560a..09df620 100644 --- a/ranx/meta/evaluate.py +++ b/ranx/meta/evaluate.py @@ -11,7 +11,7 @@ def format_metrics(metrics: Union[List[str], str]) -> List[str]: - if type(metrics) == str: + if isinstance(metrics, str): metrics = [metrics] return metrics @@ -43,7 +43,7 @@ def extract_metric_and_params(metric): def convert_qrels(qrels): if type(qrels) == Qrels: return qrels.to_typed_list() - elif type(qrels) == dict: + elif isinstance(qrels, dict): return python_dict_to_typed_list(qrels, sort=True) return qrels @@ -51,7 +51,7 @@ def convert_qrels(qrels): def convert_run(run): if type(run) == Run: return run.to_typed_list() - elif type(run) == dict: + elif isinstance(run, dict): return python_dict_to_typed_list(run, sort=True) return run @@ -133,7 +133,7 @@ def evaluate( _qrels = convert_qrels(qrels) _run = convert_run(run) metrics = format_metrics(metrics) - assert all(type(m) == str for m in metrics), "Metrics error" + assert all(isinstance(m, str) for m in metrics), "Metrics error" # Compute metrics ---------------------------------------------------------- metric_scores_dict = {} diff --git a/ranx/meta/plot.py b/ranx/meta/plot.py index eca7433..fd43646 100644 --- a/ranx/meta/plot.py +++ b/ranx/meta/plot.py @@ -25,7 +25,7 @@ def plot( _qrels = qrels.to_typed_list() - if type(runs) == list: + if isinstance(runs, list): _runs = [run.to_typed_list() for run in runs] names = [ run.name if run.name is not None else f"run_{i+1}"