<a href="https://colab.research.google.com/github/TheS1n233/Distributed-Learning-Project5/blob/main/Distributed_Learning_Project5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# install torch and torchvision

In [None]:
!pip install torch torchvision matplotlib




# install dataset CIFAR-100

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# install CIFAR-100
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# Create a data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:12<00:00, 13.2MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Train dataset size: 50000
Test dataset size: 10000


# Centralized baseline

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


batch_size = 128

# define LeNet-5 model
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 100)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

"""
# initial
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LeNet5().to(device)

# Define the optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# tarin loop
for epoch in range(1, 3):  # 简单训练2个epoch
    model.train()
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch {epoch}, Batch {i}, Loss: {loss.item()}")

# test model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

print(f"Test Accuracy: {100. * correct / total:.2f}%")
"""



# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Local SGD Simulation
def local_sgd_simulation(model, train_loader, num_workers=4, local_steps=5, epochs=5):
    model_global = model.to(device)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        # Divide dataset into `num_workers` partitions
        partition_size = len(train_loader.dataset) // num_workers
        data_partitions = torch.utils.data.random_split(train_loader.dataset, [partition_size] * num_workers)

        for worker_id, partition in enumerate(data_partitions):
            print(f"Worker {worker_id + 1}/{num_workers} processing...")

            # Local model
            model_local = LeNet5().to(device)
            model_local.load_state_dict(model_global.state_dict())  # Load global model
            optimizer = optim.SGD(model_local.parameters(), lr=0.01, momentum=0.9)

            # Local training
            local_loader = torch.utils.data.DataLoader(partition, batch_size=batch_size, shuffle=True)
            model_local.train()
            for _ in range(local_steps):
                for inputs, labels in local_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()
                    outputs = model_local(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

            # Synchronize weights (average)
            with torch.no_grad():
                for param_global, param_local in zip(model_global.parameters(), model_local.parameters()):
                    param_global.data = param_global.data + param_local.data / num_workers

        print(f"Epoch {epoch + 1}/{epochs} completed.")

    return model_global

# Training and testing
model = LeNet5()
trained_model = local_sgd_simulation(model, train_loader)

# Test model
trained_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = trained_model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

print(f"Test Accuracy: {100. * correct / total:.2f}%")

Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 1/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 2/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 3/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 4/5 completed.
Worker 1/4 processing...
Worker 2/4 processing...
Worker 3/4 processing...
Worker 4/4 processing...
Epoch 5/5 completed.
Test Accuracy: 1.00%
