-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Closed
Labels
Description
Is spectral_norm supported by amp? I've recently been trying to get amp working on a GAN, but so far no luck.
Traceback (most recent call last):
File "train.py", line 14, in <module>
trainer.train()
File "/private/home/tldevries/projects/self-attention-GAN-pytorch/trainer_fp16.py", line 177, in train
d_out_real = self.D(real_images + inst_noise, real_labels)
File "/private/home/tldevries/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/tldevries/projects/self-attention-GAN-pytorch/sagan_models.py", line 240, in forward
h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64
File "/private/home/tldevries/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/private/home/tldevries/projects/self-attention-GAN-pytorch/sagan_models.py", line 173, in forward
x = self.snconv2d1(x)
File "/private/home/tldevries/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 485, in __call__
hook(self, input)
File "/private/home/tldevries/miniconda3/lib/python3.7/site-packages/torch/nn/utils/spectral_norm.py", line 100, in __call__
setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))
File "/private/home/tldevries/miniconda3/lib/python3.7/site-packages/torch/nn/utils/spectral_norm.py", line 86, in compute_weight
sigma = torch.dot(u, torch.mv(weight_mat, v))
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'tensor'