In [1]:
import torch
import torch.nn as nn

from timm.models.vision_transformer import VisionTransformer, PatchEmbed
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class ImagePrompt(nn.Module):
    def __init__(self):
        super().__init__()
        self.n = 16
        self.h = 16
        self.net1 = nn.Sequential(
            nn.Linear(3 * self.n * self.n, self.h),
            nn.ReLU(),
            nn.Linear(self.h, 3 * self.n * self.n)
        )
        self.dropout1 = torch.nn.Dropout(0.1)
        self.dropout2 = torch.nn.Dropout(0.1)
        self.net2 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 3, 3, stride=1, padding=1)
        )
        self.w = nn.Parameter(torch.tensor(0.7))

    def get_local_prompts(self, x):
            # [64, 3, 224, 224]
        B = x.shape[0]
        n_patch = int(224 / self.n)
        x = x.reshape(B, 3, n_patch, self.n, n_patch, self.n) # [64, 3, 14, 16, 14, 16]
        x = x.permute(0, 2, 4, 1, 3, 5) # [64, 14, 14, 3, 16, 16]
        x = x.reshape(B, n_patch * n_patch, 3 * self.n * self.n)
        x = x.reshape(B * n_patch * n_patch, 3 * self.n * self.n)
        x = self.net1(x)
        x = x.reshape(B, n_patch, n_patch, 3, self.n, self.n)
        x = x.permute(0, 3, 1, 4, 2, 5) # [64, 3, 14, 16, 14, 16]
        x = x.reshape(B, 3, 224, 224)
        return self.dropout1(x)
    
    def forward(self, x):
        prompt1 = self.get_local_prompts(x)
        prompt2 = self.dropout2(self.net2(x))
        w = torch.sigmoid(self.w) # 保证在（0，1）
        return (1 - w) * prompt1 + w * prompt2 + x
    
# model = Local_prompt().to(device)
# random_noise = torch.randn(8, 3, 224, 224)
# random_noise = random_noise.to(device)

# # 前向传播测试
# with torch.no_grad():
#     output = model(random_noise)

# # 输出结果
# print("Random Noise Input Shape:", random_noise.shape)
# print("Model Output Shape:", output.shape)

In [3]:
class TokenPrompt(nn.Module):
    def __init__(self, prompt_num = 9):
        super().__init__()
        self.p = prompt_num
        self.conv1 = nn.Conv2d(3, self.p, kernel_size=7, padding=3)
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool2d(4, 4)
        self.dropout1 = nn.Dropout(0.1)
        self.conv2 = nn.Conv2d(self.p, 3 * self.p, kernel_size = 9, padding = 4)
        self.relu2 = nn.LeakyReLU()
        self.pool2 = nn.MaxPool2d(3, 3)
        self.dropout2 = nn.Dropout(0.1)
        self.conv3 = nn.Conv2d(3 * self.p, 3 * self.p, kernel_size = 3, padding = 1)
             
    def forward(self, x):
        x = self.conv1(x) 
        x = self.relu1(x)
        x = x[:,:,8:216,8:216]
        x = self.pool1(x) # [B, 9, 56, 56]
        x = self.dropout1(x)
#         print(x.shape)
        
        x = self.conv2(x) 
        x = self.relu2(x)
        x = x[:,:,2:50,2:50]
        x = self.pool2(x) # [B, 27, 16, 16]
        x = self.dropout2(x)
        
        x = self.conv3(x) 
        x = x.reshape(-1, self.p, 768)
#         print(x.shape)
        return x

# model = TokenPrompt().to(device)
# random_noise = torch.randn(8, 3, 224, 224)
# random_noise = random_noise.to(device)

# with torch.no_grad():
#     output = model(random_noise)

# # 输出结果
# print("Random Noise Input Shape:", random_noise.shape)
# print("Model Output Shape:", output.shape)

In [4]:
class InsVP(VisionTransformer):
    def __init__(self, image_size=224, patch_size=16, in_ch=3, num_classes=120, embed_dim=768,
                 depth=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 embed_layer=PatchEmbed, norm_layer=None, act_layer=None, prompt_num=9, state_dict=None, num_heads=12):

        super().__init__(img_size=image_size, patch_size=patch_size, in_chans=in_ch, num_classes=num_classes,
                         embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                         qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
                         drop_path_rate=drop_path_rate, embed_layer=embed_layer,
                         norm_layer=norm_layer, act_layer=act_layer)
        
        self.prompt_num = prompt_num
        self.depth = depth
        self.insprompt = nn.Parameter(torch.zeros(self.depth, self.prompt_num, embed_dim))
        self.head = nn.Linear(self.embed_dim, self.num_classes)
        if state_dict is not None:
            self.load_state_dict(state_dict, strict=False)
        self.get_image_prompt = ImagePrompt()
        self.get_token_prompt = TokenPrompt(prompt_num = self.prompt_num)
        self.w = nn.Parameter(torch.tensor(0.5))

    def Freeze(self):
        for param in self.parameters():
            param.requires_grad = False
        self.insprompt.requires_grad = True
        self.w.requires_grad = True
        for param in self.head.parameters():
            param.requires_grad = True
        for param in self.get_image_prompt.parameters():
            param.requires_grad = True
        for param in self.get_token_prompt.parameters():
            param.requires_grad = True

    def forward(self, x):
        token_prompt = self.get_token_prompt(x)
        x = self.get_image_prompt(x)
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        w = torch.sigmoid(self.w)
        
        for i in range(self.depth):
            prompt = w * self.insprompt[i].unsqueeze(0).expand(x.shape[0], -1, -1) + (1- w) * token_prompt
            x = torch.cat((x[:, :1, :], prompt, x[:, (1 + self.prompt_num):, :]), dim=1)
            num_tokens = x.shape[1]
            x = self.blocks[i](x)
#             print(x.shape)
            x = x[:, :num_tokens - self.prompt_num]
#             print(x.shape)
        
        x = self.blocks(x)
        x = self.norm(x)
        x = self.fc_norm(x[:, 0, :])
        x = self.head(x)
        return x
    
# model = InsVP(num_classes=102).to(device)
# random_noise = torch.randn(1, 3, 224, 224)
# random_noise = random_noise.to(device)

# with torch.no_grad():
#     output = model(random_noise)
#     predicted_class = torch.argmax(output, dim=1).item()

# print(output.shape)
# print(predicted_class)

In [5]:
pretrained_path = "vit_base_p16_224_in22k.pth"
state_dict = torch.load(pretrained_path, map_location=device)
# 移除预训练权重中的 head 层参数
state_dict.pop("head.weight", None)
state_dict.pop("head.bias", None)
model = InsVP(num_classes=102, state_dict = state_dict).to(device)
model.Freeze()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# model

  state_dict = torch.load(pretrained_path, map_location=device)


In [6]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

transform_train = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor(),
                normalize,
            ]
        )

transform_test = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.RandomCrop(224),
#                 transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor(),
                normalize,
            ]
        )

train_dataset = datasets.ImageFolder(root='flower102//prepare_pic//train', transform=transform_train)
test_dataset = datasets.ImageFolder(root='flower102//prepare_pic//test', transform=transform_test)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)


In [7]:
criterion = nn.CrossEntropyLoss()

# 训练循环
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

# 保存训练后的模型
torch.save(model.state_dict(), "InsVP.pth")

Epoch [1/100], Loss: 4.6247, Accuracy: 0.98%
Epoch [2/100], Loss: 4.6094, Accuracy: 1.47%
Epoch [3/100], Loss: 4.5966, Accuracy: 2.84%
Epoch [4/100], Loss: 4.5835, Accuracy: 4.02%
Epoch [5/100], Loss: 4.5635, Accuracy: 6.67%
Epoch [6/100], Loss: 4.5429, Accuracy: 11.57%
Epoch [7/100], Loss: 4.5157, Accuracy: 15.98%
Epoch [8/100], Loss: 4.4819, Accuracy: 20.00%
Epoch [9/100], Loss: 4.4264, Accuracy: 26.76%
Epoch [10/100], Loss: 4.3388, Accuracy: 36.57%
Epoch [11/100], Loss: 4.2470, Accuracy: 49.51%
Epoch [12/100], Loss: 4.1509, Accuracy: 57.16%
Epoch [13/100], Loss: 4.0541, Accuracy: 66.67%
Epoch [14/100], Loss: 3.9482, Accuracy: 70.49%
Epoch [15/100], Loss: 3.8510, Accuracy: 76.18%
Epoch [16/100], Loss: 3.7462, Accuracy: 79.80%
Epoch [17/100], Loss: 3.6393, Accuracy: 82.75%
Epoch [18/100], Loss: 3.5330, Accuracy: 85.78%
Epoch [19/100], Loss: 3.4330, Accuracy: 87.25%
Epoch [20/100], Loss: 3.3162, Accuracy: 88.92%
Epoch [21/100], Loss: 3.2064, Accuracy: 90.39%
Epoch [22/100], Loss: 3.101

In [8]:
model.eval()
total_loss = 0
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

# 计算测试集上的平均损失和准确率
print(f"Accuracy: {100 * correct / total:.2f}%")

Accuracy: 96.81%
