In [1]:
import os, copy, random, math, time, glob
from collections import Counter

import numpy as np
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms, models
import torchmetrics

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

In [None]:
num_users = 3
num_classes =  3
epochs = 80
frac = 1
lr= 1e-4
local_epochs   = 1
batch_size = 32
weight_decay =  5e-4
mu = 0.1
clip_grad      = 1.0   
server_lr       = 0.7
server_momentum = 0.8        
criterion = nn.CrossEntropyLoss()

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(torch.cuda.get_device_name(0))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data_path='/content/drive/MyDrive/ADNI'

train_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
test_tf = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

full_dataset = datasets.ImageFolder(root=data_path)

train_idx, test_idx = train_test_split(
    np.arange(len(full_dataset)),
    test_size=0.20,
    random_state=SEED,
    stratify=full_dataset.targets
)
train_subset = Subset(full_dataset, train_idx)
test_subset  = Subset(full_dataset, test_idx)

class CustomDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset, self.transform = subset, transform
    def __getitem__(self, i):
        img, y = self.subset[i]
        if self.transform: img = self.transform(img)
        return img, y
    def __len__(self): return len(self.subset)

train_loader = DataLoader(CustomDataset(train_subset, train_tf),
                          batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(CustomDataset(test_subset,  test_tf ),
                          batch_size=batch_size, shuffle=False)

labels_train = [full_dataset.targets[i] for i in train_idx]
cls_weights  = compute_class_weight('balanced',
                                    classes=np.unique(labels_train),
                                    y=labels_train)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(cls_weights,
                                                    dtype=torch.float,
                                                    device=device))

In [None]:
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)): w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

In [None]:
def FedAvgM(w_locals, sizes, w_glob, v_prev, beta=0.9, server_lr=1.0):
    total = sum(sizes)
    w_avg = {k: torch.zeros_like(p) for k,p in w_locals[0].items()
             if torch.is_floating_point(p)}

    for w,s in zip(w_locals, sizes):
        for k in w_avg:
            w_avg[k] += w[k] * (s / total)

    new_state = {}
    for k, p_glob in w_glob.items():
        if not torch.is_floating_point(p_glob):
            new_state[k] = p_glob           
            continue
        delta = w_avg[k] - p_glob
        v_prev[k] = beta * v_prev.get(k, torch.zeros_like(delta)) + delta
        new_state[k] = p_glob + server_lr * v_prev[k]

    return new_state, v_prev

In [None]:
def reset_bn_stats(model, loader):
    model.train()
    for x,_ in loader: model(x.to(device))
    model.eval()
    
@torch.no_grad()
def evaluate_global(model, loader):
    model.eval()
    loss, logits_all, y_all, pred_all = 0.0, [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss += criterion(out, y).item() * x.size(0)
        logits_all.append(F.softmax(out, 1).cpu())
        y_all.extend(y.cpu().tolist())
        pred_all.extend(out.argmax(1).cpu().tolist())

    y_t   = torch.tensor(y_all)
    pred_t= torch.tensor(pred_all)
    logit_t = torch.cat(logits_all)
    kwargs = {'task':'multiclass', 'num_classes':num_classes}

    acc = torchmetrics.functional.accuracy(pred_t, y_t, **kwargs).item()
    f1  = torchmetrics.functional.f1_score(pred_t, y_t, average='macro', **kwargs).item()
    auc = torchmetrics.functional.auroc(logit_t, y_t, **kwargs).item()
    return loss/len(loader.dataset), acc, f1, auc

In [None]:
class ResNet50_client_side(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2
        )
    def forward(self, x):
        return self.features(x)

class ResNet50_server_side(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.features = nn.Sequential(
            resnet.layer3,
            resnet.layer4,
            resnet.avgpool
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(resnet.fc.in_features, num_classes) 
        )
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)


class CombinedModel(nn.Module):
    def __init__(self, client, server):
        super().__init__(); self.client, self.server = client, server
    def forward(self, x): return self.server(self.client(x))
    
class Client:
    def __init__(self, idx, train_ids, test_ids, net_c_glob, net_s_glob):
        self.device = device
        tr_sub = Subset(train_subset.dataset,
                        [train_subset.indices[i] for i in train_ids])
        te_sub = Subset(test_subset.dataset,
                        [test_subset.indices[i] for i in test_ids])
        self.tr_loader = DataLoader(CustomDataset(tr_sub, train_tf),
                                    batch_size=batch_size, shuffle=True)
        self.te_loader = DataLoader(CustomDataset(te_sub, test_tf),
                                    batch_size=batch_size, shuffle=False)

        self.net_c = copy.deepcopy(net_c_glob).to(self.device)
        self.net_s = copy.deepcopy(net_s_glob).to(self.device)

        self.opt_c = torch.optim.Adam(self.net_c.parameters(),
                                      lr=lr, weight_decay=weight_decay)
        self.opt_s = torch.optim.Adam(self.net_s.parameters(),
                                      lr=lr, weight_decay=weight_decay)

    def train(self):
        self.net_c.train(); self.net_s.train()
        loss_sum, correct, total = 0.0, 0, 0
        for _ in range(local_epochs):
            for x, y in self.tr_loader:
                x, y = x.to(self.device), y.to(self.device)
                self.opt_c.zero_grad(); self.opt_s.zero_grad()
                out = self.net_s(self.net_c(x))
                loss = criterion(out, y)
                loss.backward()

                if mu>0:
                    with torch.no_grad():
                        for pg, pl in zip(net_glob_c.parameters(),
                                          self.net_c.parameters()):
                            pl.grad.add_(mu * (pl - pg.to(self.device)))

                if clip_grad:
                    torch.nn.utils.clip_grad_norm_(self.net_c.parameters(),
                                                   clip_grad)
                    torch.nn.utils.clip_grad_norm_(self.net_s.parameters(),
                                                   clip_grad)

                self.opt_c.step(); self.opt_s.step()

                loss_sum += loss.item() * x.size(0)
                correct  += (out.argmax(1)==y).sum().item()
                total    += y.size(0)

        return (self.net_c.state_dict(), self.net_s.state_dict(),
                loss_sum/total, correct/total)

    @torch.no_grad()
    def evaluate(self, net_c_glob, net_s_glob):
        model = CombinedModel(net_c_glob, net_s_glob).to(self.device)
        return evaluate_global(model, self.te_loader)[:2]

def dataset_noniid(idx_array, num_users, alpha=0.5):
    rng = np.random.default_rng(SEED)
    labels = np.array(full_dataset.targets)[idx_array]
    idx = np.arange(len(idx_array))
    dict_users = {i: [] for i in range(num_users)}

    for c in range(num_classes):
        idx_cls = idx[labels == c]
        rng.shuffle(idx_cls)
        proportions = rng.dirichlet(np.repeat(alpha, num_users))
        split_points = (np.cumsum(proportions) * len(idx_cls)).astype(int)[:-1]
        splits = np.split(idx_cls, split_points)
        for i, part in enumerate(splits):
            dict_users[i].extend(part.tolist())

    for i in dict_users: rng.shuffle(dict_users[i])
    return dict_users

dict_users_train = dataset_noniid(train_idx, num_users)
dict_users_test  = dataset_noniid(test_idx,  num_users)



In [None]:
net_glob_c = ResNet50_client_side().to(device)
net_glob_s = ResNet50_server_side(num_classes).to(device)

v_c = {k: torch.zeros_like(p) for k,p in net_glob_c.state_dict().items()}
v_s = {k: torch.zeros_like(p) for k,p in net_glob_s.state_dict().items()}

w_c_locals, w_s_locals, sizes_local = [], [], []

client_metrics = {m:{i:[] for i in range(num_users)}
                  for m in ('train_loss','train_acc','test_loss','test_acc')}
server_metrics = {m:[] for m in ('train_loss','train_acc','test_loss','test_acc')}

print("\n===== Iniciando treinamento (FedAvgM) =====")
for rnd in range(1, epochs+1):
    print(f"\n--- Rodada {rnd}/{epochs} ---")
    selected = np.random.choice(range(num_users),
                                max(int(frac*num_users),1), replace=False)
    w_c_locals, w_s_locals = [], []

    for idx in selected:
        user = Client(idx,
                      dict_users_train[idx], dict_users_test[idx],
                      net_glob_c, net_glob_s)
        w_c, w_s, l, a = user.train()
        w_c_locals.append(w_c); w_s_locals.append(w_s)
        sizes_local.append(len(dict_users_train[idx]))
        client_metrics['train_loss'][idx].append(l)
        client_metrics['train_acc' ][idx].append(a)
        print(f"Cliente {idx}: loss={l:.4f} acc={a:.4f}")

    # --- Agregação FedAvgM ---------------------------------------------------
    new_c, v_c = FedAvgM(w_c_locals, sizes_local,
                     net_glob_c.state_dict(), v_c,
                     beta=server_momentum, server_lr=server_lr)
    new_s, v_s = FedAvgM(w_s_locals, sizes_local,
                     net_glob_s.state_dict(), v_s,
                     beta=server_momentum, server_lr=server_lr)
    net_glob_c.load_state_dict(new_c)
    net_glob_s.load_state_dict(new_s)
    reset_bn_stats(CombinedModel(net_glob_c, net_glob_s), train_loader)
    print("Agregação (FedAvgM) concluída.")


    g_model = CombinedModel(net_glob_c.eval(), net_glob_s.eval())
    tr_l, tr_a, _, _ = evaluate_global(g_model, train_loader)
    te_l, te_a, _, _ = evaluate_global(g_model, test_loader)
    server_metrics['train_loss'].append(tr_l)
    server_metrics['train_acc' ].append(tr_a)
    server_metrics['test_loss' ].append(te_l)
    server_metrics['test_acc' ].append(te_a)
    print(f"[Servidor] TrainAcc:{tr_a:.4f}  TestAcc:{te_a:.4f}")

    # --- Avaliação local pós‑agregação ---------------------------------------
    for idx in range(num_users):
        if idx not in selected:
            client_metrics['train_loss'][idx].append(None)
            client_metrics['train_acc' ][idx].append(None)
        c_l, c_a = Client(idx,
                          dict_users_train[idx], dict_users_test[idx],
                          net_glob_c, net_glob_s).evaluate(net_glob_c, net_glob_s)
        client_metrics['test_loss'][idx].append(c_l)
        client_metrics['test_acc' ][idx].append(c_a)

print("\n===== Treinamento concluído =====")