From 7fbfb1be43531db8dc5e47ac9ddd3df824058469 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Mon, 29 Nov 2021 23:55:37 -0500 Subject: [PATCH] fix loss_weights dtype issue --- trojanzoo/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trojanzoo/models.py b/trojanzoo/models.py index 088c8bc9..1cfe416e 100644 --- a/trojanzoo/models.py +++ b/trojanzoo/models.py @@ -179,7 +179,7 @@ def __init__(self, name: str = 'model', self.layer_name_list: list[str] = None # ------------------------------ # - self.criterion = self.define_criterion(weight=to_tensor(loss_weights)) + self.criterion = self.define_criterion(weight=to_tensor(loss_weights, dtype=torch.float)) self.criterion_noreduction = self.define_criterion( weight=to_tensor(loss_weights), reduction='none') if isinstance(model, type):