New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
predicted bpp is Nan when training #92
Comments
Hi. The warning mentions the input tensor. Could it be corrupt input data with your training? Not easy to help without information on input motion, residuals and training data. |
Thank you for your answer. I would try my best to describe some information, but I don't know if it is useful. def __getitem__(self, index):
input_image = imageio.imread(self.image_input_list[index])
ref_image = imageio.imread(self.image_ref_list[index])
input_image = input_image.astype(np.float32) / 255.0
ref_image = ref_image.astype(np.float32) / 255.0
input_image = input_image.transpose(2, 0, 1)
ref_image = ref_image.transpose(2, 0, 1)
input_image = torch.from_numpy(input_image).float()
ref_image = torch.from_numpy(ref_image).float()
input_image, ref_image = random_crop_and_pad_image_and_labels(input_image, ref_image, [self.im_height, self.im_width])
input_image, ref_image = random_flip(input_image, ref_image)
return input_image, ref_image #About input motion and residuals def forward(self, input_image, referframe):
estmv = self.opticFlow(input_image, referframe)
mv_fea = self.mvEncoder(estmv)
mv_prior = self.mvpriorEncoder(mv_fea)
quant_mvprior, mvprior_likelihoods = self.entropy_hyper_mv(mv_prior)
recon_mv_sigma = self.mvpriorDecoder(quant_mvprior)
quant_mv = self.entropy_bottleneck_mv.quantize(mv_fea, "noise" if self.training else "dequantize")
_, mv_likelihoods = self.entropy_bottleneck_mv(mv_fea, recon_mv_sigma)
recon_mv = self.mvDecoder(quant_mv)
prediction, warpframe = self.motioncompensation(referframe, recon_mv)
res = input_image - prediction
res_fea = self.resEncoder(res)
res_prior = self.respriorEncoder(res_fea)
quant_resprior, resprior_likelihoods = self.entropy_hyper_res(res_prior)
recon_res_sigma = self.respriorDecoder(quant_resprior)
quant_res = self.entropy_bottleneck_res.quantize(res_fea, "noise" if self.training else "dequantize")
_, res_likelihoods = self.entropy_bottleneck_res(res_fea, recon_res_sigma)
recon_res = self.resDecoder(quant_res)
recon_image = prediction + recon_res
clipped_recon_image = recon_image.clamp(0. ,1.)
mse_loss = torch.mean((recon_image - input_image).pow(2))
warploss = torch.mean((warpframe - input_image).pow(2))
interloss = torch.mean((prediction - input_image).pow(2))
im_shape = input_image.size()
batch_size = res_fea.size()[0]
bpp_mv = torch.log(mv_likelihoods).sum() / (-math.log(2) * batch_size * im_shape[2] * im_shape[3])
bpp_mvprior = torch.log(mvprior_likelihoods).sum() / (-math.log(2) * batch_size * im_shape[2] * im_shape[3])
bpp_res = torch.log(res_likelihoods).sum() / (-math.log(2) * batch_size * im_shape[2] * im_shape[3])
bpp_resprior = torch.log(resprior_likelihoods).sum() / (-math.log(2) * batch_size * im_shape[2] * im_shape[3])
bpp = bpp_mv + bpp_mvprior + bpp_res + bpp_resprior
return clipped_recon_image, mse_loss, warploss, interloss, bpp #About training for batch_idx, input in enumerate(train_loader):
global_step += 1
bat_cnt += 1
input_image, ref_image = Var(input[0]), Var(input[1])
clipped_recon_image, mse_loss, warploss, interloss, bpp = net(input_image, ref_image)
mse_loss, warploss, interloss, bpp = torch.mean(mse_loss), torch.mean(warploss), torch.mean(interloss), torch.mean(bpp)
distribution_loss = bpp
distortion = mse_loss + warp_weight * (warploss + interloss)
rd_loss = train_lambda * distortion + distribution_loss
optimizer.zero_grad()
aux_optimizer.zero_grad()
rd_loss.backward()
def clip_gradient(optimizer, grad_clip):
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
clip_gradient(optimizer, 0.5)
optimizer.step()
aux_loss = net.aux_loss()
aux_loss.backward()
aux_optimizer.step() #Additional information |
Your last comment makes sense and is a good indication. I guess you can spot when the warnings happen first by printing more info and keeping the model as small and simple as possible. There seems to be 3 warnings at each problematic iteration. |
Thanks for your advice! I will try to print more error information. If there are any new discoveries, I will update here. Thank you again! |
#Traning data #When the warnings happen first def forward(self, input_image, referframe):
estmv = self.opticFlow(input_image, referframe)
mv_fea = self.mvEncoder(estmv)
mv_prior = self.mvpriorEncoder(mv_fea)
quant_mvprior, mvprior_likelihoods = self.entropy_hyper_mv(mv_prior)
recon_mv_sigma = self.mvpriorDecoder(quant_mvprior)
quant_mv = self.entropy_bottleneck_mv.quantize(mv_fea, "noise" if self.training else "dequantize")
_, mv_likelihoods = self.entropy_bottleneck_mv(mv_fea, recon_mv_sigma)
recon_mv = self.mvDecoder(quant_mv)
prediction, warpframe = self.motioncompensation(referframe, recon_mv) ResidualBlockWithStride and ResidualBlock are from Compressai. The structure of my self.mvEncoder is as follows: class mvAnalysis(nn.Module):
def __init__(self):
super(mvAnalysis, self).__init__()
self.RB1 = ResidualBlockWithStride(2, out_channel, stride=2)
self.RB2 = ResidualBlock(out_channel, out_channel)
self.RB3 = ResidualBlockWithStride(out_channel, out_channel, stride=2)
self.RB4 = ResidualBlock(out_channel, out_channel)
self.RB5 = ResidualBlockWithStride(out_channel, out_channel, stride=2)
self.conv = conv3x3(out_channel, out_channel, stride=2)
def forward(self, x):
x = self.RB1(x)
x = self.RB2(x)
x = self.RB3(x)
x = self.RB4(x)
x = self.RB5(x)
out = self.conv(x)
return out #Additional information In an experiment, while keeping the learning rate of aux_optimizer at 0.001, I set the learning rate of optimizer to 0. Even so, nan still happened. I don’t know if my training settings are wrong. distribution_loss = bpp
distortion = mse_loss + warp_weight * (warploss + interloss)
rd_loss = train_lambda * distortion + distribution_loss
optimizer.zero_grad()
aux_optimizer.zero_grad()
rd_loss.backward()
def clip_gradient(optimizer, grad_clip):
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
clip_gradient(optimizer, 0.5)
optimizer.step()
aux_loss = net.aux_loss()
aux_loss.backward()
aux_optimizer.step() |
Not sure it's related, but have you tried disabling clip_gradient? Or just used torch.nn.utils.clip_grad_norm_ ? |
Hello, when I disable gradient clipping, there will still be cases where the bpp prediction is wrong. When I used torch.nn.utils.clip_grad_norm_ as an alternative, the situation got better. Now, I reduced the learning rate and adopted torch.nn.utils.clip_grad_norm_, and the training is temporarily working normally. I think it may be a gradient explosion that caused the tensor to be nan? The previous gradient clipping may not be able to deal with gradient explosion well? Anyway, thanks for your help! I think CompressAI can work well without problems. |
ok thanks for the feedback. Going to close this, since it relates to a side use case and does not break for image compression. Feel free to post in the section discussions to get additional help from other users. |
Bug
Hello, I built a video compressor with Compressai. Based on a simple hybrid coding framework, the video compressor uses hyper-prior entropy model to compress motion and residuals separately. But when training, there will always be cases where the predicted bpp is nan randomly.
Error
Expected behavior
The hyper-prior entropy model will not predict bpp as nan during training.
Environment
Additional context
I don't know why it will appear, and can't predict when it will appear. If I load a normal checkpoint and resume training again, it may not appear. If I finish the whole training process intermittently, the entropy model can also run normally. Could you please provide me with some help?
The text was updated successfully, but these errors were encountered: