-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some questions about different losses #25
Comments
Hi,
|
I see. Thank you! |
The calculation of mse loss also involves the x part according to my understanding: target = x_start
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
terms["mse"] = mean_flat((target - model_output) ** 2)
model_out_x_start = self._x0_helper(model_output, x_t, t)['pred_xstart'] # predicted_xstart = model_output
t0_mask = (t == 0)
t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"]) Since print((target[0] - model_output[0]) ** 2, input_ids_mask[0]) # print the first sentence of one batch The output is:
We can see that the loss of x part is nonzero, which is contradictory to your paper. If I have any misunderstanding, hope someone can correct it. Thanks a lot! |
Hi,
thank you for releasing such a clean code. My questions are based on the modified ICLR accepted paper .
Here I have some questions about different losses. If I misunderstand your code, I would really appreciate your correction.
At Here, your final loss uses decoder_nll rather than term["nll_loss"]. And the calculation of decoder_nll doesn't use the input_mask.
You calculate the rounding loss as:
terms["nll"] = self._token_discrete_loss(model_out_x_start, get_logits, input_ids_x, mask=input_ids_mask, truncate=True, t=t)
But the final loss is:
The rounding loss isn't in the final loss.
Why do you calculate decoder_nll? The input to self._token_discrete_loss for decoder_nll is x_start. It is a noisy word embedding (add gaussian noise to the word embedding), should already be very close to the word embedding.
Why don't you learn sigma? The DDIM paper says a learnable sigma is beneficial.
The text was updated successfully, but these errors were encountered: