From 88e68af97f25c39add359724158791cf7cd40ed0 Mon Sep 17 00:00:00 2001 From: cfli <545999961@qq.com> Date: Tue, 26 Mar 2024 04:50:58 +0000 Subject: [PATCH] add_device_set --- FlagEmbedding/flag_reranker.py | 87 ++++++++++++++++++++++------------ 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/FlagEmbedding/flag_reranker.py b/FlagEmbedding/flag_reranker.py index 994454d8..21b478e0 100644 --- a/FlagEmbedding/flag_reranker.py +++ b/FlagEmbedding/flag_reranker.py @@ -151,21 +151,30 @@ def __init__( self, model_name_or_path: str = None, use_fp16: bool = False, - cache_dir: str = None + cache_dir: str = None, + device: Union[str, int] = None ) -> None: self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir) - if torch.cuda.is_available(): - self.device = torch.device("cuda") - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - elif is_torch_npu_available(): - self.device = torch.device("npu") + if device and isinstance(device, str): + self.device = torch.device(device) + if device == 'cpu': + use_fp16 = False else: - self.device = torch.device("cpu") - use_fp16 = False + if torch.cuda.is_available(): + if device is not None: + self.device = torch.device(f"cuda:{device}") + else: + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + elif is_torch_npu_available(): + self.device = torch.device("npu") + else: + self.device = torch.device("cpu") + use_fp16 = False if use_fp16: self.model.half() @@ -173,10 +182,13 @@ def __init__( self.model.eval() - self.num_gpus = torch.cuda.device_count() - if self.num_gpus > 1: - print(f"----------using {self.num_gpus}*GPUs----------") - self.model = torch.nn.DataParallel(self.model) + if device is None: + self.num_gpus = torch.cuda.device_count() + if self.num_gpus > 1: + print(f"----------using {self.num_gpus}*GPUs----------") + self.model = torch.nn.DataParallel(self.model) + else: + self.num_gpus = 1 @torch.no_grad() def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 256, @@ -218,7 +230,7 @@ def __init__( use_fp16: bool = False, use_bf16: bool = False, cache_dir: str = None, - device: int = 0 + device: Union[str, int] = None ) -> None: self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir, @@ -231,14 +243,21 @@ def __init__( self.model_name_or_path = model_name_or_path self.cache_dir = cache_dir - if torch.cuda.is_available(): - torch.cuda.set_device(device) - self.device = torch.device('cuda') - elif torch.backends.mps.is_available(): - self.device = torch.device('mps') + if device and isinstance(device, str): + self.device = torch.device(device) else: - self.device = torch.device('cpu') - use_fp16 = False + device = 0 if device is None else device + if torch.cuda.is_available(): + torch.cuda.set_device(device) + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + elif is_torch_npu_available(): + self.device = torch.device("npu") + else: + self.device = torch.device("cpu") + use_fp16 = False + if use_fp16 and use_bf16 is False: self.model.half() @@ -311,7 +330,7 @@ def __init__( use_fp16: bool = False, use_bf16: bool = False, cache_dir: str = None, - device: int = 0 + device: Union[str, int] = None ) -> None: self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir, @@ -329,14 +348,24 @@ def __init__( self.model_name_or_path = model_name_or_path self.cache_dir = cache_dir - if torch.cuda.is_available(): - torch.cuda.set_device(device) - self.device = torch.device('cuda') - elif torch.backends.mps.is_available(): - self.device = torch.device('mps') + if device and isinstance(device, str): + if device == 'cpu': + warnings.warn('The LLM-based layer-wise reranker does not support CPU; it has been set to CUDA.') + device = 'cuda' + self.device = torch.device(device) else: - self.device = torch.device('cpu') - use_fp16 = False + device = 0 if device is None else device + if torch.cuda.is_available(): + torch.cuda.set_device(device) + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + elif is_torch_npu_available(): + self.device = torch.device("npu") + else: + self.device = torch.device("cpu") + use_fp16 = False + if use_fp16 and use_bf16 is False: self.model.half()