diff --git a/monailabel/monaivista/lib/infers/vista_point_2pt5.py b/monailabel/monaivista/lib/infers/vista_point_2pt5.py index 0a9cc1a..ee35286 100644 --- a/monailabel/monaivista/lib/infers/vista_point_2pt5.py +++ b/monailabel/monaivista/lib/infers/vista_point_2pt5.py @@ -64,7 +64,7 @@ def pre_transforms(self, data=None) -> Sequence[Callable]: ] def inferer(self, data=None) -> Inferer: - return VISTASliceInferer() + return VISTASliceInferer(device=data.get("device") if data else None) def inverse_transforms(self, data=None): return [] diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py b/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py index cd81678..77e52f7 100644 --- a/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py @@ -238,14 +238,21 @@ def update_slice( continue inputs = inputs_l[..., start_idx - (n_z_slices // 2) : start_idx + (n_z_slices // 2) + 1].permute(2, 0, 1) + if device and (device == "cuda" or isinstance(device, torch.device) and device.type == "cuda"): + inputs = inputs.cuda() data, unique_labels = prepare_sam_val_input( - inputs.cuda(), class_prompts, point_prompts, start_idx, original_affine + inputs, class_prompts, point_prompts, start_idx, original_affine, device=device ) predictor.eval() - with torch.cuda.amp.autocast(): - outputs = predictor(data) - logit = outputs[0]["high_res_logits"] + if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + with torch.cuda.amp.autocast(): + outputs = predictor(data) + logit = outputs[0]["high_res_logits"] + else: + with torch.cpu.amp.autocast(): + outputs = predictor(data) + logit = outputs[0]["high_res_logits"] out_list = torch.unbind(logit, dim=0) y_pred = torch.stack(post_pred_slice(out_list)).float() @@ -290,11 +297,15 @@ def iterate_all( ) for start_idx in start_range: inputs = inputs_l[..., start_idx - n_z_slices // 2 : start_idx + n_z_slices // 2 + 1].permute(2, 0, 1) - data, unique_labels = prepare_sam_val_input(inputs.cuda(), class_prompts, point_prompts, start_idx) + if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + inputs = inputs.cuda() + data, unique_labels = prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, device=device) predictor = predictor.eval() with autocast(): if cachedEmbedding: - curr_embedding = cachedEmbedding[start_idx].cuda() + curr_embedding = cachedEmbedding[start_idx] + if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + curr_embedding = curr_embedding.cuda() outputs = predictor.get_mask_prediction(data, curr_embedding) else: outputs = predictor(data) diff --git a/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py b/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py index cc83dc7..695aeee 100644 --- a/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py +++ b/monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py @@ -80,14 +80,16 @@ def distributed_all_gather( return tensor_list_out -def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, original_affine=None): +def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, original_affine=None, device=None): # Don't exclude background in val but will ignore it in metric calculation H, W = inputs.shape[1:] foreground_all = point_prompts["foreground"] background_all = point_prompts["background"] class_list = [[i + 1] for i in class_prompts] - unique_labels = torch.tensor(class_list).long().cuda() + unique_labels = torch.tensor(class_list).long() + if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + unique_labels = unique_labels.cuda() volume_point_coords = [cp for cp in foreground_all] volume_point_labels = [1] * len(foreground_all) @@ -129,8 +131,11 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi prepared_input[0].update({"labels": unique_labels}) if point_coords: - point_coords = torch.tensor(point_coords).long().cuda() - point_labels = torch.tensor(point_labels).long().cuda() + point_coords = torch.tensor(point_coords).long() + point_labels = torch.tensor(point_labels).long() + if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"): + point_coords = point_coords.cuda() + point_labels = point_labels.cuda() prepared_input[0].update({"point_coords": point_coords, "point_labels": point_labels})