In [None]:
# data pre-process
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset

transform = transforms.Compose(
    [
    transforms.RandomHorizontalFlip(p=0.5),#隨機水平翻轉圖像，概率為0.5。
    transforms.RandomVerticalFlip(p=0.5),#隨機垂直翻轉圖像，概率為0.5。
    transforms.RandomRotation(degrees=30),#隨機旋轉 -30° 到 30° 之間
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

batch_size = 30

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=2)

valset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,shuffle=False, num_workers=2)

test_indices = list(range(10))
testset = Subset(valset, test_indices)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False, num_workers=2)

classes = ('airplane', 'automobile', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 48.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.models as models
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_model(model_name, model, train_loader, val_loader, criterion, optimizer, epochs):
    model_path="vgg19_bn_ver2.pth"
    model = model.to(device)
    train_losses = []
    val_losses = []
    train_acc = []
    val_acc = []
    best_val_acc = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct_train += (predicted == labels).sum().item()
            total_train += labels.size(0)

        train_losses.append(running_loss / len(train_loader))
        train_acc.append(correct_train / total_train)

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
                total_val += labels.size(0)

        val_losses.append(val_loss / len(val_loader))
        val_acc.append(correct_val / total_val)
        if val_acc[-1] > best_val_acc:
          best_val_acc = val_acc[-1]
          torch.save(model.state_dict(), model_path)  # 儲存模型參數
          print(f"New best model saved with Val Acc: {best_val_acc:.4f}")

        print(f'Epoch {epoch+1}/{epochs}, '
            f'Train Loss: {train_losses[-1]:.4f}, '
            f'Val Loss: {val_losses[-1]:.4f}, '
            f'Train Acc: {train_acc[-1]:.4f}, '
            f'Val Acc: {val_acc[-1]:.4f}')

    #output:Plotting the losses
    plt.figure(figsize=(10, 6))
    min_length = min(len(train_losses), len(val_losses))
    plt.plot(range(1, min_length + 1), train_losses[:min_length], label="Training Loss", marker="o")
    plt.plot(range(1, min_length + 1), val_losses[:min_length], label="Validation Loss", marker="o")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title(model_name+" Training and Validation Loss")
    plt.legend()
    plt.grid(True)
    save_path = f"{model_name}_loss.png"
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()

    #output:Plotting the Accuracy
    plt.figure(figsize=(10, 6))
    min_length = min(len(train_acc), len(val_acc))
    plt.plot(range(1, min_length + 1), train_acc[:min_length], label="Training accuracy", marker="o")
    plt.plot(range(1, min_length + 1), val_acc[:min_length], label="Validation accuracy", marker="o")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title(model_name+" Training and Validation Accuracy")
    plt.legend()
    plt.grid(True)
    save_path = f"{model_name}_acc.png"
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()


def test_model(model_name,model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
          inputs, labels = inputs.to(device), labels.to(device)
          outputs = model(inputs)
          _, predicted = torch.max(outputs, 1) # 獲取預測結果
          correct += (predicted == labels).sum().item()
          total += labels.size(0)

#ResNet_34
# 初始化 VGG19_BN 模型，并设置输出类别数为 10
model = models.vgg19_bn(num_classes=10)
# 查看模型结构
summary(model.to(device), input_size=(3, 32, 32))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
train_model("vgg19_bn_ver2",model, trainloader, valloader, criterion, optimizer, epochs=80)
test_model("vgg19_bn_ver2",model, testloader)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 16, 16]               0
        MaxPool2d-14            [-1, 12

KeyboardInterrupt: 