In [1]:
import torch
import timm
from timm import models
from timm.data import resolve_data_config, create_transform

In [2]:
timm.__version__

'0.9.7'

# 获取模型列表,pretrained=True代表找有预训练模型的模型

In [3]:
pretrain_model = timm.list_models(filter=["*vit*"], exclude_filters=[], pretrained=True)
len(pretrain_model)

268

In [4]:
pretrain_model

['convit_base.fb_in1k',
 'convit_small.fb_in1k',
 'convit_tiny.fb_in1k',
 'crossvit_9_240.in1k',
 'crossvit_9_dagger_240.in1k',
 'crossvit_15_240.in1k',
 'crossvit_15_dagger_240.in1k',
 'crossvit_15_dagger_408.in1k',
 'crossvit_18_240.in1k',
 'crossvit_18_dagger_240.in1k',
 'crossvit_18_dagger_408.in1k',
 'crossvit_base_240.in1k',
 'crossvit_small_240.in1k',
 'crossvit_tiny_240.in1k',
 'davit_base.msft_in1k',
 'davit_small.msft_in1k',
 'davit_tiny.msft_in1k',
 'efficientvit_b0.r224_in1k',
 'efficientvit_b1.r224_in1k',
 'efficientvit_b1.r256_in1k',
 'efficientvit_b1.r288_in1k',
 'efficientvit_b2.r224_in1k',
 'efficientvit_b2.r256_in1k',
 'efficientvit_b2.r288_in1k',
 'efficientvit_b3.r224_in1k',
 'efficientvit_b3.r256_in1k',
 'efficientvit_b3.r288_in1k',
 'efficientvit_m0.r224_in1k',
 'efficientvit_m1.r224_in1k',
 'efficientvit_m2.r224_in1k',
 'efficientvit_m3.r224_in1k',
 'efficientvit_m4.r224_in1k',
 'efficientvit_m5.r224_in1k',
 'fastvit_ma36.apple_dist_in1k',
 'fastvit_ma36.apple_in

In [5]:
model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=5)      # 默认权重
model = timm.create_model("vit_base_patch16_224.mae", pretrained=False, num_classes=5)  # 选择权重

In [6]:
model = timm.create_model("maxvit_nano_rw_256", pretrained=False, num_classes=5)
model = models.maxxvit.maxvit_nano_rw_256(pretrained=False, num_classes=5)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


# 创建对应的图片预处理，配合PIL.Image.Open('path').convert('RGB')

In [7]:
config = resolve_data_config({}, model=model)
config


{'input_size': (3, 256, 256),
 'interpolation': 'bicubic',
 'mean': (0.5, 0.5, 0.5),
 'std': (0.5, 0.5, 0.5),
 'crop_pct': 0.95,
 'crop_mode': 'center'}

In [8]:
create_transform(**config)

Compose(
    Resize(size=269, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(256, 256))
    ToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
)

# 推理

In [9]:
x = torch.ones(1, 3, 256, 256)

In [10]:
model.eval()
with torch.inference_mode():
    y = model(x)
print(y.size())

torch.Size([1, 5])


# 获取特征

## 创建一个没有池化和分类层的模型(池化之前的特征)

In [11]:
model = timm.create_model('maxvit_nano_rw_256', pretrained=False, num_classes=0, global_pool='')
model.eval()
with torch.inference_mode():
    y = model(x)
y.size()

torch.Size([1, 512, 8, 8])

## 创建一个没有分类层的模型(池化之后的特征)

In [12]:
model = timm.create_model('maxvit_nano_rw_256', pretrained=False, num_classes=0)
model.eval()
with torch.inference_mode():
    y = model(x)
y.size()

torch.Size([1, 512])

## 多尺度特征

In [13]:
model = timm.create_model('maxvit_nano_rw_256', pretrained=False, features_only=True)
model.eval()
with torch.inference_mode():
    y = model(x)
# 获取返回的通道数
print(model.feature_info.channels())    # [64, 256, 512, 1024, 2048]
print(model.feature_info.reduction())   # [2, 4, 8, 16, 32]
for layer in y:
    print(layer.size())

[64, 64, 128, 256, 512]
[2, 4, 8, 16, 32]
torch.Size([1, 64, 128, 128])
torch.Size([1, 64, 64, 64])
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 16, 16])
torch.Size([1, 512, 8, 8])


多尺度特征

可以选择特定的特征图级别（out_indices）或限制步幅（output_stride）：

out_indices:    选择输出哪个索引。指定返回哪个feature maps to return, 从0开始，out_indices[i]对应着 C(i + 1) feature level。

output_stride:  限制网络的特征输出步幅(也适用于分类模式)。通过dilated convolutions控制网络的output stride。大多数网络默认 stride 32

In [14]:
model = timm.create_model('maxvit_nano_rw_256', pretrained=False, features_only=True, out_indices=[2, 3, 4]) # 选择后3层
model.eval()
with torch.inference_mode():
    y = model(x)
# 获取返回的通道数
print(model.feature_info.channels())    # [512, 1024, 2048]
print(model.feature_info.reduction())   # [8, 16, 32]
for layer in y:
    print(layer.size())

[128, 256, 512]
[8, 16, 32]
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 16, 16])
torch.Size([1, 512, 8, 8])
