diff --git a/deeptrack/noises.py b/deeptrack/noises.py index eaaa5afe1..cce4bc123 100644 --- a/deeptrack/noises.py +++ b/deeptrack/noises.py @@ -217,7 +217,11 @@ def get( # For a Torch backend. elif self.get_backend() == "torch": - noisy_image = mu + image + torch.randn(*image.shape) * sigma + noisy_image = ( + mu + + image + + torch.randn(*image.shape, device=image.device) * sigma + ) return noisy_image @@ -300,8 +304,8 @@ def get( # For a Torch backend. elif self.get_backend() == "torch": - real_noise = torch.randn(*image.shape) - imag_noise = torch.randn(*image.shape) * 1j + real_noise = torch.randn(*image.shape, device=image.device) + imag_noise = torch.randn(*image.shape, device=image.device) * 1j noisy_image = mu + image + (real_noise + imag_noise) * sigma return noisy_image