2023.11.19
尝试使用PyTorch内置ViT分类模型

In [2]:
%matplotlib inline
import os
os.environ['TORCH_HOME'] = 'weights'

import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms

import argparse
import torchmetrics     # 用于精度指标计算
# 用于jupyter实时绘图
import pylab as pl
from IPython import display

In [8]:
""" 定义超参 """
# parser = argparse.ArgumentParser(description='ViT分类模型 2023.11.19')
# parser.add_argument('--batch', default=256, help='')

# opt = parser.parse_args()
# print(opt)

EPOCHS = 100
BATCH = 16
NUM_WORKERS = 2
DROP_LAST = True
LR = 0.005
SAVE_DIR = '/Users/sunchengcheng/Projects/D2L/ViT/weights'

In [9]:
""" 构建数据集 """
process_data = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize([224, 224]),        # int 图片短边会缩放到int，长边相应缩放，不是我们想要的正方形
])
# CIFAR10 32x32 colour images in 10 classes
train_dataset = torchvision.datasets.CIFAR10(root='../data/', train=True, transform=process_data, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../data/', transform=process_data, download=True)
train_iter = DataLoader(dataset=train_dataset, batch_size=BATCH, shuffle=True, num_workers=NUM_WORKERS, drop_last=DROP_LAST)
test_iter = DataLoader(dataset=test_dataset, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, drop_last=DROP_LAST)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
""" 定义ViT_model """
# ViT_model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
ViT_model = torchvision.models.vgg11(num_classes=10)
# ViT_model.heads = nn.Linear(768, train_dataset.classes.__len__())
""" 测试ViT输入输出 """
with torch.no_grad():
    ViT_model.eval()        # 影响BN、dropout层
    x = torch.rand(size=(10, 3, 224, 224))
    y = ViT_model(x)
    print(y.softmax(dim=-1), torch.argmax(torch.softmax(y, dim=-1), dim=-1), sep='\n')

tensor([[0.1035, 0.0967, 0.1085, 0.0962, 0.0985, 0.1016, 0.0970, 0.0992, 0.1006,
         0.0982],
        [0.1040, 0.0973, 0.1080, 0.0958, 0.0985, 0.1017, 0.0968, 0.0991, 0.1006,
         0.0982],
        [0.1039, 0.0970, 0.1084, 0.0959, 0.0984, 0.1013, 0.0969, 0.0992, 0.1007,
         0.0982],
        [0.1039, 0.0965, 0.1084, 0.0959, 0.0986, 0.1016, 0.0972, 0.0991, 0.1006,
         0.0982],
        [0.1035, 0.0970, 0.1085, 0.0961, 0.0985, 0.1016, 0.0970, 0.0989, 0.1005,
         0.0984],
        [0.1036, 0.0970, 0.1081, 0.0964, 0.0984, 0.1014, 0.0968, 0.0994, 0.1007,
         0.0982],
        [0.1039, 0.0966, 0.1085, 0.0961, 0.0987, 0.1015, 0.0969, 0.0992, 0.1004,
         0.0983],
        [0.1035, 0.0971, 0.1084, 0.0960, 0.0986, 0.1017, 0.0969, 0.0991, 0.1005,
         0.0982],
        [0.1038, 0.0967, 0.1083, 0.0960, 0.0985, 0.1017, 0.0968, 0.0991, 0.1006,
         0.0983],
        [0.1035, 0.0969, 0.1085, 0.0961, 0.0983, 0.1017, 0.0969, 0.0992, 0.1008,
         0.0980]])
tensor([2

In [20]:
""" 定义loss、optim """
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=ViT_model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=train_iter.__len__())

In [21]:
def evaluate(model, test_iter, device):
    metric_f1 = torchmetrics.F1Score(task='multiclass', num_classes=len(test_iter.dataset.classes)).to(device)
    """ 使用验证集评估模型 """
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for i, (x, y) in enumerate(test_iter):
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            correct += torch.sum(y_hat.softmax(dim=-1).argmax(dim=-1) == y)
            total += x.shape[0]
            acc = correct / total
            batch_f1 = metric_f1(y_hat.softmax(dim=-1).argmax(dim=-1), y)
            print(f'---------->evaluate {i}/{len(test_iter)} acc:{acc.item() * 100:.7f}% batch_f1 {batch_f1 * 100:.4f}% ')
        eval_f1 = metric_f1.compute()
        print(f'---------->evaluate eval_f1:{eval_f1 * 100:.4f}%')
        return acc.item(), eval_f1

In [22]:
""" 训练 """
max_f1 = 0      # 记录最大f1分数，保存模型
lr_iter = []
acc_iter = []
correct = 0
total = 0
iter = 0
device = torch.device('mps')
# 精度指标计算
metric_acc = torchmetrics.Accuracy(task='multiclass', num_classes=len(train_dataset.classes)).to(device)
metric_f1 = torchmetrics.F1Score(task='multiclass', num_classes=len(train_dataset.classes)).to(device)
metric_precision = torchmetrics.Precision
ViT_model.to(device=device)
for epoch in range(EPOCHS):
    ViT_model.train()               # 切换到train模式
    for i, (x, y) in enumerate(train_iter):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = ViT_model(x)
        # 统计精度
        correct += torch.sum(y_hat.softmax(dim=-1).argmax(dim=-1) == y)
        total += x.shape[0]
        acc = correct / total
        acc_iter.append(acc.item())
        batch_acc = metric_acc(y_hat.softmax(dim=-1).argmax(dim=-1), y)     # 计算batch的精度；已验证metric_acc.compute()与自己写的全局acc同
        batch_f1 = metric_f1(y_hat.softmax(dim=-1).argmax(dim=-1), y)
        lr_iter.append(scheduler.get_last_lr()[0] * 200)        # 收集学习率放大200倍，方便显示
        iter += 1
        print(f'epoch {epoch}/{EPOCHS} iter {i}/{len(train_iter)} lr {lr_iter[-1]:.7f} \
              acc {acc * 100:.4f}% batch_acc {batch_acc * 100:.4f}% batch_f1 {batch_f1 * 100:.4f} ')
        # backward update
        loss = loss_func(y_hat, y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        # draw acc
        # pl.clf()
        # pl.plot(acc_iter)
        # pl.plot(lr_iter)
        # display.display(pl.gcf())
        # display.clear_output(True)
    # evalute model
    epoch_acc = metric_acc.compute()
    epoch_f1 = metric_f1.compute()
    metric_acc.reset()
    metric_f1.reset()
    print(f'epoch {epoch}/{EPOCHS} epoch performance: epoch_acc {epoch_acc * 100:.4f} epoch_f1 {epoch_f1 * 100:.4f}')
    acc, eval_f1 = evaluate(ViT_model, test_iter, device)
    if eval_f1 > max_f1:
        max_f1 = eval_f1
        print(f'*************** Find Better Model, Saving Model to {SAVE_DIR} *****************')
        torch.save(ViT_model.state_dict(), os.path.join(SAVE_DIR, f'best.pth'))

  tp = tp.sum(dim=0 if multidim_average == "global" else 1)


epoch 0/100 iter 0/3125 lr 1.0000000               acc 18.7500% batch_acc 18.7500% batch_f1 18.7500 
epoch 0/100 iter 1/3125 lr 0.9999997               acc 9.3750% batch_acc 0.0000% batch_f1 0.0000 
epoch 0/100 iter 2/3125 lr 0.9999990               acc 8.3333% batch_acc 6.2500% batch_f1 6.2500 
epoch 0/100 iter 3/3125 lr 0.9999977               acc 7.8125% batch_acc 6.2500% batch_f1 6.2500 
epoch 0/100 iter 4/3125 lr 0.9999960               acc 6.2500% batch_acc 0.0000% batch_f1 0.0000 
epoch 0/100 iter 5/3125 lr 0.9999937               acc 6.2500% batch_acc 6.2500% batch_f1 6.2500 
epoch 0/100 iter 6/3125 lr 0.9999909               acc 8.0357% batch_acc 18.7500% batch_f1 18.7500 
epoch 0/100 iter 7/3125 lr 0.9999876               acc 7.0312% batch_acc 0.0000% batch_f1 0.0000 
epoch 0/100 iter 8/3125 lr 0.9999838               acc 6.2500% batch_acc 0.0000% batch_f1 0.0000 
epoch 0/100 iter 9/3125 lr 0.9999795               acc 6.8750% batch_acc 12.5000% batch_f1 12.5000 
epoch 0/100 i



---------->evaluate 0/3125 acc:37.5000000% batch_f1 37.5000% 
---------->evaluate 1/3125 acc:34.3750000% batch_f1 31.2500% 
---------->evaluate 2/3125 acc:37.5000000% batch_f1 43.7500% 
---------->evaluate 3/3125 acc:34.3750000% batch_f1 25.0000% 
---------->evaluate 4/3125 acc:38.7499988% batch_f1 56.2500% 
---------->evaluate 5/3125 acc:34.3750000% batch_f1 12.5000% 
---------->evaluate 6/3125 acc:35.7142866% batch_f1 43.7500% 
---------->evaluate 7/3125 acc:35.1562500% batch_f1 31.2500% 
---------->evaluate 8/3125 acc:35.4166657% batch_f1 37.5000% 
---------->evaluate 9/3125 acc:33.7500006% batch_f1 18.7500% 
---------->evaluate 10/3125 acc:32.9545468% batch_f1 25.0000% 
---------->evaluate 11/3125 acc:33.8541657% batch_f1 43.7500% 
---------->evaluate 12/3125 acc:35.5769217% batch_f1 56.2500% 
---------->evaluate 13/3125 acc:36.6071433% batch_f1 50.0000% 
---------->evaluate 14/3125 acc:37.0833337% batch_f1 43.7500% 
---------->evaluate 15/3125 acc:37.1093750% batch_f1 37.5000% 
--



epoch 1/100 iter 0/3125 lr 0.0000000               acc 26.0617% batch_acc 25.0000% batch_f1 25.0000 
epoch 1/100 iter 1/3125 lr 0.0000003               acc 26.0613% batch_acc 25.0000% batch_f1 25.0000 


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
ViT_model.to(torch.device('mps'))
evaluate(ViT_model, test_iter, torch.device('mps'))



---------->evaluate acc:4.6875000% 
---------->evaluate acc:5.4687500% 
---------->evaluate acc:6.2500000% 
---------->evaluate acc:6.6406250% 
---------->evaluate acc:7.1874999% 
---------->evaluate acc:8.3333336% 
---------->evaluate acc:9.1517858% 
---------->evaluate acc:9.7656250% 
---------->evaluate acc:9.3750000% 
---------->evaluate acc:10.0000001% 
---------->evaluate acc:10.2272727% 
---------->evaluate acc:10.5468750% 
---------->evaluate acc:10.3365384% 
---------->evaluate acc:9.9330358% 
---------->evaluate acc:10.0000001% 
---------->evaluate acc:9.8632812% 
---------->evaluate acc:9.5588237% 
---------->evaluate acc:9.6354164% 
---------->evaluate acc:9.8684214% 
---------->evaluate acc:10.0781247% 
---------->evaluate acc:10.1190478% 
---------->evaluate acc:9.9431820% 
---------->evaluate acc:10.0543477% 
---------->evaluate acc:10.0911461% 
---------->evaluate acc:10.0625001% 
---------->evaluate acc:9.9759616% 
---------->evaluate acc:10.1273149% 
---------->evalua

In [18]:
ViT_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 