From 2c12a2674de85dae73489d55e433c5dad579167b Mon Sep 17 00:00:00 2001 From: Timothee Cour Date: Tue, 1 Oct 2019 10:05:48 -0700 Subject: [PATCH] fix https://github.com/facebookresearch/maskrcnn-benchmark/issues/802 --- apex/amp/lists/torch_overrides.py | 11 +++++++---- apex/amp/utils.py | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/apex/amp/lists/torch_overrides.py b/apex/amp/lists/torch_overrides.py index 571745ebf..7dedb05a8 100644 --- a/apex/amp/lists/torch_overrides.py +++ b/apex/amp/lists/torch_overrides.py @@ -74,10 +74,13 @@ _bmms = ['addbmm', 'baddbmm', 'bmm'] -if utils.get_cuda_version() >= (9, 1, 0): - FP16_FUNCS.extend(_bmms) -else: - FP32_FUNCS.extend(_bmms) + +if utils.is_cuda_enabled(): + # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802 + if utils.get_cuda_version() >= (9, 1, 0): + FP16_FUNCS.extend(_bmms) + else: + FP32_FUNCS.extend(_bmms) # Multi-tensor fns that may need type promotion CASTS = [ diff --git a/apex/amp/utils.py b/apex/amp/utils.py index e76e28958..0590cd70a 100644 --- a/apex/amp/utils.py +++ b/apex/amp/utils.py @@ -5,6 +5,9 @@ import torch +def is_cuda_enabled(): + return torch.version.cuda is not None + def get_cuda_version(): return tuple(int(x) for x in torch.version.cuda.split('.'))