-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
SSIM has values larger than 1 #2327
Comments
Hi! thanks for your contribution!, great first issue! |
Hi @TuanDTr, thanks for reporting this issue. |
@SkafteNicki Thank you for your quick response. Please find bellow the methods for forward, training steps as well as for validation steps where I initialize the metrics. Basically the metric is initialized in def training_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> torch.Tensor:
if self.use_profiler:
self.profiler.step()
x = batch["t1c"]
z = self.get_latent_code(x)
z_cond = []
for m in self.hparams.cond_modality:
x_cond = batch[m]
z_cond.append(self.get_latent_code(x_cond))
z_cond = torch.cat(z_cond, dim=1)
noise = torch.randn_like(z).to(self.device)
timesteps = torch.randint(0, self.hparams.num_train_steps, (z.shape[0], ), device=self.device).long()
noisy_z = self._scheduler.add_noise(original_samples=z, noise=noise, timesteps=timesteps)
noise_pred = self._unet(torch.cat((noisy_z, z_cond), dim=1), timesteps=timesteps)
loss = self.criterion(noise_pred, noise)
self.log("train/noise_recons_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
@torch.inference_mode()
def forward(self, z_cond):
self._ema.ema_model.eval()
z_dim = [z_cond.shape[0], self._unet.out_channels, *z_cond.shape[2:]]
z = torch.randn(z_dim, device=self.device)
self._scheduler.set_timesteps(num_inference_steps=self.hparams.num_inference_steps)
for t in range(self.hparams.num_inference_steps):
model_output = self._ema.ema_model(
torch.cat((z, z_cond), dim=1),
timesteps=torch.Tensor((t,)).to(self.device).long()
)
z, _ = self._scheduler.step(model_output, t, z)
x = self.decode_from_latent_code(z)
return x, z
def on_validation_model_eval(self) -> None:
"""Prepare before validation."""
self.metrics = {
"PSNR": PeakSignalNoiseRatio(data_range=None).to(self.device),
"SSIM": StructuralSimilarityIndexMeasure(data_range=None).to(self.device),
"MAE_image": MeanAbsoluteError().to(self.device),
"MAE_latent": MeanAbsoluteError().to(self.device)
}
super().on_validation_model_eval()
def validation_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> None:
x = batch["t1c"]
if not self.hparams.preloaded_latent:
z = self.get_latent_code(x)
else:
z = batch["latent_t1c"]
z_cond = []
for m in self.hparams.cond_modality:
if not self.hparams.preloaded_latent:
x_cond = batch[m]
z_cond.append(self.get_latent_code(x_cond))
else:
z_cond.append(batch[f"latent_{m}"])
z_cond = torch.cat(z_cond, dim=1)
preds, latents = self.forward(z_cond)
# Compute score
self.metrics["PSNR"].update(preds, x)
self.metrics["SSIM"].update(preds, x)
self.metrics["MAE_image"].update(preds, x)
self.metrics["MAE_latent"].update(latents, z)
# Inverse transform
inverse_transform = BatchInverseTransform(self.val_dataloader().dataset.transforms, self.val_dataloader())
with allow_missing_keys_mode(self.val_dataloader().dataset.transforms):
preds = inverse_transform({"latent_t1c": preds})
self.save_to_h5_dataset(preds)
def on_validation_epoch_end(self) -> None:
psnr = self.metrics["PSNR"].compute()
ssim = self.metrics["SSIM"].compute()
mae_image = self.metrics["MAE_image"].compute()
mae_latent = self.metrics["MAE_latent"].compute()
self.log("val/psnr", psnr, on_epoch=True, logger=True, prog_bar=True)
self.log("val/ssim", ssim, on_epoch=True, logger=True, prog_bar=True)
self.log("val/mae_image", mae_image, on_epoch=True, logger=True, prog_bar=True)
self.log("val/mae_latent", mae_latent, on_epoch=True, logger=True, prog_bar=True)
self.metrics["PSNR"].reset()
self.metrics["SSIM"].reset()
self.metrics["MAE_image"].reset()
self.metrics["MAE_latent"].reset() |
Hi @TuanDTr, def on_validation_epoch_end(self) -> None:
ssim = self.metrics["SSIM"].compute()
if ssim > 1:
torch.save(self.metrics["SSIM"].metric_state, "ssim_state.pt") |
Hi @SkafteNicki, here is the state when the error happens:
I'll further scale all input tensors to the same range and see if this still occurs. I will follow up with you. |
@TuanDTr I been trying to further debug the issue on my end and I am still unable to reproduce the problem. From the output you send it is very clear to me that |
Hi @SkafteNicki, I have tried rescaling the range of inputs to [0, 1] (see below) but I still encountered the SSIM > 1. I am setting the def forward(self, z_cond):
....
return x.clamp(0, 1), z |
@SkafteNicki Hello and I am sorry for the late update. I might have an idea why SSIM is larger than 1. I inspected my evaluation script and found that the tensors were in float16. If I change them to float32, I will get the correct results. However, I cannot reproduce this issue outside my training script. I think my setting that uses mixed precision training could have something to do with this. Do you have any idea how to inspect this further. Thanks! |
馃悰 Bug
It seems like SSIM can have values larger than 1 when computing over an epoch. I cannot reproduce this error but only observe it with tensorboard after training.
Environment
Additional context
I am comparing between a synthetic and an target image. The target image has value between 0 and 1 while the synthetic image has values between -1 and 1. The data_range is set to None. I know the range between the synthetic and target image is different but still should SSIM produce a maximum score of 1?
The text was updated successfully, but these errors were encountered: