# Transformer 与 Cutmix 测试函数

In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter
from data_preparation import get_data_loaders
from training import train, tune_hyperparameters
from model import get_resnet, get_vit  # 导入模型构造函数

设置设备与参数

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 100
alpha = 1.0  # CutMix alpha

加载数据

In [3]:
# 获取数据加载器
train_loader, test_loader = get_data_loaders(batch_size)

Files already downloaded and verified
Files already downloaded and verified


## 调参参数
学习率与正则化强度

In [None]:
# 定义学习率和权重衰减的取值范围
learning_rates = [2e-3]
weight_decays = [0]

## Resnet 测试

In [8]:
# 准备ResNet模型并统计参数量
resnet_model = get_resnet(num_classes=100, variant='resnet18').to(device)
resnet_params = sum(p.numel() for p in resnet_model.parameters())
print(f'ResNet-18 model parameters: {resnet_params}')

ResNet-18 model parameters: 11227812




In [None]:
# 超参数调优
print("Tuning ResNet hyperparameters...")
best_resnet_params = tune_hyperparameters(lambda: get_resnet(num_classes=100, variant='resnet18'), train_loader, test_loader, epochs, learning_rates, weight_decays, device, alpha)

In [9]:
# 可以在这里直接设置超参数
best_resnet_params = (1e-3, 0)

In [10]:
# 用最优参数重新训练ResNet模型并用Tensorboard可视化
best_lr_resnet, best_wd_resnet = best_resnet_params
writer_resnet = SummaryWriter(log_dir='runs/ResNet_best')
# resnet_model = get_resnet(num_classes=100, variant='resnet18').to(device)
optimizer_resnet = torch.optim.Adam(resnet_model.parameters(), lr=best_lr_resnet, weight_decay=best_wd_resnet)
criterion = torch.nn.CrossEntropyLoss()
train(resnet_model, train_loader, test_loader, optimizer_resnet, criterion, epochs, device, alpha, writer_resnet)
writer_resnet.close()

# 保存模型的状态字典
torch.save(resnet_model.state_dict(), 'resnet_model_state_dict.pth')

Epoch [1/100] Train Loss: 4.3509, Accuracy: 5.64%, Validation Loss: 3.8010, Accuracy: 12.91%
Epoch [2/100] Train Loss: 4.1365, Accuracy: 9.29%, Validation Loss: 3.4126, Accuracy: 18.89%
Epoch [3/100] Train Loss: 4.0385, Accuracy: 11.44%, Validation Loss: 3.4367, Accuracy: 19.79%
Epoch [4/100] Train Loss: 3.9381, Accuracy: 13.24%, Validation Loss: 3.0841, Accuracy: 24.39%
Epoch [5/100] Train Loss: 3.8682, Accuracy: 14.66%, Validation Loss: 3.1134, Accuracy: 25.56%
Epoch [6/100] Train Loss: 3.8059, Accuracy: 16.55%, Validation Loss: 2.8709, Accuracy: 29.27%
Epoch [7/100] Train Loss: 3.7284, Accuracy: 17.44%, Validation Loss: 2.8152, Accuracy: 30.75%
Epoch [8/100] Train Loss: 3.7023, Accuracy: 18.69%, Validation Loss: 2.7804, Accuracy: 32.13%
Epoch [9/100] Train Loss: 3.6701, Accuracy: 18.50%, Validation Loss: 2.5921, Accuracy: 34.87%
Epoch [10/100] Train Loss: 3.6399, Accuracy: 19.34%, Validation Loss: 2.5985, Accuracy: 36.19%
Epoch [11/100] Train Loss: 3.5569, Accuracy: 20.76%, Validati

## Transformer(ViT) 测试

In [4]:
# 准备ViT模型并统计参数量
vit_model = get_vit(image_size=32, patch_size=4, num_classes=100, dim=192, depth=25, heads=6, mlp_dim=384, dropout=0.1, emb_dropout=0.1).to(device)
vit_params = sum(p.numel() for p in vit_model.parameters())
print(f'ViT model parameters: {vit_params}')

ViT model parameters: 11139844


In [None]:
# 超参数调优
print("Tuning ViT hyperparameters...")
best_vit_params = tune_hyperparameters(lambda: get_vit(image_size=32, patch_size=4, num_classes=100, dim=192, depth=25, heads=6, mlp_dim=384, dropout=0.1, emb_dropout=0.1), train_loader, test_loader, epochs, learning_rates, weight_decays, device, alpha)

In [5]:
# 可以在这里直接设置超参数
best_vit_params = (4e-4, 1e-5)

In [7]:
# 用最优参数重新训练ViT模型并用Tensorboard可视化
best_lr_vit, best_wd_vit = best_vit_params
writer_vit = SummaryWriter(log_dir='runs/ViT_best')
# vit_model = get_vit(image_size=32, patch_size=4, num_classes=100, dim=192, depth=12, heads=3, mlp_dim=384, dropout=0.1, emb_dropout=0.1).to(device)
optimizer_vit = torch.optim.Adam(vit_model.parameters(), lr=best_lr_vit, weight_decay=best_wd_vit)
criterion = torch.nn.CrossEntropyLoss()
train(vit_model, train_loader, test_loader, optimizer_vit, criterion, epochs, device, alpha, writer_vit)
writer_vit.close()

# 保存模型的状态字典
torch.save(vit_model.state_dict(), 'vit_model_state_dict.pth')

Epoch [1/100] Train Loss: 4.4377, Accuracy: 3.73%, Validation Loss: 4.0193, Accuracy: 7.84%
Epoch [2/100] Train Loss: 4.2628, Accuracy: 6.13%, Validation Loss: 3.7911, Accuracy: 11.80%
Epoch [3/100] Train Loss: 4.1497, Accuracy: 8.59%, Validation Loss: 3.6105, Accuracy: 15.22%
Epoch [4/100] Train Loss: 4.0782, Accuracy: 10.19%, Validation Loss: 3.4194, Accuracy: 18.21%
Epoch [5/100] Train Loss: 4.0060, Accuracy: 10.96%, Validation Loss: 3.3228, Accuracy: 20.36%
Epoch [6/100] Train Loss: 3.9524, Accuracy: 12.25%, Validation Loss: 3.2378, Accuracy: 22.00%
Epoch [7/100] Train Loss: 3.8981, Accuracy: 13.86%, Validation Loss: 3.0769, Accuracy: 26.08%
Epoch [8/100] Train Loss: 3.8440, Accuracy: 14.58%, Validation Loss: 3.0501, Accuracy: 25.73%
Epoch [9/100] Train Loss: 3.8471, Accuracy: 15.36%, Validation Loss: 2.9437, Accuracy: 28.61%
Epoch [10/100] Train Loss: 3.8070, Accuracy: 16.07%, Validation Loss: 2.9150, Accuracy: 28.59%
Epoch [11/100] Train Loss: 3.7515, Accuracy: 16.52%, Validation