From 6100c733c671ac6dfc15bda9cde58643ee2d25cd Mon Sep 17 00:00:00 2001 From: Yoav Katz Date: Mon, 6 May 2024 09:48:40 +0300 Subject: [PATCH] Add cache to metric prediction_type to speedup loading of tasks. Signed-off-by: Yoav Katz --- src/unitxt/task.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/unitxt/task.py b/src/unitxt/task.py index 79c8dcf6e..6797362d6 100644 --- a/src/unitxt/task.py +++ b/src/unitxt/task.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Any, Dict, List, Optional, Union from .artifact import fetch_artifact @@ -75,11 +76,16 @@ def verify(self): augmentable_input in self.inputs ), f"augmentable_input {augmentable_input} is not part of {self.inputs}" + @staticmethod + @lru_cache(maxsize=None) + def get_metric_prediction_type(metric_id: str): + metric = fetch_artifact(metric_id)[0] + return metric.get_prediction_type() + def check_metrics_type(self) -> None: prediction_type = parse_type_string(self.prediction_type) - for metric_name in self.metrics: - metric = fetch_artifact(metric_name)[0] - metric_prediction_type = metric.get_prediction_type() + for metric_id in self.metrics: + metric_prediction_type = FormTask.get_metric_prediction_type(metric_id) if ( prediction_type == metric_prediction_type @@ -93,7 +99,7 @@ def check_metrics_type(self) -> None: continue raise ValueError( - f"The task's prediction type ({prediction_type}) and '{metric_name}' " + f"The task's prediction type ({prediction_type}) and '{metric_id}' " f"metric's prediction type ({metric_prediction_type}) are different." )