Skip to content

Commit

Permalink
Fix Loss to device
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Mar 17, 2024
1 parent ef478f1 commit 9abc481
Showing 1 changed file with 0 additions and 7 deletions.
7 changes: 0 additions & 7 deletions speckcn2/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ def __init__(self, config: dict, nz: Normalizer, device: torch.device):
self.recover_tag = nz.recover_tag
# Move tensors to the device
self.h = self.h.to(self.device)
self.k = self.k.to(self.device)
self.cosz = torch.tensor(self.cosz, device=self.device)
self.secz = torch.tensor(self.secz, device=self.device)
self.L = torch.tensor(self.L, device=self.device)
self.p_fr = torch.tensor(self.p_fr, device=self.device)
self.p_iso = torch.tensor(self.p_iso, device=self.device)
self.p_scw = torch.tensor(self.p_scw, device=self.device)

def forward(self, pred: torch.Tensor,
target: torch.Tensor) -> tuple[torch.Tensor, dict]:
Expand Down

0 comments on commit 9abc481

Please sign in to comment.