In [1]:
import time
from tqdm import tqdm
from transformers import ViTForImageClassification
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

class CustomClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CustomClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        return x

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', ignore_mismatched_sizes=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

input_dim = model.classifier.in_features
num_classes = 7
model.classifier = CustomClassifier(input_dim, num_classes)
model.classifier.to(device)

for param in model.vit.parameters():
    param.requires_grad = False

# 增加 L2 正则化
optimizer = optim.Adam(model.classifier.parameters(), lr=2e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# 使用 ReduceLROnPlateau 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3, verbose=True)

# 数据增强
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

train_dataset = datasets.ImageFolder(root='archive/train', transform=transform_train)
test_dataset = datasets.ImageFolder(root='archive/test', transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Training')
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    progress_bar = tqdm(dataloader, desc='Evaluating')
    with torch.no_grad():
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            progress_bar.set_postfix(loss=loss.item())
    accuracy = correct / len(dataloader.dataset)
    return total_loss / len(dataloader), accuracy

num_epochs = 10
best_val_loss = float('inf')
for epoch in range(num_epochs):
    start_time = time.time()
    train_loss = train(model, train_loader, optimizer, criterion, device)
    val_loss, val_accuracy = evaluate(model, test_loader, criterion, device)
    end_time = time.time()
    epoch_duration = end_time - start_time
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Time: {epoch_duration:.2f}s")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Validation Accuracy: {val_accuracy:.4f}")
    print('-' * 30)
    
    scheduler.step(val_loss)

    # 保存验证损失最小的模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_vit_fer2013.pth')
        print(f"Model saved with validation loss: {val_loss:.4f}")

print(f"Best model saved with validation loss: {best_val_loss:.4f}")


  from .autonotebook import tqdm as notebook_tqdm
Training: 100%|██████████| 898/898 [05:30<00:00,  2.72it/s, loss=0.948]
Evaluating: 100%|██████████| 225/225 [00:47<00:00,  4.78it/s, loss=0.794]


Epoch 1/10
Time: 377.77s
Train Loss: 1.4960
Validation Loss: 1.1953
Validation Accuracy: 0.5483
------------------------------
Model saved with validation loss: 1.1953


Training: 100%|██████████| 898/898 [03:16<00:00,  4.57it/s, loss=1.91] 
Evaluating: 100%|██████████| 225/225 [00:37<00:00,  5.96it/s, loss=0.747]


Epoch 2/10
Time: 234.35s
Train Loss: 1.3858
Validation Loss: 1.1352
Validation Accuracy: 0.5698
------------------------------
Model saved with validation loss: 1.1352


Training: 100%|██████████| 898/898 [03:17<00:00,  4.54it/s, loss=1.72] 
Evaluating: 100%|██████████| 225/225 [00:37<00:00,  5.93it/s, loss=0.597]


Epoch 3/10
Time: 235.86s
Train Loss: 1.3584
Validation Loss: 1.1249
Validation Accuracy: 0.5644
------------------------------
Model saved with validation loss: 1.1249


Training:   9%|▉         | 85/898 [00:18<02:58,  4.55it/s, loss=1.39]


KeyboardInterrupt: 