<a href="https://colab.research.google.com/github/TheS1n233/Distributed-Learning-Project5/blob/main/Distributed_Learning_Project5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# install torch and torchvi

In [1]:
!pip install torch torchvision matplotlib




# install dataset CIFAR-100

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 下载CIFAR-100数据集
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# 训练集和验证集拆分
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# 创建验证集数据加载器
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

# 创建数据加载器

def split_dataset(dataset, num_workers):
    """将数据集拆分为 num_workers 份"""
    worker_datasets = torch.utils.data.random_split(
        dataset, [len(dataset) // num_workers] * num_workers
    )
    return worker_datasets

# Number of workers
K = 4  # 假设有 4 个工作节点
Bloc = 64  # 本地批量大小
worker_datasets = split_dataset(train_dataset, K)
worker_loaders = [
    torch.utils.data.DataLoader(ds, batch_size=Bloc, shuffle=True) for ds in worker_datasets
]


NameError: name 'train_dataset' is not defined

# Centralized baseline

In [3]:
import torch.nn as nn
import torch.nn.functional as F


# define LeNet-5 model
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 100)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 初始化全局模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
global_model = LeNet5().to(device)

# 每个工作节点初始化本地模型
worker_models = [LeNet5().to(device) for _ in range(K)]
for model in worker_models:
    model.load_state_dict(global_model.state_dict())

# 本地更新函数
def local_update(model, dataloader, criterion, optimizer, local_steps):
    """执行本地更新"""
    model.train()
    for _ in range(local_steps):
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# 模型同步函数
def average_parameters(global_model, local_states):
    """平均化所有工作节点的模型参数"""
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        global_state_dict[key] = torch.stack(
            [local_state[key] for local_state in local_states], dim=0
        ).mean(dim=0)
    global_model.load_state_dict(global_state_dict)

# 定义损失函数和训练设置
criterion = nn.CrossEntropyLoss()
H = 4  # 本地更新步数
num_rounds = 10  # 全局同步轮数

for round_idx in range(num_rounds):
    print(f"Round {round_idx + 1}/{num_rounds}")

    # 本地更新
    local_states = []
    for worker_id, (model, loader) in enumerate(zip(worker_models, worker_loaders)):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        local_state = local_update(model, loader, criterion, optimizer, H)
        local_states.append(local_state)

    # 模型同步
    average_parameters(global_model, local_states)

    # 同步后更新所有工作节点
    for model in worker_models:
        model.load_state_dict(global_model.state_dict())

# 测试全局模型
correct, total = 0, 0
global_model.eval()
with torch.no_grad():
    for inputs, labels in torch.utils.data.DataLoader(test_dataset, batch_size=Bloc, shuffle=False):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = global_model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

print(f"Global Model Test Accuracy: {100. * correct / total:.2f}%")


Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 1/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 2/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 3/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 4/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 5/5 completed.
Test Accuracy: 1.00%
