In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
from timm.models.vision_transformer import VisionTransformer

In [10]:
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

In [11]:
model = VisionTransformer(
            img_size=224,
            patch_size=32,
            embed_dim=384,
            depth=12,
            num_heads=6,
            num_classes=10
        )

In [12]:
# 加载训练和测试数据集
train_dataset = datasets.CIFAR10(root='./datasets/cifar10/', train=True, download=False, transform=transform)
test_dataset = datasets.CIFAR10(root='./datasets/cifar10/', train=False, download=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=8, pin_memory=True)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [15]:
def train_model(model, train_loader, criterion, optimizer, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
            
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
        test_model(model, test_loader, device)
        model.train()

def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

In [16]:
import time
start_time = time.time()
train_model(model, train_loader, criterion, optimizer, device, num_epochs=50)
end_time = time.time()
print(f"Training time for ViT32: {end_time - start_time:.2f} seconds")

Epoch [1/50], Step [100/196], Loss: 1.8989
Epoch [1/50], Loss: 1.8007
Test Accuracy: 38.68%
Epoch [2/50], Step [100/196], Loss: 1.5655
Epoch [2/50], Loss: 1.5458
Test Accuracy: 45.99%
Epoch [3/50], Step [100/196], Loss: 1.4332
Epoch [3/50], Loss: 1.3836
Test Accuracy: 49.80%
Epoch [4/50], Step [100/196], Loss: 1.2335
Epoch [4/50], Loss: 1.2740
Test Accuracy: 51.18%
Epoch [5/50], Step [100/196], Loss: 1.0467
Epoch [5/50], Loss: 1.1717
Test Accuracy: 52.58%
Epoch [6/50], Step [100/196], Loss: 1.0237
Epoch [6/50], Loss: 1.0882
Test Accuracy: 57.42%
Epoch [7/50], Step [100/196], Loss: 0.8496
Epoch [7/50], Loss: 1.0033
Test Accuracy: 60.06%
Epoch [8/50], Step [100/196], Loss: 0.8872
Epoch [8/50], Loss: 0.9301
Test Accuracy: 62.80%
Epoch [9/50], Step [100/196], Loss: 0.7402
Epoch [9/50], Loss: 0.8521
Test Accuracy: 62.78%
Epoch [10/50], Step [100/196], Loss: 0.6923
Epoch [10/50], Loss: 0.7905
Test Accuracy: 63.95%
Epoch [11/50], Step [100/196], Loss: 0.7881
Epoch [11/50], Loss: 0.7259
Test A

KeyboardInterrupt: 