25
25
import torch
26
26
import torch .nn as nn
27
27
import torch .nn .functional as F
28
+ from torch .nn .utils .rnn import pad_sequence
28
29
29
30
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
30
31
from timm .layers import (
@@ -89,6 +90,7 @@ class NaFlexVitCfg:
89
90
pos_embed_grid_size : Optional [Tuple [int , int ]] = (16 , 16 ) # Grid size for position embedding initialization
90
91
pos_embed_interp_mode : str = 'bicubic' # Interpolation mode for position embedding resizing
91
92
pos_embed_ar_preserving : bool = False # Whether to preserve aspect ratio during position embedding interpolation
93
+ pos_embed_use_grid_sample : bool = False # Whether to use grid_sample for naflex position embedding interpolation
92
94
93
95
# Image processing
94
96
dynamic_img_pad : bool = False # Whether to enable dynamic padding for variable resolution
@@ -221,6 +223,7 @@ def __init__(
221
223
pos_embed_grid_size : Optional [Tuple [int , int ]] = (14 , 14 ),
222
224
pos_embed_interp_mode : str = 'bicubic' ,
223
225
pos_embed_ar_preserving : bool = False ,
226
+ pos_embed_use_grid_sample : bool = False ,
224
227
input_norm_layer : Optional [Type [nn .Module ]] = None ,
225
228
proj_norm_layer : Union [bool , Optional [Type [nn .Module ]]] = None ,
226
229
norm_layer : Optional [Type [nn .Module ]] = None ,
@@ -256,6 +259,7 @@ def __init__(
256
259
self .num_reg_tokens = reg_tokens
257
260
self .pos_embed_interp_mode = pos_embed_interp_mode
258
261
self .pos_embed_ar_preserving = pos_embed_ar_preserving
262
+ self .pos_embed_use_grid_sample = pos_embed_use_grid_sample
259
263
self .patch_size = to_2tuple (patch_size )
260
264
self .in_chans = in_chans
261
265
self .embed_dim = embed_dim
@@ -438,18 +442,6 @@ def _interp2d(size):
438
442
)[:, :, :size [0 ], :size [1 ]].flatten (2 ).transpose (1 , 2 )
439
443
return pos_embed_flat .to (dtype = x .dtype )
440
444
441
- # FIXME leaving alternative code commented here for now for comparisons
442
- # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
443
- # for i, s in enumerate(naflex_grid_sizes):
444
- # if s in pos_embed_cache:
445
- # pos_embed_flat = pos_embed_cache[s]
446
- # else:
447
- # pos_embed_flat = _interp(s)
448
- # pos_embed_cache[s] = pos_embed_flat
449
- #
450
- # seq_len = min(x.shape[1], pos_embed_flat.shape[1])
451
- # x[i, :seq_len] += pos_embed_flat[0, :seq_len]
452
-
453
445
# Determine unique grid sizes to avoid duplicate interpolation
454
446
size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
455
447
for bi , k in enumerate (naflex_grid_sizes ):
@@ -467,6 +459,57 @@ def _interp2d(size):
467
459
pos_embed_flat [:, :seq_len ].expand (len (batch_indices ), - 1 , - 1 )
468
460
)
469
461
462
+ def _apply_learned_naflex_pos_embed_grid_sample (
463
+ self ,
464
+ x : torch .Tensor ,
465
+ naflex_grid_sizes : List [Tuple [int , int ]],
466
+ ):
467
+ """ NaFlex 2D position embedding interpolation using F.grid_sample.
468
+
469
+ Based on proposal by https://github.com/stas-sl
470
+ """
471
+ device = x .device
472
+ B , C = x .shape [0 :2 ]
473
+
474
+ def _make_coords (h , w ):
475
+ _y , _x = torch .meshgrid (
476
+ torch .arange (h , device = device ),
477
+ torch .arange (w , device = device ),
478
+ indexing = 'ij' ,
479
+ )
480
+ coord = torch .stack ([_y .flatten (), _x .flatten ()], dim = 1 )
481
+ return coord
482
+
483
+ coords = pad_sequence (
484
+ [_make_coords (h , w ) for h , w in naflex_grid_sizes ],
485
+ batch_first = True ,
486
+ )
487
+ shapes = coords .amax (1 ) + 1
488
+ theta = torch .zeros (B , 2 , 3 , dtype = torch .float32 , device = device )
489
+ if self .pos_embed_ar_preserving :
490
+ shape_max = shapes .amax ()
491
+ grid_size = (shape_max , shape_max )
492
+ L = shapes .amax (1 )
493
+ theta [:, 0 , 0 ] = grid_size [1 ] / L # scale x
494
+ theta [:, 1 , 1 ] = grid_size [0 ] / L # scale y
495
+ else :
496
+ grid_size = shapes .amax (0 )
497
+ theta [:, 0 , 0 ] = grid_size [1 ] / shapes [:, 1 ] # scale x
498
+ theta [:, 1 , 1 ] = grid_size [0 ] / shapes [:, 0 ] # scale y
499
+ theta [:, 0 , 2 ] = theta [:, 0 , 0 ] - 1 # translate x
500
+ theta [:, 1 , 2 ] = theta [:, 1 , 1 ] - 1 # translate y
501
+ grid = F .affine_grid (theta , (B , C , * grid_size ), align_corners = False )
502
+ pos_embed = F .grid_sample (
503
+ self .pos_embed .permute (0 , 3 , 1 , 2 ).expand (B , - 1 , - 1 , - 1 ).float (),
504
+ grid ,
505
+ mode = self .pos_embed_interp_mode ,
506
+ align_corners = False ,
507
+ padding_mode = 'border' ,
508
+ ).to (dtype = x .dtype )
509
+ bi = torch .arange (B , device = device ).unsqueeze (1 ).expand (- 1 , coords .shape [1 ])
510
+ # NOTE leave as '+=', do not change to .add_(...)
511
+ x += pos_embed [bi , :, coords [..., 0 ], coords [..., 1 ]]
512
+
470
513
def _apply_learned_pos_embed (
471
514
self ,
472
515
x : torch .Tensor ,
@@ -516,7 +559,7 @@ def _apply_factorized_naflex_pos_embed(
516
559
# Handle each batch element separately with its own grid size
517
560
orig_h , orig_w = self .pos_embed_y .shape [1 ], self .pos_embed_x .shape [1 ]
518
561
519
- # bucket samples that share the same (H,W) so we build each grid once
562
+ # bucket samples that share the same (H, W) so we build each grid once
520
563
size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
521
564
for bi , k in enumerate (naflex_grid_sizes ):
522
565
size_to_indices .setdefault (k , []).append (bi )
@@ -630,7 +673,10 @@ def forward(
630
673
631
674
if self .pos_embed_type == 'learned' :
632
675
if naflex_grid_sizes is not None :
633
- self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
676
+ if self .pos_embed_use_grid_sample :
677
+ self ._apply_learned_naflex_pos_embed_grid_sample (x , naflex_grid_sizes = naflex_grid_sizes )
678
+ else :
679
+ self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
634
680
else :
635
681
assert grid_size is not None
636
682
self ._apply_learned_pos_embed (x , grid_size = grid_size )
@@ -874,6 +920,7 @@ def __init__(
874
920
pos_embed_grid_size = cfg .pos_embed_grid_size ,
875
921
pos_embed_interp_mode = cfg .pos_embed_interp_mode ,
876
922
pos_embed_ar_preserving = cfg .pos_embed_ar_preserving ,
923
+ pos_embed_use_grid_sample = cfg .pos_embed_use_grid_sample ,
877
924
proj_norm_layer = embed_norm_layer ,
878
925
pos_drop_rate = cfg .pos_drop_rate ,
879
926
patch_drop_rate = cfg .patch_drop_rate ,
0 commit comments