In [4]:
import timm
import torch

In [5]:
convolutional_network = timm.create_model('resnet10t', pretrained=True)

In [6]:
input_tensor = torch.randn(1, 3, 84, 84)

In [7]:
convolutional_network

ResNet(
  (conv1): Sequential(
    (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(24, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
   

In [8]:
import timm
import torch
from torch import nn

# 加载预训练模型
convolutional_network = timm.create_model('resnet10t', pretrained=True)

# 冻结所有参数
for param in convolutional_network.parameters():
    param.requires_grad = False

# 自定义返回中间层特征图的模型
class FeatureExtractor(nn.Module):
    def __init__(self, model, layers):
        super(FeatureExtractor, self).__init__()
        self.model = model
        self.layers = layers

    def forward(self, x):
        features = {}
        for name, module in self.model.named_children():
            x = module(x)
            if name in self.layers:
                features[name] = x
        return features

# 指定要提取的层
layers_to_extract = ['layer1', 'layer2', 'layer3', 'layer4']
feature_extractor = FeatureExtractor(convolutional_network, layers_to_extract)

# 输入图像
input_tensor = torch.randn(1, 3, 224, 224)

# 提取特征图
features = feature_extractor(input_tensor)
for layer_name, feature_map in features.items():
    print(f"{layer_name} 特征图形状: {feature_map.shape}")


layer1 特征图形状: torch.Size([1, 64, 56, 56])
layer2 特征图形状: torch.Size([1, 128, 28, 28])
layer3 特征图形状: torch.Size([1, 256, 14, 14])
layer4 特征图形状: torch.Size([1, 512, 7, 7])
