Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

314 fix transformer training #318

Merged
merged 5 commits into from
Mar 20, 2023
Merged

Conversation

marksgraham
Copy link
Collaborator

Fixes #314

@Warvito Warvito self-requested a review March 17, 2023 21:04
Copy link
Collaborator

@Warvito Warvito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Mark, thanks for working in this tutorial. During the review, I found a few things that might be wrong in the inferer (besides the ones pointed in the review). I will try to investigate it further

generative/inferers/inferer.py Show resolved Hide resolved
generative/inferers/inferer.py Outdated Show resolved Hide resolved

# if we have not covered the full sequence we continue with inefficient looping
if probs.shape[1] < latent.shape[1]:
if logits.shape[1] < latent.shape[1]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might have something wrong here, because this logits.shape[1] < latent.shape[1]: will always be true since logits are size= spatial_shape[0] * spatial_shape[1] and latent will be it +1 (BOS)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the tests, i find the logits and the latents have the same shape, unless
transformer_model.max_seq_len < (spatial_shape[0] * spatial_shape[1])+1
that is the logits also have shape (spatial_shape[0] * spatial_shape[1])+1

Copy link
Collaborator

@Warvito Warvito Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but usually the transformer.max_seq_len=(spatial_shape[0] * spatial_shape[1]). Here, are you considering cases where max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 because we pad the BOS token?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've always been setting max_seq_len = (spatial_shape[0] * spatial_shape[1])+1 in my networks. Have you been doing it without the +1? In all the tests for the VQVAETransformerInferer it is set to (spatial_shape[0] * spatial_shape[1])+1

@marksgraham marksgraham merged commit 78fde33 into main Mar 20, 2023
@Warvito Warvito deleted the 314_fix_transformer_training branch March 20, 2023 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix transformer training
2 participants