Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some questions about different losses #25

Closed
BaohaoLiao opened this issue Feb 8, 2023 · 3 comments
Closed

Some questions about different losses #25

BaohaoLiao opened this issue Feb 8, 2023 · 3 comments

Comments

@BaohaoLiao
Copy link

BaohaoLiao commented Feb 8, 2023

Hi,
thank you for releasing such a clean code. My questions are based on the modified ICLR accepted paper .

Here I have some questions about different losses. If I misunderstand your code, I would really appreciate your correction.

  1. It seems you calculate the cross-entropy loss over both x and y rather than just y, which is different from your paper "Note that although in the first term we only compute the loss w.r.t y0, due to the attention mechanism in the transformer, the reconstruction of y0 also takes x0 into account, thus the gradients from the first term will also affect the learning of x0."

At Here, your final loss uses decoder_nll rather than term["nll_loss"]. And the calculation of decoder_nll doesn't use the input_mask.

  1. A related question to 1. It seems you don't use the rounding loss in the final loss.
    You calculate the rounding loss as:
    terms["nll"] = self._token_discrete_loss(model_out_x_start, get_logits, input_ids_x, mask=input_ids_mask, truncate=True, t=t)
    But the final loss is:
decoder_nll = self._token_discrete_loss(x_start, get_logits, input_ids_x)
terms["loss"] = terms["mse"] + decoder_nll + tT_loss

The rounding loss isn't in the final loss.

  1. Why do you calculate decoder_nll? The input to self._token_discrete_loss for decoder_nll is x_start. It is a noisy word embedding (add gaussian noise to the word embedding), should already be very close to the word embedding.

  2. Why don't you learn sigma? The DDIM paper says a learnable sigma is beneficial.

@BaohaoLiao BaohaoLiao changed the title Some questions about the implementation details Some questions about different losses Feb 8, 2023
@BaohaoLiao BaohaoLiao reopened this Feb 10, 2023
@summmeer
Copy link
Collaborator

Hi,

  1. "we only compute the loss w.r.t y0" refers to the mse loss, the x part is transformed to the regularization term.
  2. "nll" loss is for tracing the generation quality
  3. you can regard the "decoder_nll" as a regularization term of the embedding vectors.
  4. we didn't try this setting, you're free to try it.

@BaohaoLiao
Copy link
Author

I see. Thank you!

@swave-demo
Copy link

swave-demo commented Jun 3, 2024

Hi,

  1. "we only compute the loss w.r.t y0" refers to the mse loss, the x part is transformed to the regularization term.
  2. "nll" loss is for tracing the generation quality
  3. you can regard the "decoder_nll" as a regularization term of the embedding vectors.
  4. we didn't try this setting, you're free to try it.

The calculation of mse loss also involves the x part according to my understanding:

target = x_start
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
terms["mse"] = mean_flat((target - model_output) ** 2)

model_out_x_start = self._x0_helper(model_output, x_t, t)['pred_xstart'] # predicted_xstart = model_output
t0_mask = (t == 0)
t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])

Since predict_xstart in the config.json is True, model_output is acctually the estimated x_start. You just directly calculte the mse loss between x_start and model_output without input_mask, so the x part is also involved in the mse loss. I print the result with the following code:

print((target[0] - model_output[0]) ** 2, input_ids_mask[0]) # print the first sentence of one batch

The output is:

tensor([[2.0462e+00, 3.8795e-01, 3.2121e-03,  ..., 2.4803e-01, 1.7676e-01,
         4.3906e-01],
        [4.9620e+00, 5.4831e+00, 1.3603e+00,  ..., 4.6070e+00, 3.6652e+00,
         4.0038e-01],
        [3.2369e-03, 8.0437e-01, 3.5606e-01,  ..., 1.2198e-01, 5.0738e-01,
         6.2278e-03],
        ...,
        [3.9527e-01, 1.4870e+00, 6.3621e+00,  ..., 2.2396e-02, 6.8237e-03,
         1.6679e-01],
        [6.9479e-01, 3.2860e-02, 7.0464e+00,  ..., 1.8594e-01, 3.4218e-01,
         7.5128e-03],
        [1.6733e+00, 4.8956e-01, 5.6478e+00,  ..., 3.3924e-02, 1.5616e-02,
         1.0451e-01]], device='cuda:0', grad_fn=<PowBackward0>) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

We can see that the loss of x part is nonzero, which is contradictory to your paper. If I have any misunderstanding, hope someone can correct it. Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants