From 02fc608bea4c0549b0a7b00ca1bf15dee4a0b228 Mon Sep 17 00:00:00 2001 From: Yiming Zhang <49868620+eamonn-zh@users.noreply.github.com> Date: Wed, 22 Mar 2023 12:19:25 -0700 Subject: [PATCH] add torch.device support for quantization (#533) --- MinkowskiEngine/utils/quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MinkowskiEngine/utils/quantization.py b/MinkowskiEngine/utils/quantization.py index a50c3caa..d9ff7a26 100644 --- a/MinkowskiEngine/utils/quantization.py +++ b/MinkowskiEngine/utils/quantization.py @@ -266,12 +266,12 @@ def sparse_quantize( else: discrete_coordinates = discrete_coordinates.int() - if device == "cpu": + if (type(device) == str and device == "cpu") or (type(device) == torch.device and device.type == "cpu"): manager = MEB.CoordinateMapManagerCPU() - elif "cuda" in device: + elif (type(device) == str and "cuda" in device) or (type(device) == torch.device and device.type == "cuda"): manager = MEB.CoordinateMapManagerGPU_c10() else: - raise ValueError("Invalid device. Only `cpu` or `cuda` supported.") + raise ValueError("Invalid device. Only `cpu`, `cuda` or torch.device supported.") # Return values accordingly if use_label: