diff --git a/corenet/cli/main_eval_llmadapters.py b/corenet/cli/main_eval_llmadapters.py index 9fdb315..be551b8 100644 --- a/corenet/cli/main_eval_llmadapters.py +++ b/corenet/cli/main_eval_llmadapters.py @@ -226,7 +226,7 @@ class CoreNetLMEvalWrapper(HFLM): def __init__(self, opts: argparse.Namespace) -> None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') hf_config = _CorenetToHFPretrainedConfig(**vars(opts)) tokenizer_path = getattr(opts, f"text_tokenizer.sentence_piece.model_path") tokenizer_path = get_local_path(opts, tokenizer_path) @@ -261,7 +261,7 @@ def main_eval_llmadapters(args: Optional[List[str]] = None) -> None: model_eval_wrapper = CoreNetLMEvalWrapper(opts) tasks = getattr(opts, "llmadapters_evaluation.datasets") dataset_dir = getattr(opts, "llmadapters_evaluation.dataset_dir") - device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') results_loc = getattr(opts, "common.results_loc") limit = getattr(opts, "llmadapters_evaluation.limit") if limit is None: