In [4]:
import torch
import torch.nn as nn

from timm.models.vision_transformer import VisionTransformer, PatchEmbed
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

# 确保 VPT 类已定义
class VPT(VisionTransformer):
    def __init__(self, image_size=224, patch_size=16, in_ch=3, num_classes=120, embed_dim=768,
                 depth=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 embed_layer=PatchEmbed, norm_layer=None, act_layer=None, prompt_num=100, state_dict=None, num_heads=12):

        super().__init__(img_size=image_size, patch_size=patch_size, in_chans=in_ch, num_classes=num_classes,
                         embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                         qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
                         drop_path_rate=drop_path_rate, embed_layer=embed_layer,
                         norm_layer=norm_layer, act_layer=act_layer)
        
        self.prompt_num = prompt_num
        self.depth = depth
        self.prompt = nn.Parameter(torch.zeros(self.depth, self.prompt_num, embed_dim))
        self.head = nn.Linear(self.embed_dim, self.num_classes)
        if state_dict is not None:
            self.load_state_dict(state_dict, strict=False)  

    def Freeze(self):
        for param in self.parameters():
            param.requires_grad = False
        self.prompt.requires_grad = True
        for param in self.head.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        
        for i in range(self.depth):
            x = torch.cat((x, self.prompt[i].unsqueeze(0).expand(x.shape[0], -1, -1)), dim=1)
            num_tokens = x.shape[1]
            x = self.blocks[i](x)
#             print(x.shape)
            x = x[:, :num_tokens - self.prompt_num]
#             print(x.shape)
        
        x = self.blocks(x)
        x = self.norm(x)
        x = self.fc_norm(x[:, 0, :])
        x = self.head(x)
        return x

# 加载预训练模型参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_path = "pytorch_model.bin"
state_dict = torch.load(pretrained_path, map_location=device)
model = VPT(num_classes=120, state_dict = state_dict).to(device)
model.Freeze()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# model

In [5]:
random_noise = torch.randn(8, 3, 224, 224)

random_noise = random_noise.to(device)

# 前向传播测试
with torch.no_grad():
    output = model(random_noise)
#     predicted_class = torch.argmax(output, dim=1).item()

# 输出结果
print("Random Noise Input Shape:", random_noise.shape)
print("Model Output Shape:", output.shape)
# print("Predicted Class:", predicted_class)


Random Noise Input Shape: torch.Size([8, 3, 224, 224])
Model Output Shape: torch.Size([8, 120])
