2023.11.19
尝试使用PyTorch内置ViT分类模型

In [1]:
%matplotlib inline
import os
os.environ['TORCH_HOME'] = 'weights'

import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms

import argparse
# 用于jupyter实时绘图
import pylab as pl
from IPython import display

In [2]:
""" 定义超参 """
# parser = argparse.ArgumentParser(description='ViT分类模型 2023.11.19')
# parser.add_argument('--batch', default=256, help='')

# opt = parser.parse_args()
# print(opt)

EPOCHS = 100
BATCH = 64
NUM_WORKERS = 4
DROP_LAST = True
LR = 0.005

In [3]:
""" 构建数据集 """
process_data = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize([224, 224]),        # int 图片短边会缩放到int，长边相应缩放，不是我们想要的正方形
])
# CIFAR10 32x32 colour images in 10 classes
train_dataset = torchvision.datasets.CIFAR10(root='../data/', train=True, transform=process_data, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../data/', transform=process_data, download=True)
train_iter = DataLoader(dataset=train_dataset, batch_size=BATCH, shuffle=True, num_workers=NUM_WORKERS, drop_last=DROP_LAST)
test_iter = DataLoader(dataset=test_dataset, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, drop_last=DROP_LAST)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
""" 定义ViT_model """
ViT_model = torchvision.models.vit_b_16()
ViT_model.heads = nn.Linear(768, train_dataset.classes.__len__())
""" 测试ViT输入输出 """
with torch.no_grad():
    ViT_model.eval()        # 影响BN、dropout层
    x = torch.rand(size=(10, 3, 224, 224))
    y = ViT_model(x)
    print(y.softmax(dim=-1), torch.argmax(torch.softmax(y, dim=-1), dim=-1), sep='\n')

tensor([[0.2709, 0.0695, 0.0289, 0.0786, 0.0685, 0.0554, 0.0802, 0.0568, 0.1684,
         0.1228],
        [0.2728, 0.0701, 0.0303, 0.0795, 0.0695, 0.0550, 0.0764, 0.0569, 0.1679,
         0.1215],
        [0.2709, 0.0743, 0.0314, 0.0792, 0.0707, 0.0549, 0.0782, 0.0570, 0.1620,
         0.1214],
        [0.2669, 0.0714, 0.0301, 0.0766, 0.0724, 0.0549, 0.0771, 0.0586, 0.1656,
         0.1264],
        [0.2769, 0.0711, 0.0303, 0.0786, 0.0699, 0.0550, 0.0844, 0.0559, 0.1570,
         0.1210],
        [0.2593, 0.0714, 0.0306, 0.0797, 0.0692, 0.0559, 0.0785, 0.0609, 0.1689,
         0.1257],
        [0.2595, 0.0746, 0.0304, 0.0812, 0.0686, 0.0558, 0.0826, 0.0568, 0.1651,
         0.1255],
        [0.2670, 0.0715, 0.0302, 0.0799, 0.0724, 0.0564, 0.0799, 0.0555, 0.1610,
         0.1261],
        [0.2714, 0.0701, 0.0300, 0.0778, 0.0706, 0.0550, 0.0754, 0.0555, 0.1668,
         0.1273],
        [0.2631, 0.0771, 0.0303, 0.0803, 0.0721, 0.0572, 0.0770, 0.0587, 0.1661,
         0.1181]])
tensor([0

In [5]:
""" 定义loss、optim """
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=ViT_model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=train_iter.__len__())

In [8]:
def evaluate(model, test_iter, device):
    """ 使用验证集评估模型 """
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for i, (x, y) in enumerate(test_iter):
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            correct += torch.sum(y_hat.softmax(dim=-1).argmax(dim=-1) == y)
            total += x.shape[0]
            acc = correct / total
            print(f'---------->evaluate acc:{acc.item() * 100:.7f}% ')
        return acc.item()

In [9]:
""" 训练 """
lr_iter = []
acc_iter = []
correct = 0
total = 0
iter = 0
device = torch.device('mps')
ViT_model.to(device=device)
for epoch in range(EPOCHS):
    ViT_model.train()               # 切换到train模式
    for i, (x, y) in enumerate(train_iter):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = ViT_model(x)
        # 统计精度
        correct += torch.sum(y_hat.softmax(dim=-1).argmax(dim=-1) == y)
        total += x.shape[0]
        acc = correct / total
        acc_iter.append(acc.item())
        lr_iter.append(scheduler.get_last_lr()[0])
        iter += 1
        print(f'epoch {epoch}/{EPOCHS} iter {i}/{len(train_iter)} lr {lr_iter[-1]:.7f} acc {acc * 100:.4f}% ')
        # backward update
        loss = loss_func(y_hat, y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        # draw acc
        pl.clf()
        pl.plot(acc_iter)
        pl.plot(lr_iter)
        display.display(pl.gcf())
        display.clear_output(True)
    # evalute model
    _ = evaluate(ViT_model, test_iter, device)



epoch 0/100 iter 0/781 lr 0.0049299 acc 10.9375% 


VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a