# FedAvg Com Transferência de Aprendizado

## Clientes com dados completos

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import copy
import torchvision.models as models
import random

random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed(123)

# Configuração do dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Função para selecionar o modelo pré-treinado
def select_model(architecture):
    if architecture == 'alexnet':
        model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, 10)
    elif architecture == 'vgg11':
        model = models.vgg11(weights=models.VGG11_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, 10)
    elif architecture == 'resnet18':
        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 10)
    elif architecture == 'mobilenet_v2':
        model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)
    elif architecture == 'squeezenet':
        model = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1)
        model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1, 1), stride=(1, 1))
        model.num_classes = 10
    else:
        raise ValueError("Arquitetura não suportada")

    return model.to(device)

# Função para criação do dataset para cada cliente
def create_subset(dataset, subset_size):
    indices = list(range(len(dataset)))
    subset_indices = random.sample(indices, subset_size)
    return Subset(dataset, subset_indices)

# Classe para o cliente que faz treinamento local
class Client:
    def __init__(self, model, dataloader, device):
        self.model = copy.deepcopy(model).to(device)
        self.dataloader = dataloader
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001, weight_decay=1e-5)
        self.device = device

    def local_train(self, criterion, num_epochs=1):
        self.model.train()
        for epoch in range(num_epochs):
            for images, labels in self.dataloader:
                images, labels = images.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

    def get_parameters(self):
        return {name: param.clone().detach() for name, param in self.model.state_dict().items()}

    def set_parameters(self, global_parameters):
        self.model.load_state_dict(global_parameters)

# Função principal de treinamento federado usando FedAvg
def federated_training_fedavg(architecture, num_clients, num_rounds):
    # Inicializar modelo base e critério de perda
    base_model = select_model(architecture)
    criterion = nn.CrossEntropyLoss()

    # Criar DataLoaders para cada cliente
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloaders = [DataLoader(create_subset(trainset, 5000), batch_size=64, shuffle=True) for _ in range(num_clients)]

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(create_subset(testset, 1000), batch_size=64, shuffle=False)

    # Inicializar clientes com dados locais
    clients = [Client(base_model, trainloaders[i], device) for i in range(num_clients)]
    global_model = copy.deepcopy(base_model)

    # Treinamento federado com FedAvg
    for round_num in range(num_rounds):
        client_models = []

        # Treinamento local em cada cliente
        for client in clients:
            client.set_parameters(global_model.state_dict())
            client.local_train(criterion, num_epochs=5)
            client_models.append(client.get_parameters())

        # Agregação FedAvg
        new_global_parameters = {}
        for name in client_models[0].keys():
            #new_global_parameters[name] = torch.mean(torch.stack([client[name] for client in client_models]), dim=0)
            new_global_parameters[name] = torch.mean(torch.stack([client[name].float() for client in client_models]), dim=0)

        # Atualizar modelo global
        global_model.load_state_dict(new_global_parameters)

        # Avaliação do modelo global
        test_loss, test_accuracy = evaluate_model(global_model, testloader, criterion)
        print(f'Round {round_num + 1}/{num_rounds}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

    # Avaliação final no conjunto de teste
    print("Treinamento FedAvg concluído.")
    test_loss, test_accuracy = evaluate_model(global_model, testloader, criterion)
    print(f'Avaliação final no conjunto de teste - Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.4f}')

# Função de avaliação do modelo
def evaluate_model(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

# Parâmetros de execução
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])



In [2]:
num_clients = 10
num_rounds = 30

## SqueezeNet

In [3]:
architecture = 'squeezenet' #'resnet18', 'mobilenet_v2', 'squeezenet' ou 'alexnet'
federated_training_fedavg(architecture, num_clients, num_rounds)

Files already downloaded and verified
Files already downloaded and verified
Round 1/30, Test Loss: 1.2536, Test Accuracy: 0.5850
Round 2/30, Test Loss: 0.8812, Test Accuracy: 0.6990
Round 3/30, Test Loss: 0.7423, Test Accuracy: 0.7620
Round 4/30, Test Loss: 0.6602, Test Accuracy: 0.7810
Round 5/30, Test Loss: 0.6009, Test Accuracy: 0.7900
Round 6/30, Test Loss: 0.5519, Test Accuracy: 0.8140
Round 7/30, Test Loss: 0.5286, Test Accuracy: 0.8130
Round 8/30, Test Loss: 0.5040, Test Accuracy: 0.8350
Round 9/30, Test Loss: 0.4899, Test Accuracy: 0.8370
Round 10/30, Test Loss: 0.4654, Test Accuracy: 0.8410
Round 11/30, Test Loss: 0.4569, Test Accuracy: 0.8430
Round 12/30, Test Loss: 0.4413, Test Accuracy: 0.8590
Round 13/30, Test Loss: 0.4306, Test Accuracy: 0.8560
Round 14/30, Test Loss: 0.4195, Test Accuracy: 0.8570
Round 15/30, Test Loss: 0.3952, Test Accuracy: 0.8640
Round 16/30, Test Loss: 0.3951, Test Accuracy: 0.8650
Round 17/30, Test Loss: 0.3880, Test Accuracy: 0.8710
Round 18/30, Te

## MobileNet_v2

In [4]:
architecture = 'mobilenet_v2' #'resnet18', 'mobilenet_v2', 'squeezenet' ou 'alexnet'
federated_training_fedavg(architecture, num_clients, num_rounds)

Files already downloaded and verified
Files already downloaded and verified
Round 1/30, Test Loss: 0.5690, Test Accuracy: 0.8420
Round 2/30, Test Loss: 0.3711, Test Accuracy: 0.8770
Round 3/30, Test Loss: 0.3142, Test Accuracy: 0.8960
Round 4/30, Test Loss: 0.2908, Test Accuracy: 0.9030
Round 5/30, Test Loss: 0.2738, Test Accuracy: 0.9090
Round 6/30, Test Loss: 0.2645, Test Accuracy: 0.9090
Round 7/30, Test Loss: 0.2603, Test Accuracy: 0.9130
Round 8/30, Test Loss: 0.2619, Test Accuracy: 0.9120
Round 9/30, Test Loss: 0.2614, Test Accuracy: 0.9140
Round 10/30, Test Loss: 0.2604, Test Accuracy: 0.9150
Round 11/30, Test Loss: 0.2569, Test Accuracy: 0.9160
Round 12/30, Test Loss: 0.2656, Test Accuracy: 0.9170
Round 13/30, Test Loss: 0.2603, Test Accuracy: 0.9210
Round 14/30, Test Loss: 0.2721, Test Accuracy: 0.9170
Round 15/30, Test Loss: 0.2812, Test Accuracy: 0.9180
Round 16/30, Test Loss: 0.2760, Test Accuracy: 0.9150
Round 17/30, Test Loss: 0.2847, Test Accuracy: 0.9160
Round 18/30, Te

## ResNet18

In [5]:
architecture = 'resnet18' #'resnet18', 'mobilenet_v2', 'squeezenet' ou 'alexnet'
federated_training_fedavg(architecture, num_clients, num_rounds)

Files already downloaded and verified
Files already downloaded and verified
Round 1/30, Test Loss: 0.3328, Test Accuracy: 0.8930
Round 2/30, Test Loss: 0.2283, Test Accuracy: 0.9230
Round 3/30, Test Loss: 0.1908, Test Accuracy: 0.9330
Round 4/30, Test Loss: 0.1807, Test Accuracy: 0.9360
Round 5/30, Test Loss: 0.1716, Test Accuracy: 0.9420
Round 6/30, Test Loss: 0.1654, Test Accuracy: 0.9330
Round 7/30, Test Loss: 0.1636, Test Accuracy: 0.9390
Round 8/30, Test Loss: 0.1576, Test Accuracy: 0.9420
Round 9/30, Test Loss: 0.1684, Test Accuracy: 0.9420
Round 10/30, Test Loss: 0.1715, Test Accuracy: 0.9420
Round 11/30, Test Loss: 0.1729, Test Accuracy: 0.9380
Round 12/30, Test Loss: 0.1675, Test Accuracy: 0.9420
Round 13/30, Test Loss: 0.1677, Test Accuracy: 0.9410
Round 14/30, Test Loss: 0.1686, Test Accuracy: 0.9390
Round 15/30, Test Loss: 0.1703, Test Accuracy: 0.9400
Round 16/30, Test Loss: 0.1669, Test Accuracy: 0.9440
Round 17/30, Test Loss: 0.1741, Test Accuracy: 0.9400
Round 18/30, Te

## AlexNet

In [6]:
architecture = 'alexnet' #'resnet18', 'mobilenet_v2', 'squeezenet' ou 'alexnet'
federated_training_fedavg(architecture, num_clients, num_rounds)

Files already downloaded and verified
Files already downloaded and verified
Round 1/30, Test Loss: 0.5776, Test Accuracy: 0.7860
Round 2/30, Test Loss: 0.4299, Test Accuracy: 0.8420
Round 3/30, Test Loss: 0.3839, Test Accuracy: 0.8550
Round 4/30, Test Loss: 0.3490, Test Accuracy: 0.8720
Round 5/30, Test Loss: 0.3330, Test Accuracy: 0.8780
Round 6/30, Test Loss: 0.3010, Test Accuracy: 0.8840
Round 7/30, Test Loss: 0.2957, Test Accuracy: 0.8830
Round 8/30, Test Loss: 0.2809, Test Accuracy: 0.9000
Round 9/30, Test Loss: 0.2770, Test Accuracy: 0.9000
Round 10/30, Test Loss: 0.2636, Test Accuracy: 0.9060
Round 11/30, Test Loss: 0.2648, Test Accuracy: 0.9040
Round 12/30, Test Loss: 0.2668, Test Accuracy: 0.9050
Round 13/30, Test Loss: 0.2542, Test Accuracy: 0.9090
Round 14/30, Test Loss: 0.2515, Test Accuracy: 0.9150
Round 15/30, Test Loss: 0.2510, Test Accuracy: 0.9160
Round 16/30, Test Loss: 0.2536, Test Accuracy: 0.9160
Round 17/30, Test Loss: 0.2609, Test Accuracy: 0.9070
Round 18/30, Te

In [7]:
# Calcular e imprimir o número de parâmetros de cada arquitetura
architectures = ['alexnet', 'vgg11', 'resnet18', 'mobilenet_v2', 'squeezenet']

for arch in architectures:
    model = select_model(arch)
    num_params = sum(p.numel() for p in model.parameters())

    dtype = next(model.parameters()).dtype

    print(f"Arquitetura: {arch}")
    print(f"{arch}: {num_params} parâmetros")
    print(f"Tipo de dado dos parâmetros: {dtype}")

Arquitetura: alexnet
alexnet: 57044810 parâmetros
Tipo de dado dos parâmetros: torch.float32
Arquitetura: vgg11
vgg11: 128807306 parâmetros
Tipo de dado dos parâmetros: torch.float32
Arquitetura: resnet18
resnet18: 11181642 parâmetros
Tipo de dado dos parâmetros: torch.float32
Arquitetura: mobilenet_v2
mobilenet_v2: 2236682 parâmetros
Tipo de dado dos parâmetros: torch.float32
Arquitetura: squeezenet
squeezenet: 727626 parâmetros
Tipo de dado dos parâmetros: torch.float32


In [None]:
''