In [1]:
!pip install torch torchvision timm

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting timm
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e7/0e/ef97f6d8c399bf5842af0dd5a4f5ac55b2f169d62e29ecbf7663e1cb1438/timm-1.0.9-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m301.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface_hub (from timm)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/5f/f1/15dc793cb109a801346f910a6b350530f2a763a6e83b221725a0bcc1e297/huggingface_hub-0.25.1-py3-none-any.whl (436 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m436.4/436.4 kB[0m [31m312.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from timm)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b9/df/6f766b56690709d22e83836e4067a1109a7d84ea152a6deb5692743a2805/safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (435 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import timm  # 导入timm库

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
transform_train = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],  # ViT模型一般使用ImageNet的均值和标准差
                         std=[0.5, 0.5, 0.5])
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

In [4]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:16<00:00, 10466228.35it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
# 加载预训练的ViT模型
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# 获取输入特征维度
num_ftrs = model.head.in_features

# 修改分类头
model.head = nn.Linear(num_ftrs, 10)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('训练完成')

[1,   100] loss: 0.445
[1,   200] loss: 0.315
[1,   300] loss: 0.263
[1,   400] loss: 0.224
[1,   500] loss: 0.246
[1,   600] loss: 0.262
[1,   700] loss: 0.222
[1,   800] loss: 0.237
[1,   900] loss: 0.202
[1,  1000] loss: 0.214
[1,  1100] loss: 0.206
[1,  1200] loss: 0.204
[1,  1300] loss: 0.191
[1,  1400] loss: 0.192
[1,  1500] loss: 0.197
[2,   100] loss: 0.151
[2,   200] loss: 0.126
[2,   300] loss: 0.128
[2,   400] loss: 0.137
[2,   500] loss: 0.160
[2,   600] loss: 0.161
[2,   700] loss: 0.160
[2,   800] loss: 0.137
[2,   900] loss: 0.145
[2,  1000] loss: 0.180
[2,  1100] loss: 0.169
[2,  1200] loss: 0.144
[2,  1300] loss: 0.148
[2,  1400] loss: 0.175
[2,  1500] loss: 0.163
[3,   100] loss: 0.108
[3,   200] loss: 0.125
[3,   300] loss: 0.131
[3,   400] loss: 0.112
[3,   500] loss: 0.133
[3,   600] loss: 0.143
[3,   700] loss: 0.117
[3,   800] loss: 0.107
[3,   900] loss: 0.117
[3,  1000] loss: 0.113
[3,  1100] loss: 0.125
[3,  1200] loss: 0.143
[3,  1300] loss: 0.104
[3,  1400] 

In [9]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('在10000张测试图片上的准确率为: %d %%' % (
    100 * correct / total))

在10000张测试图片上的准确率为: 74 %


In [10]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('类别 %5s 的准确率: %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

类别 plane 的准确率: 80 %
类别   car 的准确率: 92 %
类别  bird 的准确率: 67 %
类别   cat 的准确率: 56 %
类别  deer 的准确率: 83 %
类别   dog 的准确率: 60 %
类别  frog 的准确率: 78 %
类别 horse 的准确率: 74 %
类别  ship 的准确率: 84 %
类别 truck 的准确率: 72 %
