Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KLD Weight #56

Open
abyildirim opened this issue Apr 16, 2022 · 5 comments
Open

KLD Weight #56

abyildirim opened this issue Apr 16, 2022 · 5 comments

Comments

@abyildirim
Copy link

Hi,

In the VAE paper (https://arxiv.org/pdf/1312.6114.pdf), the VAE loss function has no additional weight parameter for the KLD loss:

vae_loss

However, in the implementation of the Vanilla VAE model, the loss function is written as below:

loss = recons_loss + kld_weight * kld_loss

When I set "kld_weight" to 1 in my model, it could not learn how to reconstruct the images. If I understand correctly, the "kld_weight" reduces the effect of the KLD loss to balance it with the reconstruction loss. However, as I mentioned, it is not defined in the VAE paper. Could anyone please explain to me why this parameter is used and why it is set to 0.00025 by default?

@wonjunior
Copy link

It is defined in equation 8 of the paper.

@abyildirim
Copy link
Author

In Equation 8, I see that the MSE loss is also scaled with N/M. However, only the KLD loss is scaled in the code. Shouldn't we scale both of them according to the equation @wonjunior ?

image
image

@dorazhang93
Copy link

Hi,
I was also confused about the kld_weight here. But I think I found the proper interpretation in this paper, Beta-VAE.
image
Given the reconstructed loss is averaged on each pixel and kld loss averaged on each latent dimension, the M here is the dimensionality of z and N is the dimensionality of input (for images, W*H). And in this implementation, kld loss was calculated by the sum of all dimensions. so kld_weight was actually 1/N=1/4096~0.00025

@angelusualle
Copy link

angelusualle commented Sep 7, 2023

N is the dimensionality of input (for images, W*H)

Couldn't it be W * H * channels? Another part of that doc says

over the individual pixels xn

@bkkm78
Copy link

bkkm78 commented Mar 30, 2024

This weight is needed when you use L2 loss as the reconstruction loss. L2 loss (aka MSE) means that you're assuming a Gaussian $p_{\theta}(x|z)$, for which you need to specify a $\sigma$ for the Gaussian distribution as a hyperparameter. This is where the relative weight between the reconstruction loss and the KL divergence comes from. If you instead assume a Bernoulli distribution and thus apply a (per pixel per channel) binary cross-entropy loss, this relative weight is not necessary.

You can refer to section 2.4.3 of Carl Doersch's tutorial on VAE for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants