2023.11.19
尝试使用PyTorch内置ViT分类模型

In [2]:
%matplotlib inline
import os
os.environ['TORCH_HOME'] = 'weights'

import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms

import argparse
import torchmetrics     # 用于精度指标计算
# 用于jupyter实时绘图
import pylab as pl
from IPython import display

In [3]:
""" 定义超参 """
# parser = argparse.ArgumentParser(description='ViT分类模型 2023.11.19')
# parser.add_argument('--batch', default=256, help='')

# opt = parser.parse_args()
# print(opt)

EPOCHS = 100
BATCH = 128
NUM_WORKERS = 4
DROP_LAST = True
LR = 0.005
SAVE_DIR = '/Users/sunchengcheng/Projects/D2L/ViT/weights'

In [4]:
""" 构建数据集 """
process_data = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize([224, 224]),        # int 图片短边会缩放到int，长边相应缩放，不是我们想要的正方形
])
# CIFAR10 32x32 colour images in 10 classes
train_dataset = torchvision.datasets.CIFAR10(root='../data/', train=True, transform=process_data, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../data/', transform=process_data, download=True)
train_iter = DataLoader(dataset=train_dataset, batch_size=BATCH, shuffle=True, num_workers=NUM_WORKERS, drop_last=DROP_LAST)
test_iter = DataLoader(dataset=test_dataset, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS, drop_last=DROP_LAST)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
""" 定义ViT_model """
# ViT_model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
# 测试vgg模型，后期注释
ViT_model = torchvision.models.vgg11(num_classes=len(train_dataset.classes))
weights_vgg = torch.load(f='weights/hub/checkpoints/vgg11-8a719046.pth')
del weights_vgg['classifier.6.weight']
del weights_vgg['classifier.6.bias']
ViT_model.load_state_dict(weights_vgg, strict=False)

# ViT_model.heads = nn.Linear(768, train_dataset.classes.__len__())
# """ 测试ViT输入输出 """
# with torch.no_grad():
#     ViT_model.eval()        # 影响BN、dropout层
#     x = torch.rand(size=(10, 3, 224, 224))
#     y = ViT_model(x)
#     print(y.softmax(dim=-1), torch.argmax(torch.softmax(y, dim=-1), dim=-1), sep='\n')

_IncompatibleKeys(missing_keys=['classifier.6.weight', 'classifier.6.bias'], unexpected_keys=[])

In [6]:
""" 定义loss、optim """
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=ViT_model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=train_iter.__len__())

In [7]:
def evaluate(model, test_iter, device):
    metric_f1 = torchmetrics.F1Score(task='multiclass', num_classes=len(test_iter.dataset.classes)).to(device)
    """ 使用验证集评估模型 """
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for i, (x, y) in enumerate(test_iter):
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            correct += torch.sum(y_hat.softmax(dim=-1).argmax(dim=-1) == y)
            total += x.shape[0]
            acc = correct / total
            batch_f1 = metric_f1(y_hat.softmax(dim=-1).argmax(dim=-1), y)
            print(f'---------->evaluate {i}/{len(test_iter)} acc:{acc.item() * 100:.7f}% batch_f1 {batch_f1 * 100:.4f}% ')
        eval_f1 = metric_f1.compute()
        print(f'---------->evaluate eval_f1:{eval_f1 * 100:.4f}%')
        return acc.item(), eval_f1

In [8]:
""" 训练 """
max_f1 = 0      # 记录最大f1分数，保存模型
lr_iter = []
acc_iter = []
batch_acc_iter = []
batch_f1_iter = []
correct = 0
total = 0
iter = 0
device = torch.device('mps')
# 精度指标计算
metric_acc = torchmetrics.Accuracy(task='multiclass', num_classes=len(train_dataset.classes)).to(device)
metric_f1 = torchmetrics.F1Score(task='multiclass', num_classes=len(train_dataset.classes)).to(device)
metric_precision = torchmetrics.Precision
ViT_model.to(device=device)
for epoch in range(EPOCHS):
    ViT_model.train()               # 切换到train模式
    for i, (x, y) in enumerate(train_iter):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = ViT_model(x)
        # 统计精度
        correct += torch.sum(y_hat.softmax(dim=-1).argmax(dim=-1) == y)
        total += x.shape[0]
        acc = correct / total
        acc_iter.append(acc.item())
        batch_acc = metric_acc(y_hat.softmax(dim=-1).argmax(dim=-1), y)     # 计算batch的精度；已验证metric_acc.compute()与自己写的全局acc同
        batch_f1 = metric_f1(y_hat.softmax(dim=-1).argmax(dim=-1), y)
        lr_iter.append(scheduler.get_last_lr()[0] * 200)        # 收集学习率放大200倍，方便显示
        batch_acc_iter.append(batch_acc.item())
        batch_f1_iter.append(batch_f1.item())
        iter += 1
        print(f'epoch {epoch}/{EPOCHS} iter {i}/{len(train_iter)} lr {lr_iter[-1]:.7f} \
              acc {acc * 100:.4f}% batch_acc {batch_acc * 100:.4f}% batch_f1 {batch_f1 * 100:.4f} ')
        # backward update
        loss = loss_func(y_hat, y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        # draw acc
        pl.clf()
        pl.plot(acc_iter, label='acc')
        pl.plot(lr_iter, label='lr * 200')
        pl.plot(batch_acc_iter, label='batch_acc')
        pl.plot(batch_f1_iter, label='batch_f1')
        pl.legend(loc='upper right')                # 必须设置，否则pl.plot()的label参数显示不出来
        pl.xlabel(f'iters')
        display.display(pl.gcf())
        display.clear_output(True)
    # evalute model
    epoch_acc = metric_acc.compute()
    epoch_f1 = metric_f1.compute()
    metric_acc.reset()
    metric_f1.reset()
    print(f'epoch {epoch}/{EPOCHS} epoch performance: epoch_acc {epoch_acc * 100:.4f} epoch_f1 {epoch_f1 * 100:.4f}')
    acc, eval_f1 = evaluate(ViT_model, test_iter, device)
    if eval_f1 > max_f1:
        max_f1 = eval_f1
        print(f'*************** Find Better Model, Saving Model to {SAVE_DIR} *****************')
        torch.save(ViT_model.state_dict(), os.path.join(SAVE_DIR, f'best.pth'))

epoch 0/100 iter 26/390 lr 0.9890738               acc 23.9583% batch_acc 42.1875% batch_f1 42.1875 


In [None]:
ViT_model.to(torch.device('mps'))
evaluate(ViT_model, test_iter, torch.device('mps'))



---------->evaluate acc:4.6875000% 
---------->evaluate acc:5.4687500% 
---------->evaluate acc:6.2500000% 
---------->evaluate acc:6.6406250% 
---------->evaluate acc:7.1874999% 
---------->evaluate acc:8.3333336% 
---------->evaluate acc:9.1517858% 
---------->evaluate acc:9.7656250% 
---------->evaluate acc:9.3750000% 
---------->evaluate acc:10.0000001% 
---------->evaluate acc:10.2272727% 
---------->evaluate acc:10.5468750% 
---------->evaluate acc:10.3365384% 
---------->evaluate acc:9.9330358% 
---------->evaluate acc:10.0000001% 
---------->evaluate acc:9.8632812% 
---------->evaluate acc:9.5588237% 
---------->evaluate acc:9.6354164% 
---------->evaluate acc:9.8684214% 
---------->evaluate acc:10.0781247% 
---------->evaluate acc:10.1190478% 
---------->evaluate acc:9.9431820% 
---------->evaluate acc:10.0543477% 
---------->evaluate acc:10.0911461% 
---------->evaluate acc:10.0625001% 
---------->evaluate acc:9.9759616% 
---------->evaluate acc:10.1273149% 
---------->evalua

torch.Size([4096, 4096])

In [None]:
ViT_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 