<a href="https://colab.research.google.com/github/Luanmantegazine/TF/blob/main/FED_ADNIData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import time
import random
import os
import copy
import torchmetrics
import numpy as np
import nibabel as nib
import matplotlib as plt
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.nn as nn
import shutil
import math
import os.path
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from PIL import Image
from glob import glob
from sklearn.metrics import confusion_matrix
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR
from imblearn.over_sampling import SMOTE


#Definindo o número de clientes

In [None]:
num_users = 3
num_classes =  4
epochs = 100
frac = 1
lr= 0.0001
batch_size = 64
criterion= nn.CrossEntropyLoss()

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    print(torch.cuda.get_device_name(0))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Importando o dataset

In [None]:
data_path = '/Users/luanr/pycharm/TF/Alzheimers Dataset'

train_dataset = datasets.ImageFolder(
    root=data_path + '/train',
    transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomPerspective(),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)


test_dataset = datasets.ImageFolder(
    root=data_path + '/test',
    transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def count_samples_per_class(dataset):
    class_counts = {}
    for class_idx, class_name in enumerate(dataset.classes):
        class_counts[class_name] = len([label for _, label in dataset.samples if label == class_idx])
    return class_counts

print("Número de amostras por classe no treino:")
train_class_counts = count_samples_per_class(train_dataset)
for class_name, count in train_class_counts.items():
    print(f"{class_name}: {count}")

print("Número de amostras por classe no teste:")
test_class_counts = count_samples_per_class(test_dataset)
for class_name, count in test_class_counts.items():
    print(f"{class_name}: {count}")

#Definindo o treinamento do cliente

In [None]:
class ResNet50_client_side(nn.Module):
    def __init__(self):
        super(ResNet50_client_side, self).__init__()
        self.layer1 = nn.Sequential (
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.layer2 = self._make_layer(64, 64, 3)

    def _make_layer(self, in_planes, out_planes, blocks):
        layers = []
        layers.append(self._make_bottleneck(in_planes, out_planes, stride=1))
        for _ in range(1, blocks):
            layers.append(self._make_bottleneck(out_planes * 4, out_planes, stride=1))
        return nn.Sequential(*layers)

    def _make_bottleneck(self, in_planes, planes, stride=1):
        expansion = 4
        out_planes = planes * expansion
        return nn.Sequential(
            nn.Conv2d(in_planes, planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, out_planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_planes),
        )

    def forward(self, x):
        residual1 = self.layer1(x)
        out1 = self.layer2(residual1)
        out1 = F.relu(out1) if out1.size(1) == residual1.size(1) else nn.Conv2d(out1.size(1), residual1.size(1), kernel_size=1, bias=False).to(out1.device)(out1)
        out1 += residual1
        residual2 = F.relu(out1)
        return residual2

net_glob_client = ResNet50_client_side()
if torch.cuda.device_count() > 1:
    print("We use", torch.cuda.device_count(), "GPUs")
    net_glob_client = nn.DataParallel(net_glob_client)


#Definindo o treinamento do servidor

In [None]:
class Baseblock(nn.Module):
    expansion = 4

    def __init__(self, input_planes, planes, stride=1, dim_change=None):
        super(Baseblock, self).__init__()
        self.conv1 = nn.Conv2d(input_planes, planes, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.dim_change = dim_change

    def forward(self, x):
        res = x

        output = F.relu(self.bn1(self.conv1(x)))
        output = F.relu(self.bn2(self.conv2(output)))
        output = self.bn3(self.conv3(output))

        if self.dim_change is not None:
            res = self.dim_change(res)

        output += res
        output = F.relu(output)

        return output


class ResNet50_server_side(nn.Module):
    def __init__(self, block, num_layers, num_classes=4):
        super(ResNet50_server_side, self).__init__()
        self.input_planes = 64
        self.conv1 = nn.Conv2d(64, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, num_layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        dim_change = None
        if stride != 1 or self.input_planes != planes * block.expansion:
            dim_change = nn.Sequential(
                nn.Conv2d(self.input_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )

        layers = []
        layers.append(block(self.input_planes, planes, stride, dim_change))
        self.input_planes = planes * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.input_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.maxpool(F.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


net_glob_server = ResNet50_server_side(Baseblock, [3, 4, 6, 3], num_classes=4)

if torch.cuda.device_count() > 1:
    print("We use", torch.cuda.device_count(), "GPUs")
    net_glob_server = nn.DataParallel(net_glob_server)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
batch_acc_train = []
batch_loss_train = []
batch_precision_train = []
batch_recall_train = []
acc_train_collect_user = []
loss_train_collect_user = []
loss_test_collect_user = []
batch_auc_train = []
batch_acc_test = []
batch_loss_test = []
batch_prec_test = []
batch_recall_test = []
loss_train_collect = [[] for _ in range(num_users)]
acc_train_collect = [[] for _ in range(num_users)]
loss_test_collect = [[] for _ in range(num_users)]
acc_test_collect = [[] for _ in range(num_users)]
auc_train_collect = [[] for _ in range(num_users)]
auc_test_collect = [[] for _ in range(num_users)]

count1 = 0
count2 = 0

In [None]:
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

def calculate_metrics(fx, y, num_classes):
    preds = fx.argmax(dim=1)
    accuracy = torchmetrics.functional.accuracy(preds, y, task='multiclass', num_classes=num_classes)
    precision = torchmetrics.functional.precision(preds, y, average='macro', task='multiclass', num_classes=num_classes)
    recall = torchmetrics.functional.recall(preds, y, average='macro', task='multiclass', num_classes=num_classes)
    f1 = torchmetrics.functional.f1_score(preds, y, average='macro', task='multiclass', num_classes=num_classes)
    auc = torchmetrics.functional.auroc(F.softmax(fx, dim=1), y, task='multiclass', num_classes=num_classes)
    return accuracy, precision, recall, f1, auc

w_glob_server = net_glob_server.state_dict()
w_locals_server = []

idx_collect = []
l_epoch_check = False
fed_check = False

net_model_server = [ResNet50_server_side(Baseblock, [3, 4, 6, 3], num_classes) for i in range(num_users)]

net_server = copy.deepcopy(net_model_server[0]).to(device)
optimizer_server = torch.optim.Adam(net_server.parameters(), lr=lr)

In [None]:
def train_server(fx_client, y, l_epoch_count, l_epoch, idx, len_batch, num_classes=4):
    global net_model_server, criterion, optimizer_server, device
    global batch_acc_train, batch_loss_train, batch_precision_train, batch_recall_train, batch_auc_train
    global l_epoch_check, fed_check, loss_train_collect, acc_train_collect, count1
    global acc_avg_all_user_train, loss_avg_all_user_train, idx_collect, w_locals_server, w_glob_server, net_server
    global loss_train_collect_user, acc_train_collect_user, lr

    net_server = net_model_server[idx].to(device)
    net_server.train()
    optimizer_server = torch.optim.Adam(net_server.parameters(), lr=lr)

    optimizer_server.zero_grad()
    fx_client = fx_client.to(device)
    y = y.to(device)

    fx_server = net_server(fx_client)
    loss = criterion(fx_server, y)
    acc, precision, recall, f1, auc = calculate_metrics(fx_server, y, num_classes)

    loss.backward()
    dfx_client = fx_client.grad.clone().detach()
    optimizer_server.step()

    batch_loss_train.append(loss.item())
    batch_acc_train.append(acc.item())
    batch_precision_train.append(precision.item())
    batch_recall_train.append(recall.item())
    batch_auc_train.append(auc.item())

    net_model_server[idx] = net_server

    count1 += 1
    if count1 == len_batch:
        acc_avg_train = sum(batch_acc_train) / len(batch_acc_train)
        loss_avg_train = sum(batch_loss_train) / len(batch_loss_train)
        precision_avg_train = sum(batch_precision_train) / len(batch_precision_train)
        recall_avg_train = sum(batch_recall_train) / len(batch_recall_train)
        auc_avg_train = sum(batch_auc_train) / len(batch_auc_train)

        batch_acc_train.clear()
        batch_loss_train.clear()
        batch_precision_train.clear()
        batch_recall_train.clear()
        batch_auc_train.clear()
        count1 = 0

        print(f'Client{idx} Train => Local Epoch: {l_epoch_count} \tAcc: {acc_avg_train:.3f} \tLoss: {loss_avg_train:.4f} \tPrecision: {precision_avg_train:.3f} \tRecall: {recall_avg_train:.3f} \tAUC: {auc_avg_train:.3f}')

        w_server = net_server.state_dict()

        if l_epoch_count == l_epoch - 1:
            l_epoch_check = True
            w_locals_server.append(copy.deepcopy(w_server))

            acc_avg_train_all = acc_avg_train
            loss_avg_train_all = loss_avg_train

            loss_train_collect_user.append(loss_avg_train_all)
            acc_train_collect_user.append(acc_avg_train_all)

            if idx not in idx_collect:
                idx_collect.append(idx)

        if len(idx_collect) == num_users:
            fed_check = True
            w_glob_server = FedAvg(w_locals_server)
            net_glob_server.load_state_dict(w_glob_server)
            net_model_server = [copy.deepcopy(net_glob_server) for _ in range(num_users)]
            w_locals_server.clear()
            idx_collect.clear()

            acc_avg_all_user_train = sum(acc_train_collect_user) / len(acc_train_collect_user)
            loss_avg_all_user_train = sum(loss_train_collect_user) / len(loss_train_collect_user)

            loss_train_collect.append(loss_avg_all_user_train)
            acc_train_collect.append(acc_avg_all_user_train)

            acc_train_collect_user.clear()
            loss_train_collect_user.clear()

    return dfx_client

In [1]:
def evaluate_server(fx_client, y, idx, len_batch, ell, num_classes=4):
    global net_model_server, criterion, batch_acc_test, batch_loss_test, check_fed, net_server, net_glob_server
    global loss_test_collect, acc_test_collect, count2, num_users, acc_avg_train_all, loss_avg_train_all, w_glob_server, l_epoch_check, fed_check
    global loss_test_collect_user, acc_test_collect_user, acc_avg_all_user_train, loss_avg_all_user_train

    batch_precision_test = []
    acc_test_collect_user = []
    batch_recall_test = []
    batch_auc_test = []
    batch_f1_test = []

    net = copy.deepcopy(net_model_server[idx]).to(device)
    net.eval()

    with torch.no_grad():
        fx_client = fx_client.to(device)
        y = y.to(device)

        fx_server = net(fx_client)
        loss = criterion(fx_server, y)

        acc, precision, recall, f1, auc = calculate_metrics(fx_server, y, num_classes)

        batch_loss_test.append(loss.item())
        batch_acc_test.append(acc.item())
        batch_precision_test.append(precision.item())
        batch_recall_test.append(recall.item())
        batch_auc_test.append(auc.item())
        batch_f1_test.append(f1.item())

        count2 += 1
        if count2 == len_batch:
            acc_avg_test = sum(batch_acc_test) / len(batch_acc_test)
            loss_avg_test = sum(batch_loss_test) / len(batch_loss_test)
            precision_avg_test = sum(batch_precision_test) / len(batch_precision_test)
            recall_avg_test = sum(batch_recall_test) / len(batch_recall_test)
            auc_avg_test = sum(batch_auc_test) / len(batch_auc_test)
            f1_avg_test = sum(batch_f1_test) / len(batch_f1_test)

            batch_acc_test = []
            batch_loss_test = []
            batch_precision_test = []
            batch_recall_test = []
            batch_auc_test = []
            batch_f1_test = []
            count2 = 0

            print('Client{} Test =>                   \tAcc: {:.3f} \tLoss: {:.4f} \tPrecision: {:.3f} \tRecall: {:.3f} \tAUC: {:.3f} \tF1-Score: {:.3f}'.format(
                idx, acc_avg_test, loss_avg_test, precision_avg_test, recall_avg_test, auc_avg_test, f1_avg_test))

            if l_epoch_check:
                l_epoch_check = False

                acc_avg_test_all = acc_avg_test
                loss_avg_test_all = loss_avg_test

                loss_test_collect_user.append(loss_avg_test_all)
                acc_test_collect_user.append(acc_avg_test_all)

            if fed_check:
                fed_check = False
                print("------------------------------------------------")
                print("------ Federation process at Server-Side ------- ")
                print("------------------------------------------------")

                acc_avg_all_user = sum(acc_test_collect_user) / len(acc_test_collect_user)
                loss_avg_all_user = sum(loss_test_collect_user) / len(loss_test_collect_user)

                loss_test_collect.append(loss_avg_all_user)
                acc_test_collect.append(acc_avg_all_user)
                acc_test_collect_user = []
                loss_test_collect_user = []

                print("====================== SERVER V1==========================")
                print(' Train: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(ell, acc_avg_all_user_train, loss_avg_all_user_train))
                print(' Test: Round {:3d}, Avg Accuracy {:.3f} | Avg Loss {:.3f}'.format(ell, acc_avg_all_user, loss_avg_all_user))
                print("==========================================================")
    return

In [None]:
def evaluate_accuracy(net, loader, device, return_conf_matrix=False, num_classes=4):
    net.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    
    all_preds = torch.tensor(all_preds)
    all_labels = torch.tensor(all_labels)

    
    accuracy = torchmetrics.functional.accuracy(all_preds, all_labels, task='multiclass', num_classes=num_classes).item()
    precision = torchmetrics.functional.precision(all_preds, all_labels, average='macro', task='multiclass', num_classes=num_classes).item()
    recall = torchmetrics.functional.recall(all_preds, all_labels, average='macro', task='multiclass', num_classes=num_classes).item()
    f1 = torchmetrics.functional.f1_score(all_preds, all_labels, average='macro', task='multiclass', num_classes=num_classes).item()
    auc = torchmetrics.functional.auroc(F.softmax(outputs, dim=1), all_labels, task='multiclass', num_classes=num_classes).item()


    conf_matrix = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())

    if return_conf_matrix:
        return accuracy, precision, recall, f1, auc, conf_matrix
    else:
        return accuracy, precision, recall, f1, auc

class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class Client(object):
    def __init__(self, net_client_model, idx, lr, device, train_loader=None, test_loader=None, idxs=None, idxs_test=None):
        self.idx = idx
        self.device = device
        self.lr = lr
        self.local_ep = 1
        self.ldr_train = train_loader
        self.ldr_test = test_loader

    def train(self, net):
        net.train()
        optimizer_client = torch.optim.Adam(net.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer_client, step_size=30, gamma=0.1)

        for iter in range(self.local_ep):
            len_batch = len(self.ldr_train)

            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer_client.zero_grad()
                fx = net(images)
                client_fx = fx.clone().detach().requires_grad_(True)
                dfx = train_server(client_fx, labels, iter, self.local_ep, self.idx, len_batch)
                fx.backward(dfx)
                optimizer_client.step()

            scheduler.step()

        return net.state_dict()

    def evaluate(self, net, ell, num_classes=4):
        net.eval()

        with torch.no_grad():
            len_batch = len(self.ldr_test)
            for batch_idx, (images, labels) in enumerate(self.ldr_test):
                images, labels = images.to(self.device), labels.to(self.device)
                fx = net(images)
                evaluate_server(fx, labels, self.idx, len_batch, ell, num_classes)

        return

def dataset_iid(dataset, num_users):
    num_items = int(len(dataset) / num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


##Treino!


In [None]:
epoch_times = []

dict_users = dataset_iid(train_dataset, num_users)
dict_users_test = dataset_iid(test_dataset, num_users)
net_glob_client.train()

w_glob_client = net_glob_client.state_dict()


for iter in range(epochs):
    start_time = time.time()
    m = max(int(frac * num_users), 1)
    idxs_users = np.random.choice(range(num_users), m, replace=False)
    w_locals_client = []

    for idx in idxs_users:
        local = Client(
            net_glob_client,
            idx,
            lr,
            device,
            train_loader=train_loader,
            test_loader=test_loader,
            idxs=dict_users[idx],
            idxs_test=dict_users_test[idx]
        )

        w_client = local.train(net=copy.deepcopy(net_glob_client).to(device))
        w_locals_client.append(copy.deepcopy(w_client))

        local.evaluate(net=copy.deepcopy(net_glob_client).to(device), ell=iter)

    print("-----------------------------------------------------------")
    print("------ FedServer: Federation process at Client-Side ------- ")
    print("-----------------------------------------------------------")

    w_glob_client = FedAvg(w_locals_client)
    net_glob_client.load_state_dict(w_glob_client)

    epoch_time = time.time() - start_time
    epoch_times.append(epoch_time)
    print(f"Epoch {iter + 1}/{epochs} - Time taken: {epoch_time:.2f} seconds")

print("Training and Evaluation completed!")
