@@ -469,7 +469,7 @@ def _apply_learned_naflex_pos_embed_grid_sample(
469
469
Based on proposal by https://github.com/stas-sl
470
470
"""
471
471
device = x .device
472
- B , C = x .shape [ 0 : 2 ]
472
+ B , N , C = x .shape
473
473
474
474
def _make_coords (h , w ):
475
475
_y , _x = torch .meshgrid (
@@ -480,10 +480,12 @@ def _make_coords(h, w):
480
480
coord = torch .stack ([_y .flatten (), _x .flatten ()], dim = 1 )
481
481
return coord
482
482
483
- coords = pad_sequence (
484
- [_make_coords (h , w ) for h , w in naflex_grid_sizes ],
485
- batch_first = True ,
486
- )
483
+ coords = torch .zeros (B , N , 2 , dtype = torch .long , device = device )
484
+ for i , (h , w ) in enumerate (naflex_grid_sizes ):
485
+ coords_i = _make_coords (h , w ) # (h*w, 2)
486
+ coords [i , :coords_i .shape [0 ]] = coords_i # pad with zeros past h*w
487
+ # FIXME should we be masking?
488
+
487
489
shapes = coords .amax (1 ) + 1
488
490
theta = torch .zeros (B , 2 , 3 , dtype = torch .float32 , device = device )
489
491
if self .pos_embed_ar_preserving :
@@ -506,7 +508,7 @@ def _make_coords(h, w):
506
508
align_corners = False ,
507
509
padding_mode = 'border' ,
508
510
).to (dtype = x .dtype )
509
- bi = torch .arange (B , device = device ).unsqueeze (1 ). expand ( - 1 , coords . shape [ 1 ])
511
+ bi = torch .arange (B , device = device ).unsqueeze (1 )
510
512
# NOTE leave as '+=', do not change to .add_(...)
511
513
x += pos_embed [bi , :, coords [..., 0 ], coords [..., 1 ]]
512
514
0 commit comments