diff --git a/trojanzoo/models.py b/trojanzoo/models.py index 088c8bc9..1cfe416e 100644 --- a/trojanzoo/models.py +++ b/trojanzoo/models.py @@ -179,7 +179,7 @@ def __init__(self, name: str = 'model', self.layer_name_list: list[str] = None # ------------------------------ # - self.criterion = self.define_criterion(weight=to_tensor(loss_weights)) + self.criterion = self.define_criterion(weight=to_tensor(loss_weights, dtype=torch.float)) self.criterion_noreduction = self.define_criterion( weight=to_tensor(loss_weights), reduction='none') if isinstance(model, type):