In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.tensorboard import SummaryWriter

In [None]:
# 定义数据增强方法：CutMix
def cutmix_data(images, labels, alpha):
    # 随机选择一个样本
    indices = torch.randperm(images.size(0))
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]

    # 计算剪切区域的大小
    lam = np.random.beta(alpha, alpha)
    cut_w = int(images.size(2) * np.sqrt(1 - lam))
    cut_h = int(images.size(3) * np.sqrt(1 - lam))

    # 随机选择剪切区域的位置
    cx = np.random.randint(images.size(2))
    cy = np.random.randint(images.size(3))
    x1 = np.clip(cx - cut_w // 2, 0, images.size(2))
    x2 = np.clip(cx + cut_w // 2, 0, images.size(2))
    y1 = np.clip(cy - cut_h // 2, 0, images.size(3))
    y2 = np.clip(cy + cut_h // 2, 0, images.size(3))

    # 剪切区域替换为随机样本的剪切区域
    images[:, :, x1:x2, y1:y2] = shuffled_images[:, :, x1:x2, y1:y2]

    # 计算新的标签
    lam = 1 - ((x2 - x1) * (y2 - y1) / (images.size(2) * images.size(3)))
    labels = (1 - lam) * labels + lam * shuffled_labels

    return images, labels

In [None]:
# 定义数据增强方法：Cutout
def cutout_data(images, labels, n_holes, length):
    h = images.size(2)
    w = images.size(3)

    for _ in range(n_holes):
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - length // 2, 0, h)
        y2 = np.clip(y + length // 2, 0, h)
        x1 = np.clip(x - length // 2, 0, w)
        x2 = np.clip(x + length // 2, 0, w)

        images[:, :, y1:y2, x1:x2] = 0

    return images, labels

In [None]:
def mixup_data(images, labels, alpha):
    # 随机选择另一个样本
    indices = torch.randperm(images.size(0))
    shuffled_images = images[indices]
    shuffled_labels = labels[indices]

    # 计算混合比例
    lam = np.random.beta(alpha, alpha)
    lam = max(lam, 1 - lam)

    # 执行mixup
    images = lam * images + (1 - lam) * shuffled_images
    labels = lam * labels + (1 - lam) * shuffled_labels

    return images, labels

In [None]:
# 加载CIFAR-100数据集
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [None]:
print('训练数据集数量：',trainloader.get_dataset_size())
print('测试数据集数量：',len(testloader))

In [None]:
# 定义ResNet-18模型
model = resnet18(pretrained=False, num_classes=100)

In [None]:
# 设置随机种子以便结果可重现
torch.manual_seed(42)
np.random.seed(42)
# 设置训练和测试的批处理大小
batch_size = 64
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [None]:
#创建 SummaryWriter 对象
writer = SummaryWriter()
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # 使用CutMix数据增强
        # inputs, targets = cutmix_data(inputs, targets, alpha=1.0)

        # 使用Cutout数据增强
        # inputs, targets = cutout_data(inputs, targets, n_holes=1, length=16)

        # 使用Mixup数据增强
        # inputs, targets = mixup_data(inputs, targets, alpha=1.0)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        # 写入训练 loss 值到 Tensorboard
        writer.add_scalar('Train/Loss', train_loss/(batch_idx+1), epoch)
        # 写入训练准确率到 Tensorboard
        train_accuracy = 100. * correct / total
        writer.add_scalar('Train/Accuracy', train_accuracy, epoch)
        if (batch_idx + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
                  .format(epoch+1, num_epochs, batch_idx+1, len(trainloader),
                          train_loss/(batch_idx+1), 100. * correct / total))

In [None]:
#保存模型
torch.save(model, './baseline_ResNet-18')

In [None]:
# 测试模型
model.eval()
test_loss = 0
test_correct = 0
test_total = 0

In [None]:
print(len(trainloader))
print(len(testloader))

In [None]:
with torch.no_grad():
    for inputs, targets in testloader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item()
        _, predicted = outputs.max(1)
        test_total += targets.size(0)
        test_correct += predicted.eq(targets).sum().item()

    print('Test Loss: {:.4f}, Test Accuracy: {:.2f}%'.format(test_loss / len(testloader),100. * test_correct / test_total))

In [None]:
# 可视化样本图像
def imshow(img):
    img = img / 2 + 0.5     # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

In [None]:
# # 获取样本数据
# 加载三张训练样本
sample_loader = torch.utils.data.DataLoader(trainset, batch_size=3, shuffle=True, num_workers=2)
example_data, example_targets = next(iter(sample_loader))
# example_data = example_data.to(device)
# # 获取样本和标签
# samples, labels = next(iter(sample_loader))
# # samples = samples.to(device)
# examples = enumerate(trainloader)
# batch_idx, (example_data, example_targets) = next(examples)

# 将样本数据进行cutmix处理
cutmix_images, cutmix_labels = cutmix_data(example_data.clone(), example_targets.clone(), alpha=1.0)

# 将样本数据进行cutout处理
cutout_images, cutout_labels = cutout_data(example_data.clone(), example_targets.clone(), n_holes=1, length=16)

# 将样本数据进行mixup处理
mixup_images, mixup_labels = mixup_data(example_data.clone(), example_targets.clone(), alpha=1.0)

In [None]:
print("Original Images:")
imshow(torchvision.utils.make_grid(example_data))
print("CutMix Images:")
imshow(torchvision.utils.make_grid(cutmix_images))
print("Cutout Images:")
imshow(torchvision.utils.make_grid(cutout_images))
print("Mixup Images:")
imshow(torchvision.utils.make_grid(mixup_images))