diff --git a/MinkowskiEngine/utils/quantization.py b/MinkowskiEngine/utils/quantization.py index a50c3caa..d9ff7a26 100644 --- a/MinkowskiEngine/utils/quantization.py +++ b/MinkowskiEngine/utils/quantization.py @@ -266,12 +266,12 @@ def sparse_quantize( else: discrete_coordinates = discrete_coordinates.int() - if device == "cpu": + if (type(device) == str and device == "cpu") or (type(device) == torch.device and device.type == "cpu"): manager = MEB.CoordinateMapManagerCPU() - elif "cuda" in device: + elif (type(device) == str and "cuda" in device) or (type(device) == torch.device and device.type == "cuda"): manager = MEB.CoordinateMapManagerGPU_c10() else: - raise ValueError("Invalid device. Only `cpu` or `cuda` supported.") + raise ValueError("Invalid device. Only `cpu`, `cuda` or torch.device supported.") # Return values accordingly if use_label: