Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About Rate Distortion Loss #199

Closed
Indraa145 opened this issue Feb 5, 2023 · 4 comments
Closed

About Rate Distortion Loss #199

Indraa145 opened this issue Feb 5, 2023 · 4 comments

Comments

@Indraa145
Copy link

Hello, thank you for the work, I'd like to ask about the different formula of the Rate Distortion Loss from your custom training documentation and from your RateDistortionLoss class.

On your custom training documentation, the Rate Distortion Loss is defined as:

$$L=D+\lambda*R$$
x = torch.rand(1, 3, 64, 64)
net = Network()
x_hat, y_likelihoods = net(x)

# bitrate of the quantized latent
N, _, H, W = x.size()
num_pixels = N * H * W
bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels)

# mean square error
mse_loss = F.mse_loss(x, x_hat)

# final loss term
loss = mse_loss + lmbda * bpp_loss

While on your RateDistortionLoss class, which is used in your examples/train.py, it is:

$$L=\lambda*255^2*D+R$$
N, _, H, W = target.size()
out = {}
num_pixels = N * H * W

out["bpp_loss"] = sum(
   (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
   for likelihoods in output["likelihoods"].values()
)
out["mse_loss"] = self.mse(output["x_hat"], target)
out["loss"] = self.lmbda * 255**2 * out["mse_loss"] + out["bpp_loss"]

return out

I also notice that there's a difference in the bpp_loss calculation. In the RateDistortionClass, you sum all the bpp_loss. I also want to know why is this the case, are you summing all the bpp_loss across all the batches?

I'm wondering which loss is better to use? And is there paper that I can refer to regarding this? Thank you very much.

@YodaEmbedding
Copy link
Contributor

YodaEmbedding commented Feb 5, 2023

  1. The bpp_loss calculation in both code samples is exactly the same for a model that only outputs $y$ and has total rate $R_y$ (e.g. the bmshj2018-factorized model). However, the second code sample will correctly calculate $R_y + R_z$ if $z$ is also outputted by the model. Note that output["likelihoods"].values() == (y_likelihoods, z_likelihoods) in that case.

  2. Both code samples estimate the average rate over the batches. This is what $R$ typically refers to -- the average or expected value of the log likelihoods over the entire sample space of images $\mathcal{X} = \{x_1, x_2, \ldots \}$ and their corresponding latents $\mathcal{Y} = \{g_a(x) : x \in \mathcal{X}\} = \{y_1, y_2, \ldots \}$,

    $$R_y = H(Y) = \mathbb{E}_{y \in \mathcal{Y}}[-\log p(y)]$$

    In this case, we are taking a 16-sample (i.e. batch size) Monte Carlo estimate of the entropy.

  3. Optimizing a "16-sample average" $R,D$ is better than a "1-sample" $R,D$ since our goal is good average performance over a test dataset such as Kodak. In practice, I believe there isn't actually that much of a difference with SGD since it effectively tends to optimize for the average loss anyways. If we were to strongly optimize for the individual R-D performance via e.g. single sample R-D with an exaggerated loss such as $L = (R + D)^2$ for $L > 1$, the average R-D performance over the dataset might suffer. Or perhaps not.


Figure: Single sample R-D points (blue) and their average (orange) over the Kodak dataset.

@Indraa145
Copy link
Author

Thank you for the answer, but I'm still a bit confused with the different formula for the final loss term. In the custom training documentation, it's defined as:

$$L=D+\lambda*R$$

While in the RateDistortionClass, it's defined as:

$$L=\lambda*255^2*D+R$$

Why is the $\lambda$ multiplied with $255^2$ and $D$ here? As opposed to multiplied with $R$ like in the custom training documentation example.

@YodaEmbedding
Copy link
Contributor

YodaEmbedding commented Feb 6, 2023

It's just a scaling constant. It has no tangible effect, as long as $\lambda$ is accordingly rescaled by the same amount.

Why 255? An 8-bit pixel has intensities in the interval $[0, 255]$, where $255 = 2^8 - 1$. The input image is "renormalized" so that $x \in [0, 1]$ instead. Then,

$$\lambda 255^2 D = \lambda 255^2 \text{MSE}(x, \hat{x}) = \lambda \text{MSE}(255 x, 255 \hat{x}) = \lambda D'$$

where $D'$ is a more typical "distortion" value that compression engineers may be used to.

It's not necessary to include $255$, though the $\lambda$ values given in the documentation assume the $255$ is there.

@Indraa145
Copy link
Author

Ah, I see. Thank you for the explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants