In [1]:
import torch
from torchvision import models, transforms
from PIL import Image

resnet = models.resnet50(pretrained=True).to(torch.device('cuda'))
resnet = torch.nn.Sequential(*list(resnet.children())[:-2])
resnet.eval()



Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [2]:
# 定义图片预处理步骤
preprocess = transforms.Compose([
    transforms.Resize(256),            # 缩放图片，使较短的边为256像素
    transforms.CenterCrop(224),        # 从图片中心裁剪224x224的图片
    transforms.ToTensor(),             # 将图片转换为Tensor
    transforms.Normalize(              # 归一化处理
        mean=[0.485, 0.456, 0.406],    # 使用ImageNet的均值
        std=[0.229, 0.224, 0.225]      # 使用ImageNet的标准差
    ),
])

# 加载图片
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")  # 确保图片为三通道RGB格式
    image = preprocess(image)  # 对图片进行预处理
    image = image.unsqueeze(0).to(torch.device('cuda'))  # 增加一个维度表示批次大小，并转移到GPU
    return image

# 提取图片特征
def extract_features(image_path):
    image = load_image(image_path)
    with torch.no_grad():  # 不计算梯度，节省计算资源
        features = resnet(image)  # 提取特征
    return features



In [3]:
# 示例用法
image_path = 'obj_detector/faster_r_cnn/images/img0.jpg'  # 替换为你的图片路径
features = extract_features(image_path)
print(features.shape)  # 查看特征的形状

torch.Size([1, 2048, 7, 7])
