Skip to content

Commit 4e3cba8

Browse files
committed
Fix silly shape bug, and fix issue with pad_sequence when none of the shapes in batch use full seq len.
1 parent c2ba04c commit 4e3cba8

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

timm/models/naflexvit.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def _apply_learned_naflex_pos_embed_grid_sample(
469469
Based on proposal by https://github.com/stas-sl
470470
"""
471471
device = x.device
472-
B, C = x.shape[0:2]
472+
B, N, C = x.shape
473473

474474
def _make_coords(h, w):
475475
_y, _x = torch.meshgrid(
@@ -480,10 +480,12 @@ def _make_coords(h, w):
480480
coord = torch.stack([_y.flatten(), _x.flatten()], dim=1)
481481
return coord
482482

483-
coords = pad_sequence(
484-
[_make_coords(h, w) for h, w in naflex_grid_sizes],
485-
batch_first=True,
486-
)
483+
coords = torch.zeros(B, N, 2, dtype=torch.long, device=device)
484+
for i, (h, w) in enumerate(naflex_grid_sizes):
485+
coords_i = _make_coords(h, w) # (h*w, 2)
486+
coords[i, :coords_i.shape[0]] = coords_i # pad with zeros past h*w
487+
# FIXME should we be masking?
488+
487489
shapes = coords.amax(1) + 1
488490
theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device)
489491
if self.pos_embed_ar_preserving:
@@ -506,7 +508,7 @@ def _make_coords(h, w):
506508
align_corners=False,
507509
padding_mode='border',
508510
).to(dtype=x.dtype)
509-
bi = torch.arange(B, device=device).unsqueeze(1).expand(-1, coords.shape[1])
511+
bi = torch.arange(B, device=device).unsqueeze(1)
510512
# NOTE leave as '+=', do not change to .add_(...)
511513
x += pos_embed[bi, :, coords[..., 0], coords[..., 1]]
512514

0 commit comments

Comments
 (0)