In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import default_collate, DataLoader
from torchvision.transforms import v2
from torchvision.datasets import CIFAR100
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
import timm
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_preproc = v2.Compose([
    v2.PILToTensor(),
    v2.RandomResizedCrop(size=(32, 32), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    v2.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762]),
])

test_preproc = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),  # to float32 in [0, 1]
    v2.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762]),
])

# CutMix in torchvision
cutmix = v2.CutMix(num_classes=100)
def collate_fn(batch):
    return cutmix(*default_collate(batch))

# 定义 Batch_size 大小
batch_size = 512

# 数据集加载
trainset = CIFAR100(root='./data', train=True, download=True, transform=train_preproc)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True, collate_fn=collate_fn)
testset = CIFAR100(root='./data', train=False, download=True, transform=test_preproc)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)
# CutMix对比组
trainset_c1 = CIFAR100(root='./data', train=True, download=True, transform=train_preproc)
trainloader_c1 = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)
testset_c1 = CIFAR100(root='./data', train=False, download=True, transform=test_preproc)
testloader_c1 = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)
# 其它数据增强对比组
trainset_c2 = CIFAR100(root='./data', train=True, download=True, transform=test_preproc)
trainloader_c2 = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)
testset_c2 = CIFAR100(root='./data', train=False, download=True, transform=test_preproc)
testloader_c2 = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
# 自定义 CutMix
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x.size()[0])
    target_a = y
    target_b = y[rand_index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2]
    return x, target_a, target_b, lam


In [4]:
# ResNet18
class ResNet18_CIFAR100(torch.nn.Module):
    def __init__(self):
        super(ResNet18_CIFAR100, self).__init__()
        self.model = models.resnet18(pretrained=False, num_classes=100)

    def forward(self, x):
        return self.model(x)

# Vision Transformer
class ViT_CIFAR100(torch.nn.Module):
    def __init__(self):
        super(ViT_CIFAR100, self).__init__()
        self.model = timm.create_model(
            'vit_base_patch16_224',
            pretrained=False,
            img_size=32,
            patch_size=4,
            embed_dim=256,
            depth=12,
            num_heads=4,
            mlp_ratio=4.0,  # mlp hidden size = 256 * 4 = 1024
            num_classes=100
        )

    def forward(self, x):
        return self.model(x)

In [5]:
import time
# 定义训练函数
def train_model(model, trainloader, testloader, epochs=100, lr=0.1, wd=1e-4, log_dir='./logs'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter(log_dir=log_dir)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    # 优化器设置
    # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    # 学习率调度器
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.1)
    # WarmUp的epoch数量
    warmup_epochs = 5  

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        start = time.time()

        # WarmUp 策略
        if epoch < warmup_epochs:
            lr_scale = min(1., float(epoch + 1) / warmup_epochs)
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * lr

        for inputs, targets in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()  
        
        train_loss = running_loss / len(trainloader)
        writer.add_scalar('Loss/train', train_loss, epoch)
        print(f'Epoch [{epoch + 1}/{epochs}], Training Loss: {train_loss}')
        # scheduler.step()

        # 验证模型
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        test_loss = running_loss / len(testloader)
        test_accuracy = 100 * correct / total
        writer.add_scalar('Loss/test', test_loss, epoch)
        writer.add_scalar('Accuracy/test', test_accuracy, epoch)
        print(f'Validation Loss: {test_loss}, Validation Accuracy: {test_accuracy}%')
        end = time.time()
        t = end - start
        et = (epochs - epoch - 1) * t
        hours = int(et // 3600)
        minutes = int((et % 3600) // 60)
        print(f"Time: {t:.2f} seconds, Expected time: {hours} hours and {minutes} minutes.")


    writer.close()


In [6]:
# 设置参数
Max_epoch = 200
base_lr_list = [0.00001]
weight_decay_list = [0]

for base_lr in base_lr_list:
    for weight_decay in weight_decay_list:
        # 训练ResNet-18模型
        torch.cuda.empty_cache()
        print("="*50)
        print(f"Training ResNet-18: Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}")
        print("="*50)
        resnet18_model = ResNet18_CIFAR100()
        train_model(resnet18_model, trainloader, testloader, Max_epoch, lr=base_lr, wd=weight_decay, log_dir=f'./logs/resnet18/Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}_e200')

        # 训练ViT模型
        torch.cuda.empty_cache()
        print("="*50)
        print(f"Training ViT: Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}")
        print("="*50)
        vit_model = ViT_CIFAR100()
        train_model(vit_model, trainloader, testloader, Max_epoch, lr=base_lr, wd=weight_decay, log_dir=f'./logs/vit/Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}_e200')


Training ViT: Adam_bt-512_lr1e-05_wd-0
Epoch [1/200], Training Loss: 4.56467791479461
Validation Loss: 4.4666100025177, Validation Accuracy: 3.53%
Time: 25.44 seconds, Expected time: 1 hours and 24 minutes.
Epoch [2/200], Training Loss: 4.456883668899536
Validation Loss: 4.3368000984191895, Validation Accuracy: 5.41%
Time: 25.06 seconds, Expected time: 1 hours and 22 minutes.
Epoch [3/200], Training Loss: 4.386797014547854
Validation Loss: 4.2487212181091305, Validation Accuracy: 7.04%
Time: 25.16 seconds, Expected time: 1 hours and 22 minutes.
Epoch [4/200], Training Loss: 4.349663097031263
Validation Loss: 4.184719395637512, Validation Accuracy: 7.62%
Time: 25.18 seconds, Expected time: 1 hours and 22 minutes.
Epoch [5/200], Training Loss: 4.303785197588862
Validation Loss: 4.117759728431702, Validation Accuracy: 8.79%
Time: 25.22 seconds, Expected time: 1 hours and 21 minutes.
Epoch [6/200], Training Loss: 4.248463903154645
Validation Loss: 4.049191689491272, Validation Accuracy: 10

In [7]:
# # 设置参数
# Max_epoch = 200
# base_lr_list = [0.0005]
# weight_decay_list = [0]

# for base_lr in base_lr_list:
#     for weight_decay in weight_decay_list:
#         # 训练ResNet-18模型
#         torch.cuda.empty_cache()
#         print("="*50)
#         print(f"Training ResNet-18: Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}")
#         print("="*50)
#         resnet18_model = ResNet18_CIFAR100()
#         train_model(resnet18_model, trainloader_c1, testloader_c1, Max_epoch, lr=base_lr, wd=weight_decay, log_dir=f'./logs/resnet18/Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}_e200_noCutMix')

#         # 训练ViT模型
#         torch.cuda.empty_cache()
#         print("="*50)
#         print(f"Training ViT: Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}")
#         print("="*50)
#         vit_model = ViT_CIFAR100()
#         train_model(vit_model, trainloader_c1, testloader_c1, Max_epoch, lr=base_lr, wd=weight_decay, log_dir=f'./logs/vit/Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}_e200_noCutMix')

# for base_lr in base_lr_list:
#     for weight_decay in weight_decay_list:
#         # 训练ResNet-18模型
#         torch.cuda.empty_cache()
#         print("="*50)
#         print(f"Training ResNet-18: Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}")
#         print("="*50)
#         resnet18_model = ResNet18_CIFAR100()
#         train_model(resnet18_model, trainloader_c2, testloader_c2, Max_epoch, lr=base_lr, wd=weight_decay, log_dir=f'./logs/resnet18/Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}_e200_noEnhance')


#         # 训练ViT模型
#         torch.cuda.empty_cache()
#         print("="*50)
#         print(f"Training ViT: Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}")
#         print("="*50)
#         vit_model = ViT_CIFAR100()
#         train_model(vit_model, trainloader_c2, testloader_c2, Max_epoch, lr=base_lr, wd=weight_decay, log_dir=f'./logs/vit/Adam_bt-{batch_size}_lr{base_lr}_wd-{weight_decay}_e200_noEnhance')
