From 9abc48140a16d4de04e1b433d68dcb8fc153f497 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Sun, 17 Mar 2024 17:59:17 +0100 Subject: [PATCH] Fix Loss to device --- speckcn2/loss.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/speckcn2/loss.py b/speckcn2/loss.py index f8d8c58..599139f 100644 --- a/speckcn2/loss.py +++ b/speckcn2/loss.py @@ -61,13 +61,6 @@ def __init__(self, config: dict, nz: Normalizer, device: torch.device): self.recover_tag = nz.recover_tag # Move tensors to the device self.h = self.h.to(self.device) - self.k = self.k.to(self.device) - self.cosz = torch.tensor(self.cosz, device=self.device) - self.secz = torch.tensor(self.secz, device=self.device) - self.L = torch.tensor(self.L, device=self.device) - self.p_fr = torch.tensor(self.p_fr, device=self.device) - self.p_iso = torch.tensor(self.p_iso, device=self.device) - self.p_scw = torch.tensor(self.p_scw, device=self.device) def forward(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, dict]: