In [None]:
from urllib.request import urlopen
from PIL import Image
import timm


img = Image.open('./image.png')

model = timm.create_model(
    'fastvit_sa12.apple_in1k',
    pretrained=True,
    features_only=True,
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

for o in output:
    # print shape of each feature map in output
    # e.g.:
    #  torch.Size([1, 64, 64, 64])
    #  torch.Size([1, 128, 32, 32])
    #  torch.Size([1, 256, 16, 16])
    #  torch.Size([1, 512, 8, 8])

    print(o.shape)


  from .autonotebook import tqdm as notebook_tqdm


torch.Size([1, 64, 64, 64])
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 16, 16])
torch.Size([1, 512, 8, 8])


In [4]:
# Check FastVit's normalization requirements
import timm

model = timm.create_model(
    'fastvit_sa12.apple_in1k',
    pretrained=True,
    features_only=True,
)

data_config = timm.data.resolve_model_data_config(model)
print("FastVit normalization:")
print(f"  Mean: {data_config['mean']}")
print(f"  Std: {data_config['std']}")

# Compare with standard ImageNet normalization
print("\nStandard ImageNet normalization:")
print(f"  Mean: (0.485, 0.456, 0.406)")
print(f"  Std: (0.229, 0.224, 0.225)")

print("\nAre they the same?", data_config['mean'] == (0.485, 0.456, 0.406) and data_config['std'] == (0.229, 0.224, 0.225))

FastVit normalization:
  Mean: (0.485, 0.456, 0.406)
  Std: (0.229, 0.224, 0.225)

Standard ImageNet normalization:
  Mean: (0.485, 0.456, 0.406)
  Std: (0.229, 0.224, 0.225)

Are they the same? True


In [5]:
# Check what backbone_feature_dim would be
import timm

model = timm.create_model(
    'fastvit_sa12.apple_in1k',
    pretrained=True,
    features_only=True,
)

print("Feature info for all layers:")
for i, info in enumerate(model.feature_info):
    print(f"  Layer {i}: {info['num_chs']} channels (module: {info['module']})")

print(f"\nLast layer channels (what ACT will use): {model.feature_info[-1]['num_chs']}")

Feature info for all layers:
  Layer 0: 64 channels (module: stages.0)
  Layer 1: 128 channels (module: stages.1)
  Layer 2: 256 channels (module: stages.2)
  Layer 3: 512 channels (module: stages.3)

Last layer channels (what ACT will use): 512


In [2]:
# CRITICAL: Compare spatial resolutions (this is the VRAM killer!)
import torch
import torchvision
import timm

# Simulate your dataset's image size
img_size = (224, 224)  # Change to your actual image size
dummy_input = torch.randn(1, 3, *img_size)

# ResNet18
resnet = torchvision.models.resnet18(pretrained=False)
from torchvision.models._utils import IntermediateLayerGetter
resnet_backbone = IntermediateLayerGetter(resnet, return_layers={"layer4": "feature_map"})
resnet_out = resnet_backbone(dummy_input)["feature_map"]

# FastVit
fastvit = timm.create_model('fastvit_sa12.apple_in1k', pretrained=False, features_only=True)
fastvit_out = fastvit(dummy_input)[-1]

print("=" * 60)
print("SPATIAL RESOLUTION COMPARISON (This is why you run out of VRAM!)")
print("=" * 60)
print(f"ResNet18 output shape:  {resnet_out.shape}")
print(f"FastVit output shape:   {fastvit_out.shape}")
print()
print(f"ResNet18 tokens per image:  {resnet_out.shape[2] * resnet_out.shape[3]}")
print(f"FastVit tokens per image:   {fastvit_out.shape[2] * fastvit_out.shape[3]}")
print()
ratio = (fastvit_out.shape[2] * fastvit_out.shape[3]) / (resnet_out.shape[2] * resnet_out.shape[3])
print(f"FastVit has {ratio:.1f}x MORE tokens!")
print(f"Attention memory scales with O(n²), so ~{ratio**2:.1f}x more VRAM!")
print("=" * 60)



SPATIAL RESOLUTION COMPARISON (This is why you run out of VRAM!)
ResNet18 output shape:  torch.Size([1, 512, 7, 7])
FastVit output shape:   torch.Size([1, 512, 7, 7])

ResNet18 tokens per image:  49
FastVit tokens per image:   49

FastVit has 1.0x MORE tokens!
Attention memory scales with O(n²), so ~1.0x more VRAM!


In [5]:
import torch.nn as nn
# test custom wrapper
class TimmFeatureExtractorWrapper(nn.Module):
    """Wrapper for timm models to match IntermediateLayerGetter's return format."""
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, x):
        # timm features_only models return a list of feature maps
        # With out_indices=[-1], we only get the last feature map
        features = self.model(x)
        return {"feature_map": features[0] if len(features) == 1 else features[-1]}
    

image_feature_extractor = timm.create_model(
    'fastvit_sa12.apple_in1k',
    pretrained=True,
    features_only=True,
)
wrapped_extractor = TimmFeatureExtractorWrapper(image_feature_extractor)
wrapped_extractor.eval()
dummy_input = torch.randn(1, 3, 224, 224)
output = wrapped_extractor(dummy_input)
print("Wrapped extractor output keys:", output.keys())

Wrapped extractor output keys: dict_keys(['feature_map'])


## Tesing `Shuffle Net`

In [18]:
# suffle net
import torch
from torch import nn
import torchvision
from PIL import Image
# intermediate layer getter
from torchvision.models._utils import IntermediateLayerGetter

weights = torchvision.models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1
model = torchvision.models.shufflenet_v2_x1_0(weights=weights)
backbone = IntermediateLayerGetter(model, return_layers={"stage4": "feature_map"})

image = Image.open('./image.png')
# preprocess
preprocess = weights.transforms()
input_tensor = preprocess(image).unsqueeze(0)  # create a mini-batch as
input_batch = input_tensor  # single image batch
# inference
with torch.no_grad():
    output = backbone(input_batch)

print(model)

ShuffleNetV2(
  (conv1): Sequential(
    (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (stage2): Sequential(
    (0): InvertedResidual(
      (branch1): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2d(24, 58, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU(inplace=True)
      )
      (branch2): Sequential(
        (0): Conv2d(24, 58, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(58, eps=1e-05, momentum=0.1, affine=True, track_running_

In [None]:
backbone_feature_dim = model.conv5[0].in_channels
print("Backbone feature dim:", backbone_feature_dim)
# using conv1d for 
input_proj = nn.Conv2d(
                backbone_feature_dim, 512, kernel_size=1
            )
input_proj_output = input_proj(output["feature_map"])
print("Input projection output shape:", input_proj_output.shape)

Backbone feature dim: 464


In [2]:
#  dino 
import timm 

backbone_model = timm.create_model(
                'vit_small_patch16_224.dino',
                pretrained=True,
                features_only=True,
                out_indices=[-1],  # Only extract last layer to avoid memory leak
            )
backbone_feature_dim = backbone_model.feature_info[-1]['num_chs']

In [3]:
backbone_feature_dim

384