In [1]:
import torch
import torchvision
import torchvision.models as models

In [3]:
def test_resnet_output(model_name, input_size=(3, 224, 224)):
    """
    测试给定ResNet模型对任意输入尺寸的输出大小。
    
    参数：
    - model_name (str): 要测试的ResNet模型名称，如'resnet18', 'resnet50', 'resnet101'
    - input_size (tuple): 输入的图像尺寸，格式为 (channels, height, width)，默认为 (3, 224, 224)
    
    返回：
    - 输出特征图的尺寸
    """
    # 根据模型名称获取对应的ResNet模型
    model_dict = {
        'resnet18': models.resnet18,
        'resnet34': models.resnet34,
        'resnet50': models.resnet50,
        'resnet101': models.resnet101,
        'resnet152': models.resnet152
    }
    
    if model_name not in model_dict:
        raise ValueError(f"模型 {model_name} 不存在。可选模型有：{list(model_dict.keys())}")
    
    # 加载模型
    model = model_dict[model_name](pretrained=False)
    
    # 移除最后的全连接层以便查看特征图的尺寸
    model = torch.nn.Sequential(*list(model.children())[:-2])
    # model = torch.nn.Sequential(*list(model.children())[:])
    
    # 创建随机输入
    input_tensor = torch.randn(64, *input_size)  # batch_size=64
    
    # 前向传播
    with torch.no_grad():
        output = model(input_tensor)
    
    print(f"模型 {model_name} 对输入大小 {input_size} 的输出尺寸为: {output.shape}")
    
    return output.shape

# 使用示例
if __name__ == "__main__":
    test_resnet_output('resnet18', input_size=(3, 256, 256))  # 测试ResNet18模型对256x256图像的输出尺寸
    test_resnet_output('resnet101', input_size=(3, 256, 256))  # 测试ResNet101模型对128x128图像的输出尺寸
    test_resnet_output('resnet50', input_size=(3, 256, 256))  # 测试ResNet50模型对512x512图像的输出尺寸

模型 resnet18 对输入大小 (3, 256, 256) 的输出尺寸为: torch.Size([64, 512, 8, 8])
模型 resnet101 对输入大小 (3, 256, 256) 的输出尺寸为: torch.Size([64, 2048, 8, 8])
模型 resnet50 对输入大小 (3, 256, 256) 的输出尺寸为: torch.Size([64, 2048, 8, 8])


In [4]:
net = torchvision.models.resnet18(num_classes=128)  # Using ResNet18 with the low_dim output
net

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [12]:
from torchvision.models import resnet18
from torchvision.models import resnet101

# model = resnet18(pretrained=False)
# print(model, "\n\n\n")
# model = torch.nn.Sequential(*list(model.children())[:-2])
# print(model)

model = resnet101(pretrained=False)
print(model, "\n\n\n")
model = torch.nn.Sequential(*list(model.children())[:6])
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 