Skip to content

Commit c2ba04c

Browse files
committed
Slight tweak
1 parent 98c6a4a commit c2ba04c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

timm/models/naflexvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ def _make_coords(h, w):
487487
shapes = coords.amax(1) + 1
488488
theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device)
489489
if self.pos_embed_ar_preserving:
490-
shape_max = shapes.amax()
491-
grid_size = (shape_max, shape_max)
492490
L = shapes.amax(1)
491+
grid_max = L.amax()
492+
grid_size = (grid_max, grid_max)
493493
theta[:, 0, 0] = grid_size[1] / L # scale x
494494
theta[:, 1, 1] = grid_size[0] / L # scale y
495495
else:

0 commit comments

Comments
 (0)