In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# GPU 사용 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
# CIFAR-10 데이터셋 로드 및 전처리
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
# 클라이언트 수
num_clients = 3

# 데이터셋을 클라이언트 수로 나눔
lengths = [len(trainset) // num_clients] * num_clients
for i in range(len(trainset) % num_clients):
    lengths[i] += 1

client_datasets = random_split(trainset, lengths)

trainloaders = [DataLoader(client_dataset, batch_size=100, shuffle=True) for client_dataset in client_datasets]
testloader = DataLoader(testset, batch_size=100, shuffle=False)

In [15]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [16]:
# 모델 초기화 및 복사 함수
def get_model():
    return SimpleCNN().to(device)

def copy_model(target_model, source_model):
    target_model.load_state_dict(source_model.state_dict())

In [19]:
def federated_learning(num_rounds, num_epochs, learning_rate):
    global_model = get_model()
    client_models = [get_model() for _ in range(num_clients)]
    accuracy_list = []

    for round in range(num_rounds):
        print(f'Round {round+1}/{num_rounds}')

        # 각 클라이언트에서 로컬 학습 수행
        for client_idx, client_model in enumerate(client_models):
            optimizer = optim.SGD(client_model.parameters(), lr=learning_rate, momentum=0.9)
            criterion = nn.CrossEntropyLoss()

            for epoch in range(num_epochs):
                running_loss = 0.0
                for inputs, labels in trainloaders[client_idx]:
                    inputs, labels = inputs.to(device), labels.to(device)

                    optimizer.zero_grad()
                    outputs = client_model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item()

                print(f'Client {client_idx+1}, Epoch {epoch+1}, Loss: {running_loss/len(trainloaders[client_idx])}')

        # 중앙 서버에서 글로벌 모델 업데이트
        global_dict = global_model.state_dict()
        for k in global_dict.keys():
            global_dict[k] = torch.stack([client_models[i].state_dict()[k].float() for i in range(num_clients)], 0).mean(0)

        global_model.load_state_dict(global_dict)

        # 각 클라이언트 모델을 글로벌 모델로 업데이트
        for client_model in client_models:
            copy_model(client_model, global_model)

        # 각 라운드마다 글로벌 모델 성능 테스트
        accuracy = test_model(global_model)
        accuracy_list.append(accuracy)
        print(f'Round {round+1}, Global Model Accuracy: {accuracy:.2f}%')

    return global_model, accuracy_list

In [18]:
# 테스트 함수 정의
def test_model(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the global model on the 10000 test images: {accuracy:.2f}%')
    return accuracy

In [None]:
# 연합학습 실행 및 성능 테스트
global_model, accuracy_list = federated_learning(num_rounds=5, num_epochs=3, learning_rate=0.01)


# 성능 시각화
plt.figure()
plt.plot(range(1, len(accuracy_list) + 1), accuracy_list, marker='o')
plt.xlabel('Round')
plt.ylabel('Accuracy (%)')
plt.title('Global Model Accuracy on CIFAR-10 Test Set per Round')
plt.grid(True)
plt.show()

Round 1/5
Client 1, Epoch 1, Loss: 2.194660859907459
Client 1, Epoch 2, Loss: 1.868167899325936
Client 1, Epoch 3, Loss: 1.6541572937708415
