Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monailabel/monaivista/lib/infers/vista_point_2pt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def pre_transforms(self, data=None) -> Sequence[Callable]:
]

def inferer(self, data=None) -> Inferer:
return VISTASliceInferer()
return VISTASliceInferer(device=data.get("device") if data else None)

def inverse_transforms(self, data=None):
return []
Expand Down
23 changes: 17 additions & 6 deletions monailabel/monaivista/lib/model/vista_point_2pt5/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,21 @@ def update_slice(
continue

inputs = inputs_l[..., start_idx - (n_z_slices // 2) : start_idx + (n_z_slices // 2) + 1].permute(2, 0, 1)
if device and (device == "cuda" or isinstance(device, torch.device) and device.type == "cuda"):
inputs = inputs.cuda()
data, unique_labels = prepare_sam_val_input(
inputs.cuda(), class_prompts, point_prompts, start_idx, original_affine
inputs, class_prompts, point_prompts, start_idx, original_affine, device=device
)

predictor.eval()
with torch.cuda.amp.autocast():
outputs = predictor(data)
logit = outputs[0]["high_res_logits"]
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
with torch.cuda.amp.autocast():
outputs = predictor(data)
logit = outputs[0]["high_res_logits"]
else:
with torch.cpu.amp.autocast():
outputs = predictor(data)
logit = outputs[0]["high_res_logits"]

out_list = torch.unbind(logit, dim=0)
y_pred = torch.stack(post_pred_slice(out_list)).float()
Expand Down Expand Up @@ -290,11 +297,15 @@ def iterate_all(
)
for start_idx in start_range:
inputs = inputs_l[..., start_idx - n_z_slices // 2 : start_idx + n_z_slices // 2 + 1].permute(2, 0, 1)
data, unique_labels = prepare_sam_val_input(inputs.cuda(), class_prompts, point_prompts, start_idx)
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
inputs = inputs.cuda()
data, unique_labels = prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, device=device)
predictor = predictor.eval()
with autocast():
if cachedEmbedding:
curr_embedding = cachedEmbedding[start_idx].cuda()
curr_embedding = cachedEmbedding[start_idx]
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
curr_embedding = curr_embedding.cuda()
outputs = predictor.get_mask_prediction(data, curr_embedding)
else:
outputs = predictor(data)
Expand Down
13 changes: 9 additions & 4 deletions monailabel/monaivista/lib/model/vista_point_2pt5/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,16 @@ def distributed_all_gather(
return tensor_list_out


def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, original_affine=None):
def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, original_affine=None, device=None):
# Don't exclude background in val but will ignore it in metric calculation
H, W = inputs.shape[1:]
foreground_all = point_prompts["foreground"]
background_all = point_prompts["background"]

class_list = [[i + 1] for i in class_prompts]
unique_labels = torch.tensor(class_list).long().cuda()
unique_labels = torch.tensor(class_list).long()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
unique_labels = unique_labels.cuda()

volume_point_coords = [cp for cp in foreground_all]
volume_point_labels = [1] * len(foreground_all)
Expand Down Expand Up @@ -129,8 +131,11 @@ def prepare_sam_val_input(inputs, class_prompts, point_prompts, start_idx, origi
prepared_input[0].update({"labels": unique_labels})

if point_coords:
point_coords = torch.tensor(point_coords).long().cuda()
point_labels = torch.tensor(point_labels).long().cuda()
point_coords = torch.tensor(point_coords).long()
point_labels = torch.tensor(point_labels).long()
if device == "cuda" or (isinstance(device, torch.device) and device.type == "cuda"):
point_coords = point_coords.cuda()
point_labels = point_labels.cuda()

prepared_input[0].update({"point_coords": point_coords, "point_labels": point_labels})

Expand Down