In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# CIFAR 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [03:04<00:00, 925747.72it/s] 


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [2]:
from tqdm import tqdm  
# 检查 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [3]:

# 简单模型示例
model = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(32*32*3, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 10)
).to(device)

# 优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

def train(model, loader):
    model.train()
    running_loss = 0.0  # 用于记录损失
    for images, labels in tqdm(loader, desc="Training", leave=True):  # tqdm 包裹数据加载器
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Average Loss: {running_loss / len(loader):.4f}")

# 在本地运行
for epoch in range(10):
    print(f"Epoch {epoch+1} / 10")
    train(model, train_loader)
    print(f"Epoch {epoch+1} complete.")


Epoch 1 / 10


Training: 100%|██████████| 782/782 [00:17<00:00, 45.23it/s]


Average Loss: 1.6448
Epoch 1 complete.
Epoch 2 / 10


Training:  48%|████▊     | 377/782 [00:08<00:09, 44.47it/s]


KeyboardInterrupt: 

In [5]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)
pair(4)

(4, 4)

In [6]:
from vit import ViT
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = ViT(
    image_size=32,
    patch_size=4,  # 分成 4x4 的 patches，总共有 8x8=64 个 patch
    num_classes=10,
    dim=256,  # 嵌入维度
    depth=6,  # Transformer 层数
    heads=8,  # Multi-head Attention 的头数
    mlp_dim=512,  # MLP 中间层维度
    pool='cls',  # 使用 CLS Token
    channels=3,  # 输入的通道数
    dim_head=64,
    dropout=0.1,
    emb_dropout=0.1
).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)


Using device: cpu


In [8]:
def train(model, train_loader, optimizer, criterion, device, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 统计
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

            # 更新进度条
            loop.set_postfix(loss=loss.item(), accuracy=100. * correct / total)
        print(f"Epoch {epoch+1}/{num_epochs}: Loss: {total_loss/len(train_loader):.4f}, Accuracy: {100. * correct / total:.2f}%")

# **测试函数**
def test(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        loop = tqdm(test_loader, desc="Testing")
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 统计
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    print(f"Test Loss: {total_loss/len(test_loader):.4f}, Accuracy: {100. * correct / total:.2f}%")
    return total_loss / len(test_loader), 100. * correct / total



In [9]:
# **主程序**
# 1. 训练模型
train(model, train_loader, optimizer, criterion, device, num_epochs=10)

Epoch 1/10:   3%|▎         | 25/782 [00:27<13:50,  1.10s/it, accuracy=19.5, loss=2.11]


KeyboardInterrupt: 

In [None]:

# 2. 测试模型
test_loss, test_acc = test(model, test_loader, criterion, device)
print(f"Final Test Loss: {test_loss:.4f}, Final Test Accuracy: {test_acc:.2f}%")