In [None]:
!pip install torchmetrics



In [None]:
import torch
import time
import random
import os
import copy
import torchmetrics
import numpy as np
import matplotlib as plt
from torch import nn
import torchvision.models as models
from torchvision import datasets, transforms
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
from torch.utils.data import DataLoader, Dataset, random_split, Subset
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image
from sklearn.utils.class_weight import compute_class_weight
from glob import glob
from sklearn.metrics import confusion_matrix
from torch.utils.data import Subset
from torch.optim.lr_scheduler import StepLR
from imblearn.over_sampling import SMOTE
from sklearn.metrics import precision_score, recall_score, f1_score

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#Definindo o número de clientes

In [None]:
num_users = 3
num_classes =  3
epochs = 80
frac = 1
lr= 0.0001
batch_size = 64


SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    print(torch.cuda.get_device_name(0))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Tesla T4


#Importando o dataset

In [None]:
data_path='/content/drive/MyDrive/ADNI'

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

full_dataset = datasets.ImageFolder(root=data_path)

all_indices= list(range(len(full_dataset)))
np.random.shuffle(all_indices)


train_size = int(0.8 * len(full_dataset))
train_indices = all_indices[:train_size]
test_indices = all_indices[train_size:]


train_dataset = Subset(full_dataset, train_indices)
test_dataset = Subset(full_dataset, test_indices)

class CustomDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        image, label = self.subset[index]
        if self.transform:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.subset)

train_dataset = CustomDataset(train_dataset, transform=train_transform)
test_dataset = CustomDataset(test_dataset, transform=test_transform)

labels = [label for _, label in train_dataset]
class_counts = np.unique(labels)
class_weights = compute_class_weight('balanced', classes=class_counts, y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights).to(device)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
def count_samples_per_class(dataset):
    labels = [label for _, label in dataset]
    class_counts = Counter(labels)
    return class_counts

train_class_counts = count_samples_per_class(train_dataset)
print("Initial class distribution in training set:", train_class_counts)


max_count = max(train_class_counts.values())


class_to_indices = {class_idx: [] for class_idx in range(num_classes)}
for idx, (_, label) in enumerate(train_dataset):
    class_to_indices[label].append(idx)

oversampled_indices = []
for class_idx, indices in class_to_indices.items():

    num_to_add = max_count - len(indices)
    if num_to_add > 0:
        indices_to_add = np.random.choice(indices, size=num_to_add, replace=True)
        oversampled_indices.extend(indices)
        oversampled_indices.extend(indices_to_add)
    else:
        oversampled_indices.extend(indices)

print(f"Total samples after oversampling: {len(oversampled_indices)}")

oversampled_train_dataset = Subset(train_dataset, oversampled_indices)

train_loader = DataLoader(oversampled_train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



Initial class distribution in training set: Counter({2: 212, 1: 147, 0: 77})
Total samples after oversampling: 636


In [None]:
def get_class_distribution(loader):
    class_counts = Counter()
    for _, labels in loader:
        class_counts.update(labels.numpy())
    return class_counts

new_train_class_counts = get_class_distribution(train_loader)
print("New class distribution in training loader:", new_train_class_counts)


New class distribution in training loader: Counter({0: 212, 1: 212, 2: 212})


#Definindo o treinamento dos clientes

In [None]:
class DenseNet_client_side(nn.Module):
    def __init__(self):
        super(DenseNet_client_side, self).__init__()
        densenet = models.densenet121(pretrained=False)
        # Extract features up to transition2
        self.features = nn.Sequential(*list(densenet.features.children())[:8])

    def forward(self, x):
        x = self.features(x)
        return x

net_glob_client = DenseNet_client_side()
if torch.cuda.device_count() > 1:
    print("We use", torch.cuda.device_count(), "GPUs")
    net_glob_client = nn.DataParallel(net_glob_client)



#Definindo o treinamento do servidor

In [None]:
class DenseNet_server_side(nn.Module):
    def __init__(self, num_classes=3):
        super(DenseNet_server_side, self).__init__()
        densenet = models.densenet121(pretrained=False)
        self.features = nn.Sequential(*list(densenet.features.children())[8:])
        self.classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Adicionando dropout
            nn.Linear(512, num_classes)
        )  # DenseNet121 has 1024 features at the end

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

net_glob_server = DenseNet_server_side(num_classes=3)
if torch.cuda.device_count() > 1:
    print("We use", torch.cuda.device_count(), "GPUs")
    net_glob_server = nn.DataParallel(net_glob_server)
    net_glob_server = net_glob_server.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
loss_train_collect = []
acc_train_collect = []
loss_test_collect = []
acc_test_collect = []
auc_train_collect = []
auc_test_collect = []
batch_acc_train = []
batch_loss_train = []
batch_precision_train = []
batch_recall_train = []
acc_avg_all_user_train = []
loss_avg_all_user_train = []
loss_avg_train_all = []
acc_avg_train_all = []
acc_train_collect_user = []
loss_train_collect_user = []
loss_test_collect_user = []
batch_auc_train = []
batch_acc_test = []
batch_loss_test = []
batch_prec_test = []
batch_recall_test = []

count1 = 0
count2 = 0

In [None]:
def FedSGD(gradients_list):
    avg_grad = {}
    for name in gradients_list[0].keys():
        grad_sum = torch.zeros_like(gradients_list[0][name], device=device)
        for gradients in gradients_list:
            grad_sum += gradients[name].to(device)
        avg_grad[name] = grad_sum / len(gradients_list)
    return avg_grad

def calculate_metrics(fx, y, num_classes):
    preds = fx.argmax(dim=1)
    accuracy = torchmetrics.functional.accuracy(preds, y, task='multiclass', num_classes=num_classes)
    precision = torchmetrics.functional.precision(preds, y, average='macro', task='multiclass', num_classes=num_classes)
    recall = torchmetrics.functional.recall(preds, y, average='macro', task='multiclass', num_classes=num_classes)
    f1 = torchmetrics.functional.f1_score(preds, y, average='macro', task='multiclass', num_classes=num_classes)
    auc = torchmetrics.functional.auroc(F.softmax(fx, dim=1), y, task='multiclass', num_classes=num_classes)
    return accuracy, precision, recall, f1, auc

w_glob_server = net_glob_server.state_dict()
w_locals_server = []
weight_decay = 1e-4

idx_collect = []
l_epoch_check = False
fed_check = False

net_model_server = [copy.deepcopy(net_glob_server).to(device) for _ in range(num_users)]

net_server = copy.deepcopy(net_model_server[0]).to(device)
optimizer_server = torch.optim.Adam(net_server.parameters(), lr=lr, weight_decay=weight_decay)
scheduler_server = torch.optim.lr_scheduler.StepLR(optimizer_server, step_size=30, gamma=0.1)


In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, validation_loss):
        if self.best_loss is None:
            self.best_loss = validation_loss
        elif validation_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = validation_loss
            self.counter = 0

# Instanciar o early stopping
early_stopping = EarlyStopping(patience=7, min_delta=0.01)

In [None]:
def train_server(fx_client, y, l_epoch_count, l_epoch, idx, len_batch, num_classes=3):
    global net_model_server, criterion, optimizer_server, device
    global batch_acc_train, batch_loss_train, batch_precision_train, batch_recall_train, batch_auc_train
    global l_epoch_check, fed_check, loss_train_collect, acc_train_collect, count1
    global acc_avg_all_user_train, loss_avg_all_user_train, idx_collect, w_locals_server, w_glob_server, net_server
    global loss_train_collect_user, acc_train_collect_user, lr

    # Certifique-se de que o modelo está no dispositivo correto
    net_server = net_model_server[idx].to(device)
    net_server.train()
    optimizer_server = torch.optim.Adam(net_server.parameters(), lr=lr)

    optimizer_server.zero_grad()
    fx_client = fx_client.to(device)
    y = y.to(device)  # Move os rótulos para o mesmo dispositivo

    # Forward pass
    fx_server = net_server(fx_client)

    # Calcule a perda
    loss = criterion(fx_server, y)
    acc, precision, recall, f1, auc = calculate_metrics(fx_server, y, num_classes)
    loss.backward()

    # Coletar gradientes como um dicionário
    gradients = {name: param.grad.clone().detach() for name, param in net_server.named_parameters() if param.grad is not None}

    dfx_client = fx_client.grad.clone().detach()

    # Atualize as métricas de treinamento do lote
    batch_loss_train.append(loss.item())
    batch_acc_train.append(acc.item())
    batch_precision_train.append(precision.item())
    batch_recall_train.append(recall.item())
    batch_auc_train.append(auc.item())

    net_model_server[idx].server_gradients = gradients

    count1 += 1
    if count1 == len_batch:
        acc_avg_train = sum(batch_acc_train) / len(batch_acc_train)
        loss_avg_train = sum(batch_loss_train) / len(batch_loss_train)
        precision_avg_train = sum(batch_precision_train) / len(batch_precision_train)
        recall_avg_train = sum(batch_recall_train) / len(batch_recall_train)
        auc_avg_train = sum(batch_auc_train) / len(batch_auc_train)

        batch_acc_train.clear()
        batch_loss_train.clear()
        batch_precision_train.clear()
        batch_recall_train.clear()
        batch_auc_train.clear()
        count1 = 0


        print(f'Client{idx} Train => Local Epoch: {l_epoch_count} \tAcc: {acc_avg_train:.3f} \tLoss: {loss_avg_train:.4f} \tPrecision: {precision_avg_train:.3f} \tRecall: {recall_avg_train:.3f} \tAUC: {auc_avg_train:.3f}')

        if l_epoch_count == l_epoch - 1:
            l_epoch_check = True
            w_locals_server.append(gradients)

            acc_avg_all_user_train = acc_avg_train
            loss_avg_all_user_train = loss_avg_train

            loss_train_collect_user.append(loss_avg_train_all)
            acc_train_collect_user.append(acc_avg_train_all)

            if idx not in idx_collect:
                idx_collect.append(idx)

            if len(idx_collect) == num_users:
                fed_check = True


    return dfx_client


In [None]:
def evaluate_server(fx_client, y, idx, len_batch, ell, num_classes=3):
    global net_model_server, criterion, batch_acc_test, batch_loss_test, check_fed, net_server, net_glob_server
    global loss_test_collect, acc_test_collect, count2, num_users, acc_avg_train_all, loss_avg_train_all, w_glob_server, l_epoch_check, fed_check
    global loss_test_collect_user, acc_test_collect_user, acc_avg_all_user_train, loss_avg_all_user_train

    batch_precision_test = []
    acc_test_collect_user = []
    batch_recall_test = []
    batch_auc_test = []
    batch_f1_test = []

    net = copy.deepcopy(net_model_server[idx]).to(device)
    net.eval()

    with torch.no_grad():
        fx_client = fx_client.to(device)
        y = y.to(device)

        fx_server = net(fx_client)
        loss = criterion(fx_server, y)

        acc, precision, recall, f1, auc = calculate_metrics(fx_server, y, num_classes)

        batch_loss_test.append(loss.item())
        batch_acc_test.append(acc.item())
        batch_precision_test.append(precision.item())
        batch_recall_test.append(recall.item())
        batch_auc_test.append(auc.item())
        batch_f1_test.append(f1.item())

        count2 += 1
        if count2 == len_batch:
            acc_avg_test = sum(batch_acc_test) / len(batch_acc_test)
            loss_avg_test = sum(batch_loss_test) / len(batch_loss_test)
            precision_avg_test = sum(batch_precision_test) / len(batch_precision_test)
            recall_avg_test = sum(batch_recall_test) / len(batch_recall_test)
            auc_avg_test = sum(batch_auc_test) / len(batch_auc_test)
            f1_avg_test = sum(batch_f1_test) / len(batch_f1_test)

            batch_acc_test = []
            batch_loss_test = []
            batch_precision_test = []
            batch_recall_test = []
            batch_auc_test = []
            batch_f1_test = []
            count2 = 0

            print('Client{} Test =>                   \tAcc: {:.3f} \tLoss: {:.4f} \tPrecision: {:.3f} \tRecall: {:.3f} \tAUC: {:.3f} \tF1-Score: {:.3f}'.format(
                idx, acc_avg_test, loss_avg_test, precision_avg_test, recall_avg_test, auc_avg_test, f1_avg_test))

            if l_epoch_check:
                l_epoch_check = False

                acc_avg_test_all = acc_avg_test
                loss_avg_test_all = loss_avg_test

                loss_test_collect_user.append(loss_avg_test_all)
                acc_test_collect_user.append(acc_avg_test_all)

            if fed_check:
                fed_check = False
                print("------------------------------------------------")
                print("------ Federation process at Server-Side ------- ")
                print("------------------------------------------------")


                acc_avg_all_user = sum(acc_test_collect_user) / len(acc_test_collect_user)
                loss_avg_all_user = sum(loss_test_collect_user) / len(loss_test_collect_user)

                loss_test_collect.append(loss_avg_all_user)
                acc_test_collect.append(acc_avg_all_user)
                acc_test_collect_user = []
                loss_test_collect_user = []

                print("====================== SERVER V1==========================")
                print(' Train: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(ell, acc_avg_all_user_train, loss_avg_all_user_train))
                print(' Test: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(ell, acc_avg_all_user, loss_avg_all_user))
                print("==========================================================")
            return

In [None]:
def evaluate_accuracy(net, loader, device, return_conf_matrix=False, num_classes=3):
    net.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())


    all_preds = torch.tensor(all_preds)
    all_labels = torch.tensor(all_labels)


    accuracy = torchmetrics.functional.accuracy(all_preds, all_labels, task='multiclass', num_classes=num_classes).item()
    precision = torchmetrics.functional.precision(all_preds, all_labels, average='macro', task='multiclass', num_classes=num_classes).item()
    recall = torchmetrics.functional.recall(all_preds, all_labels, average='macro', task='multiclass', num_classes=num_classes).item()
    f1 = torchmetrics.functional.f1_score(all_preds, all_labels, average='macro', task='multiclass', num_classes=num_classes).item()
    auc = torchmetrics.functional.auroc(F.softmax(outputs, dim=1), all_labels, task='multiclass', num_classes=num_classes).item()


    conf_matrix = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())

    if return_conf_matrix:
        return accuracy, precision, recall, f1, auc, conf_matrix
    else:
        return accuracy, precision, recall, f1, auc

class DatasetSplit(Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = list(indices)
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        image, label = self.dataset[real_idx]
        if self.transform:
            image = self.transform(image)
        return image, label

class Client(object):
    def __init__(self, net_client_model, idx, lr, device, train_loader=None, test_loader=None, idxs=None, idxs_test=None):
        self.idx = idx
        self.device = device
        self.lr = lr
        self.local_ep = 1
        self.train_dataset = DatasetSplit(full_dataset, train_indices, transform=train_transform)
        self.test_dataset = DatasetSplit(full_dataset, test_indices, transform=test_transform)
        self.ldr_train = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        self.ldr_test = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False)

    def train(self, net):
        net.train()
        optimizer_client = torch.optim.Adam(net.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer_client, step_size=30, gamma=0.1)

        for local_epoch in range(self.local_ep):
            len_batch = len(self.ldr_train)

            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer_client.zero_grad()
                fx = net(images)
                client_fx = fx.clone().detach().requires_grad_(True)
                dfx = train_server(client_fx, labels, local_epoch, self.local_ep, self.idx, len_batch)
                fx.backward(dfx)
                torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
                optimizer_client.step()

            scheduler.step()

        gradients = {name: param.grad.clone() for name, param in net.named_parameters()}
        return gradients

    def get_server_gradients_from_train_server(self):
        return net_model_server[self.idx].server_gradients


    def evaluate(self, net, ell, num_classes=3):
      net.eval()
      with torch.no_grad():
        len_batch = len(self.ldr_test)
        for batch_idx, (images, labels) in enumerate(self.ldr_test):
            images, labels = images.to(self.device), labels.to(self.device)
            fx = net(images)
            evaluate_server(fx, labels, self.idx, len_batch, ell, num_classes)
        return

def dataset_iid(indices, num_users):
    indices = np.array(indices)
    num_items_per_user = len(indices) // num_users
    dict_users = {}
    for i in range(num_users):
        start_idx = i * num_items_per_user
        end_idx = (i + 1) * num_items_per_user if i != num_users - 1 else len(indices)
        dict_users[i] = indices[start_idx:end_idx]
    return dict_users




##Treino!


In [None]:
epoch_times = []

dict_users = dataset_iid(train_indices, num_users)
dict_users_test = dataset_iid(test_indices, num_users)
net_glob_client.train()

w_glob_client = net_glob_client.state_dict()


for iter in range(epochs):
    start_time = time.time()
    grad_locals_client = []
    w_locals_server = []
    m = max(int(frac * num_users), 1)
    idxs_users = np.random.choice(range(num_users), m, replace=False)
    w_locals_client = []

    for idx in idxs_users:
        local = Client(
            net_glob_client,
            idx,
            lr,
            device,
            idxs=dict_users[idx],
            idxs_test=dict_users_test[idx]
        )

        # Client training
        gradients_client = local.train(net=copy.deepcopy(net_glob_client).to(device))
        if gradients_client:
            grad_locals_client.append(gradients_client)

        # Collect server-side gradients
        gradients_server = local.get_server_gradients_from_train_server()
        if gradients_server:
            w_locals_server.append(gradients_server)

        # Client evaluation
        local.evaluate(net=copy.deepcopy(net_glob_client).to(device), ell=iter)

    print("-----------------------------------------------------------")
    print("------ Federation process at Client-Side ------- ")
    print("-----------------------------------------------------------")

    # Aggregate client-side gradients and update net_glob_client
    if len(grad_locals_client) > 0:
        avg_gradients_client = FedSGD(grad_locals_client)
        for name in avg_gradients_client:
            avg_gradients_client[name] = avg_gradients_client[name].to(device)
        for name, param in net_glob_client.named_parameters():
            if name in avg_gradients_client:
                param.data = param.data.to(device) - lr * avg_gradients_client[name].to(device)

    # Aggregate server-side gradients and update net_glob_server
    if len(w_locals_server) > 0:
        w_glob_server = FedSGD(w_locals_server)
        for name in w_glob_server:
            w_glob_server[name] = w_glob_server[name].to(device)
        net_glob_server.to(device)
        for name, param in net_glob_server.named_parameters():
            if name in w_glob_server:
                param.data -= lr * w_glob_server[name].to(device)

        # Update local server models
        net_model_server = [copy.deepcopy(net_glob_server).to(device) for _ in range(num_users)]

    epoch_time = time.time() - start_time
    epoch_times.append(epoch_time)
    print(f"Epoch {iter + 1}/{epochs} - Time taken: {epoch_time:.2f} seconds")

print("Training and Evaluation completed!")



Client2 Train => Local Epoch: 0 	Acc: 0.410 	Loss: 1.1012 	Precision: 0.319 	Recall: 0.351 	AUC: 0.505
Client2 Test =>                   	Acc: 0.501 	Loss: 1.1887 	Precision: 0.178 	Recall: 0.333 	AUC: 0.387 	F1-Score: 0.232
------------------------------------------------
------ Federation process at Server-Side ------- 
------------------------------------------------
 Train: Round   0, Avg Accuracy 0.410 | Avg Loss 1.101
 Test: Round   0, Avg Accuracy 0.501 | Avg Loss 1.189
Client1 Train => Local Epoch: 0 	Acc: 0.394 	Loss: 1.1011 	Precision: 0.348 	Recall: 0.338 	AUC: 0.513
Client1 Test =>                   	Acc: 0.501 	Loss: 1.1862 	Precision: 0.178 	Recall: 0.333 	AUC: 0.389 	F1-Score: 0.232
------------------------------------------------
------ Federation process at Server-Side ------- 
------------------------------------------------
 Train: Round   0, Avg Accuracy 0.394 | Avg Loss 1.101
 Test: Round   0, Avg Accuracy 0.501 | Avg Loss 1.186
Client0 Train => Local Epoch: 0 	Acc

KeyboardInterrupt: 