Skip to content

mv adapt implementation from gitlab; upgrade to 3D#13

Merged
pzhanggit merged 2 commits intoORNL:mainfrom
pzhanggit:cleanup-adapttoken
Jan 12, 2026
Merged

mv adapt implementation from gitlab; upgrade to 3D#13
pzhanggit merged 2 commits intoORNL:mainfrom
pzhanggit:cleanup-adapttoken

Conversation

@pzhanggit
Copy link
Copy Markdown
Collaborator

The long-delayed moving of adaptive tokenization implementation in gitlab repo here

  • move adaptive tokens from data to model forward
  • supporting 3D and multiple levels (not tested yet)
  • clean up codes

To test the codes, run sbatch submit_batch_tests.sh and sbatch submit_batch_adapt.sh on Frontier.

forward_time = forward_end-model_start
if torch.isnan(loss) or not torch.isfinite(loss):
print(f"NaN detected in loss at batch {batch_idx}. Skipping batch...")
continue
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering how would this skip affect the loss if gradient accumulation is turned on? I think it might lead to incorrect loss calculation for the samples near the skipped ones, but I'm not sure how it would work exactly.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, it will break the gradient accumulation. Let me disable it for gradient accumulation.

for idim, dim in enumerate(space_dims):
ntokendim.append(dim//ps[idim])
num_tokens=reduce(mul, ntokendim)
#T,B,C,D,H,W-->T,B,C,ntz,ntx,nty,psz,psx,psy->B,C,ntz,ntx,nty->B,ntz,nty,nty
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last dimensions in the comment should be B,ntz,ntx,nty?

variance = xdata.unfold(2,ps[0],ps[0]).unfold(3,ps[1],ps[1]).unfold(4,ps[2],ps[2]).var(dim=(0,-3,-2,-1)).mean(dim=0)
assert ntokendim==list(variance.shape)
variance = rearrange(variance, 'ntz ntx nty -> (ntz ntx nty)')
#T,B,C,D,H,W-->T,B,C,ntz,ntx,nty,psz,psx,psy->B,C,ntz,ntx,nty->B,ntz,nty,nty
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last dimension: B,ntz,ntx,nty?

#mask2d_padding = repeat(mask_padding, 'b c1 len -> b (len1 c1) len', len1=len).unsqueeze(1)
valid = mask_padding.squeeze(1).to(torch.bool) #(B, L),True=meaningful
mask2d_padding = valid[:, None, None, :] #(B,1,1,L)
#without this backend, ran into RuntimeError: Function 'ScaledDotProductEfficientAttentionBackward0' returned nan values in its 0th output.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we could use this SPDA backend selector to select a flash attention backend instead of using the flash_attn_func if we want. It's not clear if they are exactly the same though.
Also F.scaled_dot_product_attention is called on line 268 as well in the forward of Attention2DBlock, so maybe add the backend there as well for consistency?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but this is a temporary fix. This might be related to pytorch versions, pytorch/pytorch#119320. We will come back to test it later.

Re line 268, we do not want to change it. We still want the more efficient implementations as long as they do not break the runs.

@pzhanggit pzhanggit requested a review from TsChala January 12, 2026 20:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants