mv adapt implementation from gitlab; upgrade to 3D#13
Conversation
| forward_time = forward_end-model_start | ||
| if torch.isnan(loss) or not torch.isfinite(loss): | ||
| print(f"NaN detected in loss at batch {batch_idx}. Skipping batch...") | ||
| continue |
There was a problem hiding this comment.
I'm wondering how would this skip affect the loss if gradient accumulation is turned on? I think it might lead to incorrect loss calculation for the samples near the skipped ones, but I'm not sure how it would work exactly.
There was a problem hiding this comment.
Good point, it will break the gradient accumulation. Let me disable it for gradient accumulation.
matey/data_utils/shared_utils.py
Outdated
| for idim, dim in enumerate(space_dims): | ||
| ntokendim.append(dim//ps[idim]) | ||
| num_tokens=reduce(mul, ntokendim) | ||
| #T,B,C,D,H,W-->T,B,C,ntz,ntx,nty,psz,psx,psy->B,C,ntz,ntx,nty->B,ntz,nty,nty |
There was a problem hiding this comment.
Last dimensions in the comment should be B,ntz,ntx,nty?
matey/data_utils/shared_utils.py
Outdated
| variance = xdata.unfold(2,ps[0],ps[0]).unfold(3,ps[1],ps[1]).unfold(4,ps[2],ps[2]).var(dim=(0,-3,-2,-1)).mean(dim=0) | ||
| assert ntokendim==list(variance.shape) | ||
| variance = rearrange(variance, 'ntz ntx nty -> (ntz ntx nty)') | ||
| #T,B,C,D,H,W-->T,B,C,ntz,ntx,nty,psz,psx,psy->B,C,ntz,ntx,nty->B,ntz,nty,nty |
There was a problem hiding this comment.
last dimension: B,ntz,ntx,nty?
| #mask2d_padding = repeat(mask_padding, 'b c1 len -> b (len1 c1) len', len1=len).unsqueeze(1) | ||
| valid = mask_padding.squeeze(1).to(torch.bool) #(B, L),True=meaningful | ||
| mask2d_padding = valid[:, None, None, :] #(B,1,1,L) | ||
| #without this backend, ran into RuntimeError: Function 'ScaledDotProductEfficientAttentionBackward0' returned nan values in its 0th output. |
There was a problem hiding this comment.
Seems like we could use this SPDA backend selector to select a flash attention backend instead of using the flash_attn_func if we want. It's not clear if they are exactly the same though.
Also F.scaled_dot_product_attention is called on line 268 as well in the forward of Attention2DBlock, so maybe add the backend there as well for consistency?
There was a problem hiding this comment.
We could, but this is a temporary fix. This might be related to pytorch versions, pytorch/pytorch#119320. We will come back to test it later.
Re line 268, we do not want to change it. We still want the more efficient implementations as long as they do not break the runs.
The long-delayed moving of adaptive tokenization implementation in gitlab repo here
To test the codes, run
sbatch submit_batch_tests.shandsbatch submit_batch_adapt.shon Frontier.