-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
314 fix transformer training #318
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Mark, thanks for working in this tutorial. During the review, I found a few things that might be wrong in the inferer (besides the ones pointed in the review). I will try to investigate it further
|
||
# if we have not covered the full sequence we continue with inefficient looping | ||
if probs.shape[1] < latent.shape[1]: | ||
if logits.shape[1] < latent.shape[1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might have something wrong here, because this logits.shape[1] < latent.shape[1]:
will always be true since logits are size= spatial_shape[0] * spatial_shape[1]
and latent will be it +1 (BOS)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the tests, i find the logits and the latents have the same shape, unless
transformer_model.max_seq_len < (spatial_shape[0] * spatial_shape[1])+1
that is the logits also have shape (spatial_shape[0] * spatial_shape[1])+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but usually the transformer.max_seq_len=(spatial_shape[0] * spatial_shape[1]). Here, are you considering cases where max_seq_len = (spatial_shape[0] * spatial_shape[1])+1
because we pad the BOS token?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I've always been setting max_seq_len = (spatial_shape[0] * spatial_shape[1])+1
in my networks. Have you been doing it without the +1? In all the tests for the VQVAETransformerInferer
it is set to (spatial_shape[0] * spatial_shape[1])+1
Fixes #314