Skip to content

Commit 98c6a4a

Browse files
committed
Add grid_sample pos_embed interpolation option
1 parent 4471cad commit 98c6a4a

File tree

1 file changed

+61
-14
lines changed

1 file changed

+61
-14
lines changed

timm/models/naflexvit.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch
2626
import torch.nn as nn
2727
import torch.nn.functional as F
28+
from torch.nn.utils.rnn import pad_sequence
2829

2930
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3031
from timm.layers import (
@@ -89,6 +90,7 @@ class NaFlexVitCfg:
8990
pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization
9091
pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing
9192
pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation
93+
pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation
9294

9395
# Image processing
9496
dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution
@@ -221,6 +223,7 @@ def __init__(
221223
pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
222224
pos_embed_interp_mode: str = 'bicubic',
223225
pos_embed_ar_preserving: bool = False,
226+
pos_embed_use_grid_sample: bool = False,
224227
input_norm_layer: Optional[Type[nn.Module]] = None,
225228
proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None,
226229
norm_layer: Optional[Type[nn.Module]] = None,
@@ -256,6 +259,7 @@ def __init__(
256259
self.num_reg_tokens = reg_tokens
257260
self.pos_embed_interp_mode = pos_embed_interp_mode
258261
self.pos_embed_ar_preserving = pos_embed_ar_preserving
262+
self.pos_embed_use_grid_sample = pos_embed_use_grid_sample
259263
self.patch_size = to_2tuple(patch_size)
260264
self.in_chans = in_chans
261265
self.embed_dim = embed_dim
@@ -438,18 +442,6 @@ def _interp2d(size):
438442
)[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2)
439443
return pos_embed_flat.to(dtype=x.dtype)
440444

441-
# FIXME leaving alternative code commented here for now for comparisons
442-
# pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
443-
# for i, s in enumerate(naflex_grid_sizes):
444-
# if s in pos_embed_cache:
445-
# pos_embed_flat = pos_embed_cache[s]
446-
# else:
447-
# pos_embed_flat = _interp(s)
448-
# pos_embed_cache[s] = pos_embed_flat
449-
#
450-
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
451-
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
452-
453445
# Determine unique grid sizes to avoid duplicate interpolation
454446
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
455447
for bi, k in enumerate(naflex_grid_sizes):
@@ -467,6 +459,57 @@ def _interp2d(size):
467459
pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
468460
)
469461

462+
def _apply_learned_naflex_pos_embed_grid_sample(
463+
self,
464+
x: torch.Tensor,
465+
naflex_grid_sizes: List[Tuple[int, int]],
466+
):
467+
""" NaFlex 2D position embedding interpolation using F.grid_sample.
468+
469+
Based on proposal by https://github.com/stas-sl
470+
"""
471+
device = x.device
472+
B, C = x.shape[0:2]
473+
474+
def _make_coords(h, w):
475+
_y, _x = torch.meshgrid(
476+
torch.arange(h, device=device),
477+
torch.arange(w, device=device),
478+
indexing='ij',
479+
)
480+
coord = torch.stack([_y.flatten(), _x.flatten()], dim=1)
481+
return coord
482+
483+
coords = pad_sequence(
484+
[_make_coords(h, w) for h, w in naflex_grid_sizes],
485+
batch_first=True,
486+
)
487+
shapes = coords.amax(1) + 1
488+
theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device)
489+
if self.pos_embed_ar_preserving:
490+
shape_max = shapes.amax()
491+
grid_size = (shape_max, shape_max)
492+
L = shapes.amax(1)
493+
theta[:, 0, 0] = grid_size[1] / L # scale x
494+
theta[:, 1, 1] = grid_size[0] / L # scale y
495+
else:
496+
grid_size = shapes.amax(0)
497+
theta[:, 0, 0] = grid_size[1] / shapes[:, 1] # scale x
498+
theta[:, 1, 1] = grid_size[0] / shapes[:, 0] # scale y
499+
theta[:, 0, 2] = theta[:, 0, 0] - 1 # translate x
500+
theta[:, 1, 2] = theta[:, 1, 1] - 1 # translate y
501+
grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=False)
502+
pos_embed = F.grid_sample(
503+
self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(),
504+
grid,
505+
mode=self.pos_embed_interp_mode,
506+
align_corners=False,
507+
padding_mode='border',
508+
).to(dtype=x.dtype)
509+
bi = torch.arange(B, device=device).unsqueeze(1).expand(-1, coords.shape[1])
510+
# NOTE leave as '+=', do not change to .add_(...)
511+
x += pos_embed[bi, :, coords[..., 0], coords[..., 1]]
512+
470513
def _apply_learned_pos_embed(
471514
self,
472515
x: torch.Tensor,
@@ -516,7 +559,7 @@ def _apply_factorized_naflex_pos_embed(
516559
# Handle each batch element separately with its own grid size
517560
orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
518561

519-
# bucket samples that share the same (H,W) so we build each grid once
562+
# bucket samples that share the same (H, W) so we build each grid once
520563
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
521564
for bi, k in enumerate(naflex_grid_sizes):
522565
size_to_indices.setdefault(k, []).append(bi)
@@ -630,7 +673,10 @@ def forward(
630673

631674
if self.pos_embed_type == 'learned':
632675
if naflex_grid_sizes is not None:
633-
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
676+
if self.pos_embed_use_grid_sample:
677+
self._apply_learned_naflex_pos_embed_grid_sample(x, naflex_grid_sizes=naflex_grid_sizes)
678+
else:
679+
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
634680
else:
635681
assert grid_size is not None
636682
self._apply_learned_pos_embed(x, grid_size=grid_size)
@@ -874,6 +920,7 @@ def __init__(
874920
pos_embed_grid_size=cfg.pos_embed_grid_size,
875921
pos_embed_interp_mode=cfg.pos_embed_interp_mode,
876922
pos_embed_ar_preserving=cfg.pos_embed_ar_preserving,
923+
pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample,
877924
proj_norm_layer=embed_norm_layer,
878925
pos_drop_rate=cfg.pos_drop_rate,
879926
patch_drop_rate=cfg.patch_drop_rate,

0 commit comments

Comments
 (0)