You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can't help but to wonder why the NormalizingFlow class use the flows' inverse method when computing forward_kl, but, on the contrary, when using the NormalizingFlowVAE, it uses the flows' forward method.
This way, when trying to fit MNIST with NormalizingFlow, when training and passing a batch of say (64, 784) images I get the following error:
34 for i in range(len(self.flows) - 1, -1, -1):
35 z, log_det = self.flows[i].inverse(z)
---> 36 log_q += log_det
37 log_q += self.q0.log_prob(z)
38 return -torch.mean(log_q)
RuntimeError: output with shape [64] doesn't match the broadcast shape [1, 64]
Any help/suggestion?
The text was updated successfully, but these errors were encountered:
in this package, flows are defined as maps from the latent to the observation space. To compute the forward KL divergence, you have to map the observations to the latent space, i.e. apply the inverse.
On the other side, in a variational autoencoder you sample from the latent space (prior) and transform it with the flow layers, i.e. us the forward direction.
In any case, the error is not related to the fact whether you use the forward and the inverse map. Probably, it is due to a misspecification in the flow layer, e.g. that certain parameters that you use to initialize the flow layer do not have the correct shape.
I'll close the issue for now, but if this does not resolve your problem, feel free to open it and add more details about what you are doing and in which context the error occurs.
I can't help but to wonder why the NormalizingFlow class use the flows' inverse method when computing forward_kl, but, on the contrary, when using the NormalizingFlowVAE, it uses the flows' forward method.
This way, when trying to fit MNIST with NormalizingFlow, when training and passing a batch of say (64, 784) images I get the following error:
Any help/suggestion?
The text was updated successfully, but these errors were encountered: