In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
from timm.models.vision_transformer import VisionTransformer

In [2]:
model_type = 'vit'

def get_model(model_type):
    transform224 = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    transform32 = transforms.Compose([
        # transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    if model_type == 'vit':
        model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
        transform = transform224
    elif model_type == 'resnet':
        model = timm.create_model('resnet18', pretrained=True, num_classes=10)
        transform = transform224
    elif model_type == 'vit4':
        model = VisionTransformer(
            img_size=32,
            patch_size=4,
            embed_dim=192,
            depth=12,
            num_heads=6,
            num_classes=10
        )
        transform = transform32
    elif model_type == 'vit2':
        model = VisionTransformer(
            img_size=32,
            patch_size=2,
            embed_dim=96,
            depth=12,
            num_heads=6,
            num_classes=10
        )
        transform = transform32
    elif model_type == 'vit4tiny':
        model = VisionTransformer(
            img_size=32,
            patch_size=4,
            embed_dim=192,
            depth=6,
            num_heads=3,
            num_classes=10
        )
        transform = transform32
        
    return (model, transform)

model, transform = get_model(model_type)

# 加载训练和测试数据集
train_dataset = datasets.CIFAR10(root='~/datasets/cifar10', train=True, download=False, transform=transform)
test_dataset = datasets.CIFAR10(root='~/datasets/cifar10', train=False, download=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=20)
test_loader = DataLoader(test_dataset, batch_size=50, shuffle=False, num_workers=20)

# 将模型转移到GPU（如果可用）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./runs/vit_experiment')

def train_model(model, train_loader, criterion, optimizer, device, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

            if (i + 1) % 100 == 0:
                writer.add_scalar('training loss', loss.item(), epoch * len(train_loader) + i)

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
        acc = test_model(model, test_loader, device)
        writer.add_scalar('test accuracy', acc, epoch)
    writer.close()

def test_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

In [None]:
import time

start_time = time.time()
train_model(model, train_loader, criterion, optimizer, device, num_epochs=10)
end_time = time.time()
print(f"Training time for {model_type}: {end_time - start_time:.2f} seconds")

In [None]:
test_model(model, test_loader, device)