From eaecb3d338e66c48609fec5d703c4f8d21f67fb7 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Thu, 5 May 2022 17:23:07 +0200 Subject: [PATCH 1/2] Solve bug --- ot/backend.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 361ffba69..e4b48e176 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1507,15 +1507,19 @@ class TorchBackend(Backend): def __init__(self): - self.rng_ = torch.Generator() + self.rng_ = torch.Generator("cpu") self.rng_.seed() self.__type_list__ = [torch.tensor(1, dtype=torch.float32), torch.tensor(1, dtype=torch.float64)] if torch.cuda.is_available(): + self.rng_cuda_ = torch.Generator("cuda") + self.rng_cuda_.seed() self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + else: + self.rng_cuda_ = torch.Generator("cpu") from torch.autograd import Function @@ -1761,20 +1765,26 @@ def reshape(self, a, shape): def seed(self, seed=None): if isinstance(seed, int): self.rng_.manual_seed(seed) + self.rng_cuda_.manual_seed(seed) elif isinstance(seed, torch.Generator): - self.rng_ = seed + if self.device_type(seed) == "GPU": + self.rng_cuda_ = seed + else: + self.rng_ = seed else: raise ValueError("Non compatible seed : {}".format(seed)) def rand(self, *size, type_as=None): if type_as is not None: - return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device) else: return torch.rand(size=size, generator=self.rng_) def randn(self, *size, type_as=None): if type_as is not None: - return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device) else: return torch.randn(size=size, generator=self.rng_) From 1f5b06e02f4bf65085785e6bf6456c48623d439d Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Thu, 5 May 2022 17:40:05 +0200 Subject: [PATCH 2/2] Update release file --- RELEASES.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index be2192eb7..a336c51e7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,16 @@ # Releases +## 0.8.3dev + +#### New features + +- + +#### Closed issues + +- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU + (Issue #371, PR #373) + ## 0.8.2