diff --git a/projects/CySGAN/cysgan/trainer.py b/projects/CySGAN/cysgan/trainer.py index 53d0af1d..9ffd31dc 100755 --- a/projects/CySGAN/cysgan/trainer.py +++ b/projects/CySGAN/cysgan/trainer.py @@ -170,7 +170,7 @@ def train(self): real_seg = torch.cat(targetX, 1).to(self.device) # concatenate over channel dim fake_seg = self.seg_handler(fakeXseg if random.random() > 0.5 else recYseg) loss_Ds = self.update_netD(self.Ds, real_seg, self.image_pool['Ds'].query(fake_seg), - self.optimizer['Dy'], self.lr_scheduler['Dy']) + self.optimizer['Ds'], self.lr_scheduler['Ds']) loss_D = loss_Dx + loss_Dy + loss_Ds # discriminator losses self.iter_time = time.perf_counter() - self.start_time