<a href="https://colab.research.google.com/github/Luanmantegazine/FedAlzheimer/blob/main/FlowerAlzheimer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install "flwr[simulation]>=1.20.0" "torchvision>=0.15" "torch>=2.0" "torchmetrics>=1.4"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m617.6/617.6 kB[0m [31m37.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m110.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m81.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m64.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os, math, random
from contextlib import nullcontext
from collections import Counter, OrderedDict
from dataclasses import dataclass
from typing import List, Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset, WeightedRandomSampler
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights

import torchmetrics

import flwr as fl
from flwr.client import ClientApp
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvgM
from flwr.simulation import run_simulation
from flwr.common import NDArrays, ndarrays_to_parameters, Context

from sklearn.model_selection import train_test_split


from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ.pop("NVIDIA_VISIBLE_DEVICES", None)
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

Utils

In [4]:
class Config:

  data_path: str = "/content/drive/MyDrive/TCC - Grupo SLD/Projeto 2/ADNI"
  input_size: int = 224
  grayscale3: bool = True

  # FL
  num_clients: int = 3
  num_rounds: int = 10
  seed: int = 1234

  # particionamento
  partition_strategy: str = "dirichlet"
  dirichlet_alpha: float = 0.5

  # treino local
  batch_size: int = 32
  local_epochs: int = 2
  learning_rate_head: float = 5e-4
  learning_rate_backbone: float = 1e-4
  weight_decay: float = 1e-4
  optimizer: str = "sgd"  # "sgd" ou "adamw"
  momentum: float = 0.9
  grad_clip_norm: float = 1.0

  # fine-tuning gradual
  head_only_first_rounds: int = 1  # rounds iniciais treinando só a cabeça
  dropout: float = 0.3
  label_smoothing: float = 0.05

  # augmentação
  use_color_jitter: bool = True
  mixup_alpha: float = 0.2  # 0.0 desliga MixUp

  # loss
  use_focal_loss: bool = False
  focal_gamma: float = 2.0

  # scheduler
  use_cosine_warmup: bool = True
  warmup_epochs: int = 2

  # EMA
  use_ema: bool = True
  ema_decay: float = 0.999

  # TTA (test-time augmentation)
  tta_hflip: bool = True

  # strategy do servidor
  server_strategy: str = "fedavgm"  # "fedavgm" ou "fedadam"
  # FedAvgM
  server_learning_rate: float = 1e-3
  server_momentum: float = 0.9

cfg = Config()

In [5]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

def get_device() -> torch.device:
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
    return torch.device("cpu")

def build_transforms():
    tr = [
        transforms.Resize((max(256, cfg.input_size), max(256, cfg.input_size))),
        transforms.RandomResizedCrop(cfg.input_size, scale=(0.8, 1.0)),
        transforms.RandomRotation(10),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    ]
    if cfg.grayscale3:
        tr.insert(2, transforms.Grayscale(num_output_channels=3))
    if cfg.use_color_jitter:
        tr.insert(3, transforms.ColorJitter(brightness=0.2, contrast=0.2))
    tr += [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
    train_transform = transforms.Compose(tr)

    te = [transforms.Resize((cfg.input_size, cfg.input_size))]
    if cfg.grayscale3:
        te.append(transforms.Grayscale(num_output_channels=3))
    te += [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
    test_transform = transforms.Compose(te)
    return train_transform, test_transform

def partition_dirichlet(labels_all: List[int], idxs: np.ndarray, num_clients: int,
                        alpha: float, seed: int, min_size: int = 10):
    rng = np.random.RandomState(seed)
    labels = np.array([labels_all[i] for i in idxs])
    client_indices = {i: [] for i in range(num_clients)}
    min_client_size = 0
    # Garante que todos os clientes recebem exemplos
    while min_client_size < min_size:
        client_indices = {i: [] for i in range(num_clients)}
        for c in np.unique(labels):
            idx_c = idxs[labels == c]
            rng.shuffle(idx_c)
            proportions = rng.dirichlet(alpha=np.repeat(alpha, num_clients))
            proportions = proportions / proportions.sum()
            splits = (np.cumsum(proportions) * len(idx_c)).astype(int)[:-1]
            chunks = np.split(idx_c, splits)
            for i, chunk in enumerate(chunks):
                client_indices[i].extend(chunk.tolist())
        sizes = [len(client_indices[i]) for i in range(num_clients)]
        min_client_size = min(sizes) if sizes else 0
    return {i: np.array(v) for i, v in client_indices.items()}

def make_weights_for_balanced_classes(indices: List[int], imagefolder: datasets.ImageFolder, num_classes: int):
    labels = [imagefolder.samples[i][1] for i in indices]
    cnt = Counter(labels)
    class_weight = {c: len(labels) / (num_classes * cnt[c]) if cnt[c] > 0 else 0 for c in range(num_classes)}
    sample_weights = [class_weight[lbl] for lbl in labels]
    return torch.DoubleTensor(sample_weights)

def mixup(x, y, alpha=0.2):
    if alpha is None or alpha <= 0:
        return x, (y, y), 1.0
    lam = float(np.random.beta(alpha, alpha))
    idx = torch.randperm(x.size(0), device=x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    return mixed_x, (y, y[idx]), lam

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction="mean"):
        super().__init__()
        self.gamma, self.alpha, self.reduction = gamma, alpha, reduction
        self.ce = nn.CrossEntropyLoss(weight=alpha, reduction="none")
    def forward(self, logits, target):
        ce = self.ce(logits, target)  # [N]
        pt = torch.softmax(logits, dim=1).gather(1, target.view(-1, 1)).squeeze(1)
        loss = (1 - pt).pow(self.gamma) * ce
        return loss.mean() if self.reduction == "mean" else loss.sum()

def build_model(num_classes: int):
    model = resnet18(weights=ResNet18_Weights.DEFAULT)
    in_features = model.fc.in_features
    model.fc = nn.Sequential(nn.Dropout(p=cfg.dropout), nn.Linear(in_features, num_classes))
    return model

def get_params(model: nn.Module):
    return [v.detach().cpu().numpy() for _, v in model.state_dict().items()]

def set_params(model: nn.Module, params: NDArrays):
    state_dict = model.state_dict()
    new_state_dict = OrderedDict({k: torch.tensor(v) for k, v in zip(state_dict.keys(), params)})
    model.load_state_dict(new_state_dict, strict=True)

In [6]:
@dataclass
class Metrics:
    loss: float
    accuracy: float
    precision: float
    recall: float
    f1: float
    auc: float

def evaluate_model(model: nn.Module, loader: DataLoader, device: torch.device, num_classes: int) -> Metrics:
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses, logits_list, labels_list = [], [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out_logits = model(x)
            if cfg.tta_hflip:
                out_logits = out_logits + model(torch.flip(x, dims=[3]))
                out_logits = out_logits / 2.0
            loss = criterion(out_logits, y)
            losses.append(loss.item())
            logits_list.append(out_logits.cpu())
            labels_list.append(y.cpu())

    logits = torch.cat(logits_list) if logits_list else torch.zeros((0, num_classes))
    labels = torch.cat(labels_list) if labels_list else torch.zeros((0,), dtype=torch.long)
    if logits.shape[0] == 0:
        return Metrics(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

    preds = logits.argmax(dim=1)
    prob = F.softmax(logits, dim=1)

    try:
        auc_score = torchmetrics.functional.auroc(prob, labels, task="multiclass", num_classes=num_classes).item()
    except Exception:
        auc_score = 0.0

    return Metrics(
        loss=float(np.mean(losses)),
        accuracy=torchmetrics.functional.accuracy(preds, labels, task="multiclass", num_classes=num_classes).item(),
        precision=torchmetrics.functional.precision(preds, labels, average="macro", task="multiclass", num_classes=num_classes, zero_division=0).item(),
        recall=torchmetrics.functional.recall(preds, labels, average="macro", task="multiclass", num_classes=num_classes, zero_division=0).item(),
        f1=torchmetrics.functional.f1_score(preds, labels, average="macro", task="multiclass", num_classes=num_classes, zero_division=0).item(),
        auc=auc_score,
    )

class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = decay
        self.shadow = {n: p.detach().clone() for n, p in model.named_parameters() if p.requires_grad}
        self.backup = {}
    def update(self, model: nn.Module):
        for n, p in model.named_parameters():
            if n in self.shadow and p.requires_grad:
                self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for n, p in model.named_parameters():
            if n in self.shadow and p.requires_grad:
                self.backup[n] = p.detach().clone()
                p.data.copy_(self.shadow[n].data)
    def restore(self, model: nn.Module):
        for n, p in model.named_parameters():
            if n in self.backup and p.requires_grad:
                p.data.copy_(self.backup[n].data)
        self.backup = {}


In [7]:
seed_everything(cfg.seed)
device = get_device()
train_tf, test_tf = build_transforms()

full_dataset = datasets.ImageFolder(root=cfg.data_path)
labels_all = [lbl for _, lbl in full_dataset.samples]
num_classes = len(full_dataset.classes)

idx_all = np.arange(len(full_dataset))
from sklearn.model_selection import train_test_split
idx_train, idx_test = train_test_split(
    idx_all, test_size=0.2, stratify=labels_all, random_state=cfg.seed
)

class TransformingSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

test_dataset: Dataset = TransformingSubset(Subset(full_dataset, idx_test), test_tf)

# particionamento (Dirichlet) no conjunto de treino
parts_train = partition_dirichlet(labels_all, idx_train, cfg.num_clients, cfg.dirichlet_alpha, cfg.seed)


In [8]:
seed_everything(cfg.seed)
device = get_device()

train_tf, test_tf = build_transforms()
full_dataset = datasets.ImageFolder(root=cfg.data_path)
labels_all = [lbl for _, lbl in full_dataset.samples]
num_classes = len(full_dataset.classes)

idx_all = np.arange(len(full_dataset))
idx_train, idx_test = train_test_split(
    idx_all, test_size=0.2, stratify=labels_all, random_state=cfg.seed
)

class TransformingSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

test_dataset: Dataset = TransformingSubset(Subset(full_dataset, idx_test), test_tf)
parts_train = partition_dirichlet(labels_all, idx_train, cfg.num_clients, cfg.dirichlet_alpha, cfg.seed)


In [11]:
class AlzheimerClient(fl.client.NumPyClient):
    def __init__(self, cid: int, full_dataset, train_idx, test_idx, num_classes,
                 lr_head, lr_backbone, batch_size, local_epochs, device):
        self.cid = cid
        self.num_classes = num_classes
        self.device = device
        self.local_epochs = local_epochs
        self.batch_size = batch_size

        sub_train = TransformingSubset(Subset(full_dataset, train_idx), train_tf)
        sub_test  = TransformingSubset(Subset(full_dataset, test_idx),  test_tf)


        sampler = None
        base_subset = sub_train.subset
        if isinstance(base_subset, Subset) and isinstance(base_subset.dataset, datasets.ImageFolder) and len(base_subset.indices) > 0:
            weights = make_weights_for_balanced_classes(base_subset.indices, base_subset.dataset, num_classes)
            sampler = WeightedRandomSampler(weights, num_samples=len(base_subset.indices), replacement=True)

        self.train_loader = DataLoader(sub_train, batch_size=batch_size, shuffle=(sampler is None),
                                       sampler=sampler, num_workers=0, persistent_workers=False, pin_memory=False)
        self.test_loader  = DataLoader(sub_test,  batch_size=batch_size, shuffle=False,
                                       num_workers=0, persistent_workers=False, pin_memory=False)

        self.model = build_model(num_classes).to(self.device)

        if cfg.use_focal_loss:
            self.criterion = FocalLoss(gamma=cfg.focal_gamma).to(self.device)
        else:
            self.criterion = nn.CrossEntropyLoss(label_smoothing=cfg.label_smoothing).to(self.device)

        head_params, backbone_params = [], []
        for n, p in self.model.named_parameters():
            (head_params if n.startswith("fc.") else backbone_params).append(p)

        if cfg.optimizer.lower() == "sgd":
            self.optimizer = torch.optim.SGD([
                {"params": backbone_params, "lr": lr_backbone, "weight_decay": cfg.weight_decay, "momentum": cfg.momentum},
                {"params": head_params,     "lr": lr_head,     "weight_decay": cfg.weight_decay, "momentum": cfg.momentum},
            ])
        else:
            self.optimizer = torch.optim.AdamW([
                {"params": backbone_params, "lr": lr_backbone, "weight_decay": cfg.weight_decay},
                {"params": head_params,     "lr": lr_head,     "weight_decay": cfg.weight_decay},
            ])

        if cfg.use_cosine_warmup:
            def lr_lambda(ep):
                if ep < cfg.warmup_epochs:
                    return (ep + 1) / max(1, cfg.warmup_epochs)
                t = (ep - cfg.warmup_epochs) / max(1, self.local_epochs - cfg.warmup_epochs)
                return 0.5 * (1 + math.cos(math.pi * t))
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda)
        else:
            self.scheduler = None

        self.ema = EMA(self.model, decay=cfg.ema_decay) if cfg.use_ema else None

    def _freeze_backbone(self, freeze: bool = True):
        for n, p in self.model.named_parameters():
            if not n.startswith("fc."):
                p.requires_grad = not freeze

    def get_parameters(self, config):
        return get_params(self.model)

    def fit(self, parameters, config):
        set_params(self.model, parameters)

        current_round = int(config.get("round", 1)) if config else 1
        head_only = current_round <= cfg.head_only_first_rounds
        self._freeze_backbone(head_only)

        epochs = int(config.get("local_epochs", self.local_epochs)) if config else self.local_epochs
        amp_ctx = nullcontext()  # CPU: sem autocast

        self.model.train()
        for ep in range(epochs):
            for x, y in self.train_loader:
                x, y = x.to(self.device), y.to(self.device)
                self.optimizer.zero_grad(set_to_none=True)

                if cfg.mixup_alpha and cfg.mixup_alpha > 0:
                    x, (y_a, y_b), lam = mixup(x, y, alpha=cfg.mixup_alpha)
                    with amp_ctx:
                        logits = self.model(x)
                        loss = lam * self.criterion(logits, y_a) + (1 - lam) * self.criterion(logits, y_b)
                else:
                    with amp_ctx:
                        logits = self.model(x)
                        loss = self.criterion(logits, y)

                loss.backward()
                if cfg.grad_clip_norm and cfg.grad_clip_norm > 0:
                    nn.utils.clip_grad_norm_(self.model.parameters(), cfg.grad_clip_norm)
                self.optimizer.step()

                if self.ema:
                    self.ema.update(self.model)

            if self.scheduler:
                self.scheduler.step()

        if self.ema:
            self.ema.apply_shadow(self.model)
        m = evaluate_model(self.model, self.test_loader, self.device, self.num_classes)
        if self.ema:
            self.ema.restore(self.model)

        return get_params(self.model), len(self.train_loader.dataset), {
            "accuracy_local": m.accuracy, "f1_local": m.f1
        }

    def evaluate(self, parameters, config):
        set_params(self.model, parameters)
        if self.ema:
            self.ema.apply_shadow(self.model)
        m = evaluate_model(self.model, self.test_loader, self.device, self.num_classes)
        if self.ema:
            self.ema.restore(self.model)
        return m.loss, len(self.test_loader.dataset), {"accuracy": m.accuracy, "f1": m.f1}



In [12]:
def get_evaluate_fn(test_dataset: Dataset, batch_size: int, num_classes: int, device: torch.device):
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    def evaluate(server_round: int, parameters: NDArrays, config_server: Dict[str, float]):
        model = build_model(num_classes).to(device)
        set_params(model, parameters)
        m = evaluate_model(model, test_loader, device, num_classes)
        return m.loss, {"accuracy": m.accuracy, "precision": m.precision, "recall": m.recall, "f1": m.f1, "auc": m.auc}
    return evaluate

def weighted_average(client_metrics: List[Tuple[int, Dict[str, float]]]) -> Dict[str, float]:
    if not client_metrics:
        return {}
    total = sum(n for n, _ in client_metrics)
    if total == 0:
        keys0 = list(client_metrics[0][1].keys())
        return {k: 0.0 for k in keys0}
    all_keys = set().union(*(m.keys() for _, m in client_metrics))
    return {k: sum(n * m.get(k, 0.0) for n, m in client_metrics) / total for k in all_keys}


In [13]:
def client_fn(context: Context):
    os.environ["CUDA_VISIBLE_DEVICES"] = ""  # reforça CPU nos atores
    try:
        torch.set_num_threads(1)
    except Exception:
        pass
    cid = int(context.node_id) % cfg.num_clients
    train_idx = parts_train[cid]
    return AlzheimerClient(
        cid=cid,
        full_dataset=full_dataset,
        train_idx=train_idx,
        test_idx=idx_test,
        num_classes=num_classes,
        lr_head=cfg.learning_rate_head,
        lr_backbone=cfg.learning_rate_backbone,
        batch_size=cfg.batch_size,
        local_epochs=cfg.local_epochs,
        device=device,
    ).to_client()

client_app = ClientApp(client_fn)

# ====== ServerApp ======
def server_fn(context: Context) -> ServerAppComponents:
    server_cfg = ServerConfig(num_rounds=cfg.num_rounds)
    initial_parameters = ndarrays_to_parameters(get_params(build_model(num_classes)))
    strategy = FedAvgM(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_available_clients=cfg.num_clients,
        evaluate_fn=get_evaluate_fn(test_dataset, cfg.batch_size, num_classes, device),
        on_fit_config_fn=lambda rnd: {"local_epochs": cfg.local_epochs, "round": rnd},
        evaluate_metrics_aggregation_fn=weighted_average,
        server_learning_rate=cfg.server_learning_rate,
        server_momentum=cfg.server_momentum,
        initial_parameters=initial_parameters,
    )
    # Compat com versões (server_config vs config)
    try:
        return ServerAppComponents(strategy=strategy, server_config=server_cfg)
    except TypeError:
        return ServerAppComponents(strategy=strategy, config=server_cfg)

server_app = ServerApp(server_fn=server_fn)


In [None]:
print("Apps prontos. Iniciando simulação...")
history = run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=cfg.num_clients,
    backend_config={"client_resources": {"num_cpus": 1, "num_gpus": 0.0}},
)
print("Simulação concluída.")

DEBUG:flwr:Asyncio event loop already running.


Apps prontos. Iniciando simulação...


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 205MB/s]
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
  img = Image.fromarray(np_img, "RGB")
[36m(pid=1665)[0m 2025-08-13 17:21:49.907695: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=1665)[0m E0000 00:00:1755105709.973253    1665 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=1665)[0m E0000 00:00:1755105709.992572    1665 cuda_blas.c