In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm  # 导入 tqdm.notebook 进度条

# TODO：正常的协同训练流程，仅在所有客户端训练完成后聚合模型


# 定义MLP模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 数据集加载器
def get_data_loaders(dataset_name, batch_size, data_dir="./data", num_workers=10):
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if dataset_name == "MNIST":
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        train_dataset = datasets.MNIST(
            data_dir, train=True, download=True, transform=transform
        )
        test_dataset = datasets.MNIST(
            data_dir, train=False, download=True, transform=transform
        )
    elif dataset_name == "CIFAR10":
        transform_train = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        transform_test = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        train_dataset = datasets.CIFAR10(
            data_dir, train=True, download=True, transform=transform_train
        )
        test_dataset = datasets.CIFAR10(
            data_dir, train=False, download=True, transform=transform_test
        )
    else:
        raise ValueError("Unsupported dataset")

    # 划分训练数据集
    subset_size = len(train_dataset) // num_workers
    train_loaders = [
        DataLoader(
            Subset(train_dataset, range(i * subset_size, (i + 1) * subset_size)),
            batch_size=batch_size,
            shuffle=True,
        )
        for i in range(num_workers)
    ]
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loaders, test_loader


# 本地训练函数
def local_train(model, device, train_loader, optimizer, criterion, local_epochs=1):
    model.train()
    for epoch in range(local_epochs):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()


# 模型聚合函数
def aggregate_models(global_model, local_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack(
            [local_models[i][k].float() for i in range(len(local_models))], 0
        ).mean(0)
    global_model.load_state_dict(global_dict)


def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / len(test_loader.dataset)
    return accuracy


# 通用评估函数
def eval(
    model_name,
    dataset_name,
    batch_size=32,
    total_data_limit=20000,
    log_interval=10,
    data_dir="./data",
    lr=0.01,
    momentum=0.9,
    num_workers=10,
    local_epochs=1,
):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    train_loaders, test_loader = get_data_loaders(
        dataset_name, batch_size, data_dir, num_workers
    )

    if model_name == "MLP" and dataset_name == "MNIST":
        global_model = MLP().to(device)
    elif model_name == "ResNet18" and dataset_name == "CIFAR10":
        global_model = models.resnet18(num_classes=10).to(device)
    else:
        raise ValueError("Unsupported model or dataset combination")

    criterion = nn.CrossEntropyLoss()
    all_data_amounts = []
    all_accuracies = []
    total_data = 0

    with tqdm(total=total_data_limit, desc="Training", unit="sample") as pbar:
        while total_data < total_data_limit:
            local_models = []
            for i in range(num_workers):
                local_model = (
                    MLP().to(device)
                    if model_name == "MLP"
                    else models.resnet18(num_classes=10).to(device)
                )
                local_model.load_state_dict(global_model.state_dict())
                optimizer = optim.SGD(
                    local_model.parameters(), lr=lr, momentum=momentum
                )
                local_train(
                    local_model,
                    device,
                    train_loaders[i],
                    optimizer,
                    criterion,
                    local_epochs,
                )
                local_models.append(local_model.state_dict())
                total_data += len(train_loaders[i].dataset) * local_epochs
                pbar.update(len(train_loaders[i].dataset))
                pbar.set_postfix({"Total Data": total_data})

            aggregate_models(global_model, local_models)
            accuracy = test(global_model, device, test_loader, criterion)
            print(
                f"Trained [{total_data} / {total_data_limit} samples]\tAccuracy: {accuracy:.2f}%"
            )
            all_accuracies.append(accuracy)
            all_data_amounts.append(total_data)

            if total_data >= total_data_limit:
                break

    return all_data_amounts, all_accuracies


# 主函数调用
# all_data_amounts, all_accuracies = eval(
#     model_name='MLP',
#     dataset_name='MNIST',
#     batch_size=32,
#     total_data_limit=120000,
#     num_workers=10,
#     local_epochs=1
# )

all_data_amounts, all_accuracies = eval(
    model_name="ResNet18",
    dataset_name="CIFAR10",
    batch_size=32,
    total_data_limit=2000000,
    num_workers=10,
    local_epochs=2,
)

# 保存数据到文件
torch.save(
    {"data_amounts": all_data_amounts, "accuracies": all_accuracies}, "eval_results.pt"
)

# 读取数据从文件
loaded_data = torch.load("eval_results.pt")
loaded_data_amounts = loaded_data["data_amounts"]
loaded_accuracies = loaded_data["accuracies"]

# 验证读取的数据是否正确
print(loaded_data_amounts)
print(loaded_accuracies)

Files already downloaded and verified
Files already downloaded and verified


Training:   0%|          | 0/2000000 [00:00<?, ?sample/s]