diff --git a/delphi/sparse_coders/custom/gemmascope.py b/delphi/sparse_coders/custom/gemmascope.py index 27511d58..5db9ceb7 100644 --- a/delphi/sparse_coders/custom/gemmascope.py +++ b/delphi/sparse_coders/custom/gemmascope.py @@ -104,6 +104,8 @@ def from_pretrained(cls, model_name_or_path, position, device): pt_params = {k: torch.from_numpy(v) for k, v in params.items()} model = cls(params["W_enc"].shape[0], params["W_enc"].shape[1]) model.load_state_dict(pt_params) - if device == "cuda": + if device == "cuda" or ( + isinstance(device, torch.device) and device.type == "cuda" + ): model.cuda() return model