In [None]:
!pip install torchmetrics



In [None]:
import os, json, copy, time, random
from pathlib import Path
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as F_tm
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset, Subset, Sampler
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
SEED = 12
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Usando GPU: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ajuste para seu caminho
data_path = '/content/drive/MyDrive/ADNI4'

Usando GPU: Tesla T4


In [None]:
os.makedirs("artifacts", exist_ok=True)

experiment_config = {
    "experiment_name": "SlitFedAlzheimer_FedAvg",
    "threat_model": {
        "adversary_type": "passive",
        "knowledge": "white-box",
        "compromised_clients": 0.0,
        "goal": "inference"
    },
    "privacy": {
        "use_dp": False,
        "dp_epsilon": None,
        "dp_delta": None,
        "gradient_clipping": None,
        "secure_aggregation": False
    },
    "reproducibility": {"seed": SEED}
}
with open("artifacts/experiment_config.json","w") as f:
    json.dump(experiment_config, f, indent=2)

CM_EVERY = 5

In [None]:
class RandomNoise(object):
    def __init__(self, p=0.5, mean=0., std=0.1):
        self.p,self.mean,self.std=p,mean,std
    def __call__(self, tensor):
        if torch.rand(1).item() < self.p:
            return tensor + torch.randn(tensor.size(), device=tensor.device) * self.std + self.mean
        return tensor
    def __repr__(self): return f"{self.__class__.__name__}(p={self.p}, mean={self.mean}, std={self.std})"

class CustomDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset, self.transform = subset, transform
    def __getitem__(self, index):
        image, label = self.subset[index]
        if self.transform: image = self.transform(image)
        return image, label
    def __len__(self): return len(self.subset)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(0.85,1.15)),
    transforms.ToTensor(),
    RandomNoise(p=0.2, mean=0., std=0.08),
    transforms.Normalize([0.5]*3, [0.5]*3),
])
test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3,[0.5]*3)
])

full_dataset = datasets.ImageFolder(root=data_path)

# Paths & labels
all_samples = full_dataset.samples  # list of (path, class_idx)
all_paths   = [p for p, _ in all_samples]
labels_all  = np.array([lbl for _, lbl in all_samples], dtype=int)
classes = full_dataset.classes
num_classes = len(classes)

In [None]:
import re

root = Path(data_path)
_PATTERNS = [
    re.compile(r'(\d{3}_S_\d{4})'),               # 002_S_0295
    re.compile(r'ADNI[_-](\d{3}_S_\d{4})', re.I), # ADNI_002_S_0295
    re.compile(r'PTID[_-]?(\d{3}_S_\d{4})', re.I)
]

def extract_patient_id_from_filename(path_str: str) -> str:
    fname = Path(path_str).name
    for pat in _PATTERNS:
        m = pat.search(fname)
        if m:
            return m.group(1)
    # Fallback: if nothing matches, use the stem => one scan == one "patient"
    return Path(fname).stem

patient_ids = np.array([extract_patient_id_from_filename(p) for p in all_paths])
IDX_TO_PATIENT = {i: patient_ids[i] for i in range(len(all_samples))}

In [None]:
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=SEED)
all_idx = np.arange(len(full_dataset))
train_idx, test_idx = next(sgkf.split(all_idx, y=labels_all, groups=patient_ids))

# Sanity: patient disjoint
train_pats = set(patient_ids[train_idx]); test_pats = set(patient_ids[test_idx])
assert train_pats.isdisjoint(test_pats), "Leakage: patient appears in both train and test!"

# Datasets
main_train_dataset = CustomDataset(Subset(full_dataset, train_idx), transform=train_transform)
main_test_dataset  = CustomDataset(Subset(full_dataset, test_idx),  transform=test_transform)


In [None]:
def partition_by_patient_balanced_no_empty(idx_array, idx_to_patient, labels_all, desired_users, seed=1234):
    """
    Patient-exclusive + class-balanced + no empty clients.
    If too few patients, reduces the effective number of users automatically.
    Returns (dict_users, eff_users).
    """
    rng = np.random.default_rng(seed)
    pats = sorted({idx_to_patient[i] for i in idx_array})
    if not pats:
        return {c: set() for c in range(desired_users)}, 0

    pat_to_idxs = defaultdict(list)
    for i in idx_array:
        pat_to_idxs[idx_to_patient[i]].append(i)

    def maj_label(p):
        cnt = Counter(labels_all[j] for j in pat_to_idxs[p])
        return int(cnt.most_common(1)[0][0])

    pats_by_class = defaultdict(list)
    for p in pats:
        pats_by_class[maj_label(p)].append(p)
    for c in list(pats_by_class.keys()):
        rng.shuffle(pats_by_class[c])

    eff_users = min(desired_users, len(pats))
    buckets = [set() for _ in range(eff_users)]

    # round-robin by class for diversity
    while any(pats_by_class.values()):
        for c in list(pats_by_class.keys()):
            if pats_by_class[c]:
                for b in range(eff_users):
                    if pats_by_class[c]:
                        buckets[b].add(pats_by_class[c].pop())

    # fix any empties by redistributing
    all_pats = [p for b in buckets for p in b]
    non_empty = [b for b in buckets if b]
    if len(non_empty) < eff_users:
        buckets = [set() for _ in range(len(non_empty) or 1)]
        for k, p in enumerate(all_pats):
            buckets[k % len(buckets)].add(p)
        eff_users = len(buckets)

    out = {b: set(i for p in buckets[b] for i in pat_to_idxs[p]) for b in range(eff_users)}
    for b in range(eff_users, desired_users):  # padding
        out[b] = set()
    return out, eff_users

def class_dist_for(idxs, labels_all, classes):
    if len(idxs) == 0:
        return {c: 0 for c in classes}
    arr = np.array([labels_all[i] for i in idxs], dtype=int)
    return dict(zip(classes, np.bincount(arr, minlength=len(classes)).tolist()))

In [None]:
class ResNet50_client_side(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        # freeze early layers
        for name, param in resnet.named_parameters():
            if name.startswith(("conv1", "bn1")):
                param.requires_grad = False
        self.features = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1, resnet.layer2
        )
    def forward(self, x):
        return self.features(x)

class ResNet50_server_side(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.features = nn.Sequential(resnet.layer3, resnet.layer4)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        num_ftrs = resnet.fc.in_features
        self.classifier = nn.Sequential(nn.Dropout(p=0.5), nn.Linear(num_ftrs, num_classes))
    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = torch.flatten(x,1)
        return self.classifier(x)

class CombinedModel(nn.Module):
    def __init__(self, client_model, server_model):
        super().__init__()
        self.client_model=client_model
        self.server_model=server_model
    def forward(self, x):
        return self.server_model(self.client_model(x))


In [None]:
class DenseNet169_client_side(nn.Module):
    def __init__(self):
        super().__init__()
        dn = models.densenet169(weights=models.DenseNet169_Weights.DEFAULT)
        feats = dn.features  # Sequential
        # congela bem o começo
        for n, p in dn.named_parameters():
            if any(n.startswith(k) for k in ["features.conv0", "features.norm0"]):
                p.requires_grad = False
        self.features = nn.Sequential(
            feats.conv0, feats.norm0, feats.relu0, feats.pool0,
            feats.denseblock1, feats.transition1,
            feats.denseblock2
        )
    def forward(self, x):
        return self.features(x)

class DenseNet169_server_side(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        dn = models.densenet169(weights=models.DenseNet169_Weights.DEFAULT)
        feats = dn.features
        self.features_tail = nn.Sequential(
            feats.transition2,
            feats.denseblock3, feats.transition3,
            feats.denseblock4, feats.norm5
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        num_ftrs = dn.classifier.in_features  # 1664
        self.classifier = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        x = self.features_tail(x)
        x = F.relu(x, inplace=True)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

def make_models(backbone: str, num_classes: int):
    b = backbone.lower()
    if b == "resnet50":
        net_glob_client = ResNet50_client_side().to(device)
        net_glob_server = ResNet50_server_side(num_classes=num_classes).to(device)
    elif b == "densenet169":
        net_glob_client = DenseNet169_client_side().to(device)
        net_glob_server = DenseNet169_server_side(num_classes=num_classes).to(device)
    else:
        raise ValueError(f"Unknown backbone: {backbone}")
    return net_glob_client, net_glob_server

In [None]:
class BalancedBatchSampler(Sampler):
    def __init__(self, dataset: CustomDataset, n_classes: int, n_per_class: int):
        self.labels = [lbl for _, lbl in dataset]
        self.class_to_idxs = defaultdict(list)
        for i, y in enumerate(self.labels):
            self.class_to_idxs[int(y)].append(i)

        max_len = max((len(v) for v in self.class_to_idxs.values() if len(v) > 0), default=0)
        if max_len == 0:
            self.length = 0
            self.n_classes = n_classes
            self.n_per_class = n_per_class
            return

        for c in range(n_classes):
            if len(self.class_to_idxs[c]) == 0:
                for alt in range(n_classes):
                    if len(self.class_to_idxs[alt]) > 0:
                        self.class_to_idxs[c] = [self.class_to_idxs[alt][0]]
                        break
            while len(self.class_to_idxs[c]) < max_len:
                self.class_to_idxs[c].extend(self.class_to_idxs[c])

        self.n_classes = n_classes
        self.n_per_class = n_per_class
        self.length = max_len // n_per_class  # inteiro OK

    def __len__(self):
        return self.length

    def __iter__(self):
        import random as _random
        if self.length == 0:
            return iter([])
        class_iters = {c: iter(_random.sample(idxs, len(idxs))) for c, idxs in self.class_to_idxs.items()}
        for _ in range(self.length):
            batch = []
            for c in range(self.n_classes):
                for _ in range(self.n_per_class):
                    try:
                        batch.append(next(class_iters[c]))
                    except StopIteration:
                        class_iters[c] = iter(self.class_to_idxs[c])
                        batch.append(next(class_iters[c]))
            yield batch

In [None]:
class EarlyStopping:
    def __init__(self,patience=20,verbose=True,delta=0,save_path='chkp.pt'):
        self.patience,self.verbose,self.delta,self.save_path=patience,verbose,delta,save_path
        self.counter,self.best_score,self.early_stop,self.val_loss_min=0,None,False,np.Inf
    def __call__(self,val_loss,models):
        score=-val_loss; net_client, net_server = models
        if self.best_score is None:
            self.best_score=score; self.save_checkpoint(val_loss, net_client, net_server)
        elif score<self.best_score+self.delta:
            self.counter+=1
            if self.verbose: print(f'EarlyStopping counter: {self.counter}/{self.patience}')
            if self.counter>=self.patience: self.early_stop=True
        else:
            self.best_score=score; self.save_checkpoint(val_loss, net_client, net_server); self.counter=0
    def save_checkpoint(self,val_loss,net_client,net_server):
        if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f}-->{val_loss:.6f}). Saving model to {self.save_path}')
        torch.save({'net_glob_client_state_dict':net_client.state_dict(),
                    'net_glob_server_state_dict':net_server.state_dict()}, self.save_path)
        self.val_loss_min=val_loss

def FedAvg_weighted(w_locals, counts):
    """Weighted average by client data size."""
    if not w_locals:
        return {}
    import copy as _copy, torch as _torch
    N = float(sum(counts)) if sum(counts) > 0 else 1.0
    w_avg = _copy.deepcopy(w_locals[0])
    for k in w_avg.keys():
        if not w_avg[k].dtype.is_floating_point:
            continue
        w_avg[k] = sum((w[k].float() * (n / N) for w, n in zip(w_locals, counts)))
    return w_avg

MAX_GRAD_NORM = 1.0

def train_server(fx_client, y, net_server, optimizer_server, criterion, num_classes):
    net_server.train(); optimizer_server.zero_grad()
    fx_server = net_server(fx_client)
    loss = criterion(fx_server, y)
    # batch accuracy only
    with torch.no_grad():
        preds = fx_server.argmax(dim=1)
        acc = F_tm.accuracy(preds, y, task='multiclass', num_classes=num_classes).item()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(net_server.parameters(), max_norm=MAX_GRAD_NORM)
    dfx_client = fx_client.grad.clone().detach()
    optimizer_server.step()
    return dfx_client, loss.item(), acc

def evaluate_loader_aggregated(net_client, net_server, loader, criterion, num_classes, device):
    net_client.eval(); net_server.eval()
    all_logits, all_labels = [], []
    total_loss, n_batches = 0.0, 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            fx_client = net_client(images)
            fx_server = net_server(fx_client)
            loss = criterion(fx_server, labels)
            total_loss += float(loss.item()); n_batches += 1
            all_logits.append(fx_server.detach().cpu())
            all_labels.append(labels.detach().cpu())

    if n_batches == 0:
        return 0.0, [0.0]*5

    logits = torch.cat(all_logits, dim=0)
    y_true = torch.cat(all_labels, dim=0)
    y_pred = logits.argmax(dim=1)

    acc  = F_tm.accuracy(y_pred, y_true, task='multiclass', num_classes=num_classes).item()
    prec = F_tm.precision(y_pred, y_true, average='macro', task='multiclass', num_classes=num_classes).item()
    rec  = F_tm.recall(y_pred, y_true, average='macro', task='multiclass', num_classes=num_classes).item()
    f1   = F_tm.f1_score(y_pred, y_true, average='macro', task='multiclass', num_classes=num_classes).item()
    try:
        auc = F_tm.auroc(torch.softmax(logits, dim=1), y_true, task='multiclass', num_classes=num_classes).item()
    except Exception:
        auc = 0.0

    avg_loss = total_loss / n_batches
    return avg_loss, [acc, prec, rec, f1, auc]

def evaluate_accuracy(net,loader,device,return_conf_matrix=False,num_classes=3):
    net.eval(); all_preds,all_labels,all_outputs=[],[],[]
    with torch.no_grad():
        for images,labels in loader:
            images,labels=images.to(device),labels.to(device)
            outputs=net(images); _,predicted=torch.max(outputs,1)
            all_outputs.append(outputs.cpu())
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    if len(all_preds) == 0:
        if return_conf_matrix:
            return 0,0,0,0,0, np.zeros((num_classes,num_classes), dtype=int)
        return 0,0,0,0,0
    all_outputs=torch.cat(all_outputs,dim=0)
    all_preds=torch.tensor(all_preds); all_labels=torch.tensor(all_labels)
    accuracy = F_tm.accuracy(all_preds,all_labels,task='multiclass',num_classes=num_classes).item()
    precision=F_tm.precision(all_preds,all_labels,average='macro',task='multiclass',num_classes=num_classes).item()
    recall  = F_tm.recall(all_preds,all_labels,average='macro',task='multiclass',num_classes=num_classes).item()
    f1      = F_tm.f1_score(all_preds,all_labels,average='macro',task='multiclass',num_classes=num_classes).item()
    try:
        auc     = F_tm.auroc(F.softmax(all_outputs,dim=1),all_labels,task="multiclass",num_classes=num_classes).item()
    except ValueError:
        auc = 0.0
    cm=confusion_matrix(all_labels.cpu().numpy(),all_preds.cpu().numpy())
    if return_conf_matrix: return accuracy,precision,recall,f1,auc,cm
    return accuracy,precision,recall,f1,auc

def _get_indices_from(dataset):
    if isinstance(dataset, Subset):
        return dataset.indices
    if isinstance(dataset, CustomDataset) and isinstance(dataset.subset, Subset):
        return dataset.subset.indices
    raise ValueError("evaluate_by_patient: preciso de Subset(...) ou CustomDataset(Subset(...)).")

def evaluate_by_patient(model, dataset_or_subset, idx_to_patient_map, batch_size=64, device="cuda"):
    loader = DataLoader(dataset_or_subset, batch_size=batch_size, shuffle=False, num_workers=0)
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            logits = model(images)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            all_probs.append(probs)
            all_labels.extend(labels.numpy())
    all_probs = np.concatenate(all_probs, axis=0)
    original_indices = _get_indices_from(dataset_or_subset)
    pats = [idx_to_patient_map[i] for i in original_indices]

    df = pd.DataFrame(all_probs, columns=[f"p_{c}" for c in range(all_probs.shape[1])])
    df["y"] = np.array(all_labels, dtype=int)
    df["patient"] = pats

    agg = df.groupby("patient").agg({f"p_{c}":"mean" for c in range(all_probs.shape[1])})
    y_true = df.groupby("patient")["y"].agg(lambda x: Counter(x).most_common(1)[0][0]).values
    P = agg.values
    y_pred = P.argmax(axis=1)

    from sklearn import metrics as skm
    acc  = (y_pred == y_true).mean().item()
    prec = skm.precision_score(y_true, y_pred, average="macro", zero_division=0)
    rec  = skm.recall_score(y_true, y_pred, average="macro", zero_division=0)
    f1   = skm.f1_score(y_true, y_pred, average="macro", zero_division=0)
    try:
        from sklearn.metrics import roc_auc_score
        auc = roc_auc_score(y_true, P, multi_class="ovr")
    except Exception:
        auc = 0.0
    cm   = skm.confusion_matrix(y_true, y_pred, labels=list(range(P.shape[1])))
    return acc, prec, rec, f1, auc, cm

In [None]:
class Client:
    def __init__(self, idx, lr, device, idxs, idxs_test, net_server_initial_weights, weight_decay,
                 batch_size, num_classes, criterion, mu, use_balanced_batch=True, epochs_total=80, local_ep=1):
        self.idx,self.device,self.lr = idx,device,lr
        self.local_ep = local_ep
        self.num_classes = num_classes
        self.criterion = criterion
        self.mu = mu
        self.epochs_total = epochs_total

        self.train_dataset = CustomDataset(Subset(full_dataset, list(idxs)), transform=train_transform)
        self.test_dataset  = CustomDataset(Subset(full_dataset, list(idxs_test)), transform=test_transform)

        WORKERS = 2

        labels_client = [lbl for _, lbl in self.train_dataset]
        counts = Counter(labels_client)
        n_per_class = max(1, batch_size // num_classes)
        use_balanced = (len(np.unique(labels_client)) == num_classes) and (min(counts.values()) >= n_per_class)

        if use_balanced and use_balanced_batch:
            batch_sampler = BalancedBatchSampler(self.train_dataset, n_classes=num_classes, n_per_class=n_per_class)
            try:
                _ = next(iter(batch_sampler))
                self.ldr_train = DataLoader(self.train_dataset, batch_sampler=batch_sampler,
                                            num_workers=WORKERS, pin_memory=True)
            except StopIteration:
                self.ldr_train = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True,
                                            num_workers=WORKERS, pin_memory=True)
        else:
            self.ldr_train = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True,
                                        num_workers=WORKERS, pin_memory=True)

        self.ldr_test = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False,
                                   num_workers=WORKERS, pin_memory=True)

        self.net_server_local_copy = DenseNet169_server_side(num_classes=num_classes).to(self.device)
        self.net_server_local_copy.load_state_dict(net_server_initial_weights)
        self.optimizer_server_local = torch.optim.Adam(self.net_server_local_copy.parameters(),
                                                       lr=self.lr, weight_decay=weight_decay)
        self.scheduler_server = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer_server_local,
                                                                           T_max=self.epochs_total, eta_min=1e-6)

    def evaluate(self, net_client, net_server):
        return evaluate_loader_aggregated(net_client, net_server, self.ldr_test,
                                          self.criterion, self.num_classes, self.device)

    def train(self, net_client, epoch_idx=None, progressive_unfreeze=False):
        net_client.train(); net_client.to(self.device)

        if progressive_unfreeze and epoch_idx is not None and epoch_idx >= 10:
            for name, p in net_client.named_parameters():
                if ("denseblock2" in name) or ("layer2" in name) or ("client_model.5" in name):
                    p.requires_grad = True

        client_params = []
        low_lr, high_lr = 1e-5, self.lr
        low_group, high_group = [], []
        for n,p in net_client.named_parameters():
            if not p.requires_grad:
                continue
            if ("denseblock2" in n) or ("layer2" in n) or ("client_model.5" in n):
                low_group.append(p)
            else:
                high_group.append(p)
        if low_group:  client_params.append({"params": low_group, "lr": low_lr})
        if high_group: client_params.append({"params": high_group, "lr": high_lr})
        if not client_params:
            client_params = [{"params": [p for p in net_client.parameters() if p.requires_grad], "lr": self.lr}]

        optimizer_client = torch.optim.Adam(client_params, weight_decay=1e-4)
        scheduler_client = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_client,
                                                                      T_max=self.epochs_total, eta_min=1e-6)

        global_weights = copy.deepcopy(net_client.state_dict())

        batch_losses, batch_accs = [], []
        for _ in range(self.local_ep):
            for images, labels in self.ldr_train:
                images, labels = images.to(self.device), labels.to(self.device)
                optimizer_client.zero_grad()

                fx = net_client(images)
                client_fx = fx.clone().detach().requires_grad_(True)

                dfx, loss, acc = train_server(client_fx, labels, self.net_server_local_copy,
                                              self.optimizer_server_local, self.criterion, self.num_classes)

                fx.backward(dfx)

                if self.mu > 0.0:
                    proximal_term = 0.0
                    for name, param in net_client.named_parameters():
                        if param.requires_grad:
                            proximal_term += torch.sum((param - global_weights[name].to(self.device))**2)
                    ((self.mu/2) * proximal_term).backward()

                torch.nn.utils.clip_grad_norm_(net_client.parameters(), max_norm=1.0)
                optimizer_client.step()

                batch_losses.append(loss); batch_accs.append(float(acc))

        scheduler_client.step(); self.scheduler_server.step()

        avg_loss = sum(batch_losses)/len(batch_losses) if batch_losses else 0.0
        avg_acc  = sum(batch_accs)/len(batch_accs)    if batch_accs  else 0.0
        return net_client.state_dict(), self.net_server_local_copy.state_dict(), avg_loss, [avg_acc,0,0,0,0], len(self.train_dataset)


In [None]:
def _safe_list(x):
    return x if (isinstance(x, (list, tuple)) and len(x) > 0) else [0.0]

def plot_history(history, num_users, save_dir="artifacts"):
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    epochs_range = range(1, len(_safe_list(history['train_acc'])) + 1)

    # Accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, _safe_list(history['train_acc']), marker='o', linewidth=2, label='Train Accuracy')
    plt.plot(epochs_range, _safe_list(history['test_acc']),  marker='s', linewidth=2, label='Validation Accuracy')
    plt.xlabel('Epoch', fontsize=13); plt.ylabel('Accuracy', fontsize=13)
    plt.title(f'Accuracy over Epochs ({num_users} Clients)', fontsize=15)
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.4); plt.tight_layout()
    plt.savefig(f"{save_dir}/accuracy_{num_users}c.png", dpi=300, bbox_inches='tight')
    plt.close()

    # Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, _safe_list(history['train_loss']), marker='o', linewidth=2, label='Train Loss')
    plt.plot(epochs_range, _safe_list(history['test_loss']),  marker='s', linewidth=2, label='Validation Loss')
    plt.xlabel('Epoch', fontsize=13); plt.ylabel('Loss', fontsize=13)
    plt.title(f'Loss over Epochs ({num_users} Clients)', fontsize=15)
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.4); plt.tight_layout()
    plt.savefig(f"{save_dir}/loss_{num_users}c.png", dpi=300, bbox_inches='tight')
    plt.close()

    # Recall
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, _safe_list(history['train_recall']), marker='o', linewidth=2, label='Train Recall (macro)')
    plt.plot(epochs_range, _safe_list(history['test_recall']),  marker='s', linewidth=2, label='Validation Recall (macro)')
    plt.xlabel('Epoch', fontsize=13); plt.ylabel('Recall', fontsize=13)
    plt.title(f'Recall over Epochs ({num_users} Clients)', fontsize=15)
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.4); plt.tight_layout()
    plt.savefig(f"{save_dir}/recall_{num_users}c.png", dpi=300, bbox_inches='tight')
    plt.close()

    # AUC
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_range, _safe_list(history['train_auc']), marker='o', linewidth=2, label='Train AUC (macro)')
    plt.plot(epochs_range, _safe_list(history['test_auc']),  marker='s', linewidth=2, label='Validation AUC (macro)')
    plt.xlabel('Epoch', fontsize=13); plt.ylabel('AUC', fontsize=13)
    plt.title(f'AUC over Epochs ({num_users} Clients)', fontsize=15)
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.4); plt.tight_layout()
    plt.savefig(f"{save_dir}/auc_{num_users}c.png", dpi=300, bbox_inches='tight')
    plt.close()

def plot_confusion_matrix(cm, class_names, num_users, save_dir="artifacts", suffix=""):
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cbar=True, cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Label', fontsize=13)
    plt.ylabel('True Label', fontsize=13)
    plt.title(f'Confusion Matrix ({num_users} Clients){suffix}', fontsize=15)
    plt.tight_layout()
    plt.savefig(f"{save_dir}/confusion_matrix_{num_users}c{suffix}.png", dpi=300, bbox_inches='tight')
    plt.close()

def summarize_distribution(dict_users, labels_all, classes, num_clients):
    """Build a DataFrame with per-class counts per client for one setup."""
    rows = []
    for c in sorted(dict_users.keys()):
        idxs = list(dict_users[c])
        if len(idxs) == 0:
            counts = {cls: 0 for cls in classes}
        else:
            labs = [int(labels_all[i]) for i in idxs]
            binc = np.bincount(labs, minlength=len(classes)).tolist()
            counts = dict(zip(classes, binc))
        rows.append({
            **counts,
            "client": f"C{c+1}",
            "setup":  f"{num_clients} clients",
            "group":  f"{num_clients}c-C{c+1}"
        })
    return pd.DataFrame(rows)

def plot_single_distribution_357(snapshots, labels_all, classes, save_path="artifacts/distribution_3_5_7_clients.png"):
    """
    Create ONE chart with class distributions per client
    for the 3, 5 and 7 clients setups.
    """
    dfs = []
    for k in sorted(snapshots.keys()):  # k = 3, 5, 7 (if available)
        df_k = summarize_distribution(snapshots[k], labels_all, classes, k)
        dfs.append(df_k)
    if not dfs:
        print("[WARN] No distribution snapshot available for plotting.")
        return
    df_all = pd.concat(dfs, ignore_index=True)

    # Long format
    df_melt = df_all.melt(id_vars=["setup", "client", "group"],
                          value_vars=classes, var_name="Class", value_name="Samples")

    # X-axis order: 3c-C1..C3, 5c-C1..C5, 7c-C1..C7
    def _sort_key(g):
        left, right = g.split("-C")   # "3c", "1"
        return (int(left.replace("c","").strip()), int(right))
    df_melt = df_melt.sort_values(by="group", key=lambda s: s.map(_sort_key))

    plt.figure(figsize=(14, 6))
    sns.barplot(data=df_melt, x="group", y="Samples", hue="Class", dodge=True)
    plt.title("Class distribution per client for 3 / 5 / 7 clients")
    plt.xlabel("Setup–Client (e.g., 3c-C1 = 3 clients, client 1)")
    plt.ylabel("Number of samples")
    plt.xticks(rotation=45, ha="right")
    plt.grid(axis="y", linestyle="--", alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"[OK] Distribution chart saved at: {save_path}")

In [None]:
client_counts = [3,5,7]  # scale experiments
batch_size = 64
epochs, frac, lr, weight_decay = 5, 1.0, 1e-4, 1e-4
mu = 0.01  # FedProx (0.0 => FedAvg)
progressive_unfreeze = True
local_ep_default = 1

# Class weights from global TRAIN
labels_train_global = [label for _, label in main_train_dataset]
present_classes = np.unique(labels_train_global)
num_classes = len(classes)
full_weights = np.ones((num_classes,), dtype=np.float32)
if len(present_classes) > 0:
    partial_weights = compute_class_weight(class_weight='balanced', classes=present_classes, y=labels_train_global)
    for c, w in zip(present_classes, partial_weights):
        full_weights[int(c)] = float(w)

criterion_global = nn.CrossEntropyLoss(weight=torch.tensor(full_weights, dtype=torch.float32, device=device),
                                       label_smoothing=0.05)

dict_users_snapshots = {}

for num_users in client_counts:
    print("\n" + "="*80)
    print(f"STARTING TRAINING FOR {num_users} CLIENTS")
    print("="*80 + "\n")

    # Partition train/test
    dict_users, eff_train = partition_by_patient_balanced_no_empty(train_idx, IDX_TO_PATIENT, labels_all, num_users, seed=SEED)
    dict_users_test, eff_test = partition_by_patient_balanced_no_empty(test_idx,  IDX_TO_PATIENT, labels_all, num_users, seed=SEED)

    # Retry seeds if empties
    tries = 0
    while (any(len(dict_users[c])==0 for c in range(num_users))) and tries < 5:
        tries += 1
        alt_seed = SEED + 100*tries
        dict_users, eff_train = partition_by_patient_balanced_no_empty(train_idx, IDX_TO_PATIENT, labels_all, num_users, seed=alt_seed)
        dict_users_test, eff_test = partition_by_patient_balanced_no_empty(test_idx,  IDX_TO_PATIENT, labels_all, num_users, seed=alt_seed)

    train_clients = [c for c in range(num_users) if len(dict_users[c]) > 0]
    empty_clients = [c for c in range(num_users) if len(dict_users[c]) == 0]
    if empty_clients:
        print(f"[Warn] Empty clients (ignored): {empty_clients}")

    # Audit print
    for c in range(num_users):
        tr_dist = class_dist_for(list(dict_users[c]), labels_all, classes)
        te_dist = class_dist_for(list(dict_users_test[c]), labels_all, classes)
        print(f"[TRAIN] client {c}: n={len(dict_users[c])}  dist={tr_dist}")
        print(f"[TEST ] client {c}: n={len(dict_users_test[c])} dist={te_dist}")

    dict_users_snapshots[num_users] = copy.deepcopy(dict_users)

    # Models (DenseNet169 split)
    net_glob_client = DenseNet169_client_side().to(device)
    net_glob_server = DenseNet169_server_side(num_classes=num_classes).to(device)
    w_glob_client = net_glob_client.state_dict()
    w_glob_server = net_glob_server.state_dict()

    history = {
        'train_loss':[], 'train_acc':[], 'test_loss':[], 'test_acc':[],
        'train_recall':[], 'test_recall':[], 'train_auc':[], 'test_auc':[]
    }

    checkpoint_path = f'best_model_{num_users}_clients.pt'
    early_stopping = EarlyStopping(patience=20, verbose=True, save_path=checkpoint_path)

    # ===== TRAIN LOOP =====
    for iter_epoch in range(epochs):
        start_time = time.time()
        m = max(int(frac * len(train_clients)), 1)
        idxs_users = np.random.choice(train_clients, m, replace=False)

        w_locals_client, w_locals_server, ns = [], [], []
        round_train_loss, round_train_metrics = [], []

        # STEP 1: LOCAL CLIENT TRAINING
        for idx in idxs_users:
            local = Client(idx, lr, device, dict_users[idx], dict_users_test[idx],
                           w_glob_server, weight_decay, batch_size, num_classes, criterion_global, mu,
                           use_balanced_batch=True, epochs_total=epochs, local_ep=local_ep_default)
            w_c, w_s, t_loss, t_metrics, n_i = local.train(net_client=copy.deepcopy(net_glob_client),
                                                           epoch_idx=iter_epoch,
                                                           progressive_unfreeze=progressive_unfreeze)
            w_locals_client.append(w_c)
            w_locals_server.append(w_s)
            ns.append(n_i)
            round_train_loss.append(t_loss)
            round_train_metrics.append(t_metrics)

        # STEP 2: SERVER AGGREGATION
        w_glob_client = FedAvg_weighted(w_locals_client, ns)
        w_glob_server = FedAvg_weighted(w_locals_server, ns)
        net_glob_client.load_state_dict(w_glob_client)
        net_glob_server.load_state_dict(w_glob_server)

        # STEP 3: GLOBAL EVALUATION
        round_test_loss, round_test_metrics = [], []
        for idx in idxs_users:
            local = Client(idx, lr, device, dict_users[idx], dict_users_test[idx],
                           w_glob_server, weight_decay, batch_size, num_classes, criterion_global, mu,
                           use_balanced_batch=True, epochs_total=epochs)
            test_loss, test_metrics = local.evaluate(net_glob_client, net_glob_server)
            round_test_loss.append(test_loss)
            round_test_metrics.append(test_metrics)

        # History
        history['train_loss'].append(sum(round_train_loss) / len(round_train_loss) if round_train_loss else 0.0)
        if round_train_metrics:
            avg_train_metrics = [sum(col) / len(col) for col in zip(*round_train_metrics)]
        else:
            avg_train_metrics = [0.0]*5
        history['train_acc'].append(avg_train_metrics[0])
        history['train_recall'].append(avg_train_metrics[2])
        history['train_auc'].append(avg_train_metrics[4])

        history['test_loss'].append(sum(round_test_loss) / len(round_test_loss) if round_test_loss else 0.0)
        if round_test_metrics:
            avg_test_metrics = [sum(col) / len(col) for col in zip(*round_test_metrics)]
        else:
            avg_test_metrics = [0.0]*5
        history['test_acc'].append(avg_test_metrics[0])
        history['test_recall'].append(avg_test_metrics[2])
        history['test_auc'].append(avg_test_metrics[4])

        # Confusion Matrix (Validation) periódico
        if ((iter_epoch + 1) % CM_EVERY == 0) or (iter_epoch == 0):
            y_true_all, y_pred_all = [], []
            net_glob_client.eval(); net_glob_server.eval()
            with torch.no_grad():
                for idx in idxs_users:
                    local_eval = Client(
                        idx, lr, device, dict_users[idx], dict_users_test[idx],
                        w_glob_server, weight_decay, batch_size, num_classes, criterion_global, mu,
                        use_balanced_batch=True, epochs_total=epochs
                    )
                    for images, labels in local_eval.ldr_test:
                        images = images.to(device)
                        logits = net_glob_server(net_glob_client(images))
                        preds  = logits.argmax(dim=1).cpu().numpy()
                        y_pred_all.append(preds)
                        y_true_all.append(labels.numpy())
            if len(y_true_all) > 0:
                y_true = np.concatenate(y_true_all)
                y_pred = np.concatenate(y_pred_all)
                cm_epoch = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
                plot_confusion_matrix(
                    cm_epoch, classes, num_users,
                    save_dir="artifacts",
                    suffix=f"_epoch{iter_epoch+1:03d}"
                )

        print(
            f"Epoch {iter_epoch+1}/{epochs} | "
            f"Train Loss: {history['train_loss'][-1]:.4f} Acc: {history['train_acc'][-1]:.4f} | "
            f"Test  Loss: {history['test_loss'][-1]:.4f} "
            f"Acc: {history['test_acc'][-1]:.4f} "
            f"Rec: {history['test_recall'][-1]:.4f} "
            f"AUC: {history['test_auc'][-1]:.4f} | "
            f"Time: {time.time()-start_time:.2f}s"
        )

        early_stopping(history['test_loss'][-1], (net_glob_client, net_glob_server))
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            try:
                y_true_all, y_pred_all = [], []
                net_glob_client.eval(); net_glob_server.eval()
                with torch.no_grad():
                    for idx in idxs_users:
                        local_eval = Client(
                            idx, lr, device, dict_users[idx], dict_users_test[idx],
                            w_glob_server, weight_decay, batch_size, num_classes, criterion_global, mu,
                            use_balanced_batch=True, epochs_total=epochs
                        )
                        for images, labels in local_eval.ldr_test:
                            images = images.to(device)
                            logits = net_glob_server(net_glob_client(images))
                            preds  = logits.argmax(dim=1).cpu().numpy()
                            y_pred_all.append(preds)
                            y_true_all.append(labels.numpy())
                if len(y_true_all) > 0:
                    y_true = np.concatenate(y_true_all)
                    y_pred = np.concatenate(y_pred_all)
                    cm_last = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
                    plot_confusion_matrix(cm_last, classes, num_users, save_dir="artifacts", suffix="_last")
            except Exception as e:
                print(f"[WARN] Could not save last confusion matrix: {e}")
            break

    print(f"\nTraining for {num_users} clients completed!")
    plot_history(history, num_users, save_dir="artifacts")

    # Avaliação final global (image-level)
    try:
        checkpoint_path = f'best_model_{num_users}_clients.pt'
        checkpoint = torch.load(checkpoint_path, map_location=device)
        net_glob_client.load_state_dict(checkpoint['net_glob_client_state_dict'])
        net_glob_server.load_state_dict(checkpoint['net_glob_server_state_dict'])
        final_model = CombinedModel(net_glob_client, net_glob_server).to(device)

        test_loader_global = DataLoader(main_test_dataset, batch_size=batch_size, shuffle=False)
        acc, prec, rec, f1, auc, cm = evaluate_accuracy(final_model, test_loader_global,
                                                        device, return_conf_matrix=True, num_classes=num_classes)

        print(f"\n--- Final Metrics (image-level, {num_users} Clients) ---")
        print(f"Accuracy: {acc:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f} | F1: {f1:.4f} | AUC: {auc:.4f}")
        plot_confusion_matrix(cm, classes, num_users, save_dir="artifacts", suffix="_image")

        # Patient-level
        try:
            acc_p, prec_p, rec_p, f1_p, auc_p, cm_p = evaluate_by_patient(
                final_model,
                main_test_dataset,
                IDX_TO_PATIENT,
                batch_size=batch_size,
                device=device
            )
            print(f"--- Final Metrics (patient-level, {num_users} Clients) ---")
            print(f"Accuracy: {acc_p:.4f} | Precision: {prec_p:.4f} | Recall: {rec_p:.4f} | F1: {f1_p:.4f} | AUC: {auc_p:.4f}")
            plot_confusion_matrix(cm_p, classes, num_users, save_dir="artifacts", suffix="_patient")
        except Exception as e:
            print(f"[WARN] Patient-level evaluation failed: {e}")

    except FileNotFoundError:
        print(f"[WARN] Checkpoint not found: {checkpoint_path}")
    except Exception as e:
        print(f"[ERROR] Final evaluation failed: {e}")

try:
    plot_single_distribution_357(
        snapshots=dict_users_snapshots,
        labels_all=labels_all,
        classes=classes,
        save_path="artifacts/distribuicao_3_5_7_clientes.png"
    )
except Exception as e:
    print(f"[WARN] Falha ao gerar o gráfico único de distribuição: {e}")


STARTING TRAINING FOR 7 CLIENTS

[TRAIN] client 0: n=61  dist={'AD': 13, 'CN': 17, 'MCI': 31}
[TEST ] client 0: n=19 dist={'AD': 1, 'CN': 8, 'MCI': 10}
[TRAIN] client 1: n=59  dist={'AD': 14, 'CN': 24, 'MCI': 21}
[TEST ] client 1: n=19 dist={'AD': 5, 'CN': 2, 'MCI': 12}
[TRAIN] client 2: n=68  dist={'AD': 8, 'CN': 22, 'MCI': 38}
[TEST ] client 2: n=26 dist={'AD': 5, 'CN': 2, 'MCI': 19}
[TRAIN] client 3: n=44  dist={'AD': 9, 'CN': 14, 'MCI': 21}
[TEST ] client 3: n=17 dist={'AD': 4, 'CN': 8, 'MCI': 5}
[TRAIN] client 4: n=92  dist={'AD': 17, 'CN': 31, 'MCI': 44}
[TEST ] client 4: n=18 dist={'AD': 1, 'CN': 7, 'MCI': 10}
[TRAIN] client 5: n=48  dist={'AD': 14, 'CN': 19, 'MCI': 15}
[TEST ] client 5: n=11 dist={'AD': 0, 'CN': 2, 'MCI': 9}
[TRAIN] client 6: n=56  dist={'AD': 8, 'CN': 25, 'MCI': 23}
[TEST ] client 6: n=11 dist={'AD': 0, 'CN': 2, 'MCI': 9}




Epoch 1/5 | Train Loss: 1.0922 Acc: 0.3915 | Test  Loss: 1.1658 Acc: 0.2758 Rec: 0.3911 AUC: 0.4565 | Time: 69.95s
Validation loss decreased (inf-->1.165845). Saving model to best_model_7_clients.pt




Epoch 2/5 | Train Loss: 1.0787 Acc: 0.4265 | Test  Loss: 1.1296 Acc: 0.3523 Rec: 0.3897 AUC: 0.4877 | Time: 24.06s
Validation loss decreased (1.165845-->1.129592). Saving model to best_model_7_clients.pt
Epoch 3/5 | Train Loss: 1.0883 Acc: 0.3952 | Test  Loss: 1.1082 Acc: 0.4046 Rec: 0.4610 AUC: 0.5084 | Time: 24.07s
Validation loss decreased (1.129592-->1.108166). Saving model to best_model_7_clients.pt
Epoch 4/5 | Train Loss: 1.0601 Acc: 0.4567 | Test  Loss: 1.0904 Acc: 0.4446 Rec: 0.4570 AUC: 0.5461 | Time: 24.53s
Validation loss decreased (1.108166-->1.090390). Saving model to best_model_7_clients.pt
Epoch 5/5 | Train Loss: 1.0419 Acc: 0.4963 | Test  Loss: 1.0712 Acc: 0.4902 Rec: 0.4126 AUC: 0.5619 | Time: 31.80s
Validation loss decreased (1.090390-->1.071166). Saving model to best_model_7_clients.pt

Training for 7 clients completed!

--- Final Metrics (image-level, 7 Clients) ---
Accuracy: 0.4793 | Precision: 0.4099 | Recall: 0.4189 | F1: 0.3960 | AUC: 0.6234
--- Final Metrics (p