In [None]:
!pip install pandas
!pip install numpy

from tqdm import tqdm
import csv
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, ResNet18_Weights, resnet50, ResNet50_Weights
from torch.utils.data import DataLoader, Subset  # 修正：导入 Subset
import numpy as np
# dataset and dataloader
def get_dataloaders(dataset_name="cifar10", batch_size=64, num_workers=2,train_ratio=1.0):
    if dataset_name.lower() == "cifar10":
        num_classes = 10
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomCrop(224, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2470, 0.2435, 0.2616)),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2470, 0.2435, 0.2616)),
        ])
        train_dataset = torchvision.datasets.CIFAR10(
            root="./data",
            train=True,
            download=True,
            transform=transform_train
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root="./data",
            train=False,
            download=True,
            transform=transform_test
        )
    elif dataset_name.lower() == "cifar100":
        num_classes = 100
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.RandomCrop(224, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        transform_test = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        train_dataset = torchvision.datasets.CIFAR100(
            root="./data",
            train=True,
            download=True,
            transform=transform_train
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root="./data",
            train=False,
            download=True,
            transform=transform_test
        )
    else:
        raise NotImplementedError(f"Dataset {dataset_name} not supported yet.")

    # 选择部分训练数据
    num_train_samples = int(len(train_dataset) * train_ratio)  # 计算需要的样本数量
    train_indices = np.random.choice(len(train_dataset), num_train_samples, replace=False)  # 随机选取索引
    train_subset = Subset(train_dataset, train_indices)  # 创建子集

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader, num_classes


def get_model(model_name="resnet18", num_classes=10, device="cpu"):
    model_name = model_name.lower()
    if model_name == "resnet18":
        model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    elif model_name == "resnet50":
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    else:
        raise NotImplementedError(f"Model {model_name} not supported yet.")

    return model.to(device)

###############################
# 训练与测试的函数
###############################
def train_one_epoch(model, train_loader, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    loop = tqdm(train_loader, desc="Training", leave=True)  # ✅ 添加 tqdm
    for inputs, labels in loop:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # ✅ 在 tqdm 进度条中更新信息
        loop.set_postfix(loss=loss.item(), acc=100.0 * correct / total)

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc


# 测试函数（添加 tqdm 进度条）
def test_one_epoch(model, test_loader, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    loop = tqdm(test_loader, desc="Testing", leave=True)  # ✅ 添加 tqdm
    with torch.no_grad():
        for inputs, labels in loop:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # ✅ 更新 tqdm 进度条信息
            loop.set_postfix(loss=loss.item(), acc=100.0 * correct / total)

    epoch_loss = running_loss / len(test_loader)
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc

import pandas as pd
import matplotlib.pyplot as plt
def plot_results(csv_filename):
    """
    从csv_filename读取 epoch, train_loss, train_acc, test_loss, test_acc，
    分别画出Loss曲线和Accuracy曲线。
    """
    # 读取CSV
    df = pd.read_csv(csv_filename)

    # 提取列
    epochs = df['epoch']
    train_loss = df['train_loss']
    test_loss = df['test_loss']
    train_acc = df['train_acc']
    test_acc = df['test_acc']

    # 画Loss曲线
    plt.figure()
    plt.plot(epochs, train_loss, label="Train Loss")
    plt.plot(epochs, test_loss, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss over Epochs")
    plt.legend()
    plt.show()

    # 画Accuracy曲线
    plt.figure()
    plt.plot(epochs, train_acc, label="Train Acc")
    plt.plot(epochs, test_acc, label="Test Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title("Accuracy over Epochs")
    plt.legend()
    plt.show()

###############################
# 主函数
###############################
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # ============= 切换数据集和模型 =============
    dataset_name = "cifar10"  # "cifar10", "cifar100"
    model_name = "resnet50"   # "resnet18", "resnet50"
    # ===========================================

    train_loader, test_loader, num_classes = get_dataloaders(dataset_name=dataset_name)
    model = get_model(model_name=model_name,
                      num_classes=num_classes,
                      device=device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # 准备一个 CSV 文件来保存结果
    # 如果你想追加写入，可以改成 'a'; 这里我们先用 'w' 重写。
    csv_filename = f"results_{dataset_name}_{model_name}.csv"
    with open(csv_filename, mode='w', newline='') as f:
        writer = csv.writer(f)
        # 写标题行
        writer.writerow(["epoch", "train_loss", "train_acc", "test_loss", "test_acc"])

        EPOCHS = 50
        best_acc = 0.0
        for epoch in range(1, EPOCHS+1):
            start_time = time.time()
            train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
            test_loss, test_acc = test_one_epoch(model, test_loader, device)
            epoch_time = time.time() - start_time

            # 打印日志
            print(f"[Epoch {epoch:02d}] "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
                  f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}% | "
                  f"Time: {epoch_time:.2f}s")

            # 写入 CSV
            writer.writerow([epoch, train_loss, train_acc, test_loss, test_acc])

            # 保存最优模型
            if test_acc > best_acc:
                best_acc = test_acc
                torch.save(model.state_dict(), f"best_{dataset_name}_{model_name}.pth")

    print(f"Training finished! Best Test Acc={best_acc:.2f}%")
    print(f"Results have been saved to {csv_filename}")
    # =============== 训练结束，开始绘图 ===============
    plot_results(csv_filename)


if __name__ == "__main__":
    main()
