<a href="https://colab.research.google.com/github/Luanmantegazine/FedAlzheimer/blob/main/TTF(FedAvg).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [39]:
!pip install "tensorflow>=2.12" "tensorflow-federated>=0.71.0"



In [40]:
!pip install tensorflow-addons



In [54]:
import os, math, random, json, itertools
from dataclasses import dataclass
from collections import defaultdict, Counter
from typing import Dict, List, Tuple

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

try:
    import nest_asyncio, asyncio
    nest_asyncio.apply()
except Exception:
    pass


Utils

In [42]:
class Config:

  data_path: str = "/content/drive/MyDrive/TCC - Grupo SLD/Projeto 2/ADNI"

  img_size: int = 224
  grayscale_to_3: bool = True
  per_image_standardization: bool = True
  num_clients: int = 5
  alpha_dirichlet: float = 0.5
  test_size: float = 0.2
  seed: int = 42

  batch_size: int = 16
  local_epochs: int = 1
  rounds: int = 15
  clients_per_round: int = 0  # 0 => usa todos

  # Otimizadores (TFF)
  client_lr: float = 1e-3
  client_momentum: float = 0.9
  server_lr: float = 1.0

  # Perda
  label_smoothing: float = 0.05
  use_focal_loss: bool = False

cfg = Config()

In [43]:
def seed_all(seed: int):
    random.seed(seed); np.random.seed(seed); tf.random.set_seed(seed)
seed_all(cfg.seed)


In [44]:
IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff")

def discover_imagefolder(root: str):
    if not os.path.isdir(root):
        raise FileNotFoundError(f"data_path inexistente: {root}")
    classes = sorted([d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))])
    if not classes:
        raise ValueError(f"Nenhuma subpasta de classe encontrada em {root}")
    class_to_idx = {c: i for i, c in enumerate(classes)}
    idx_to_class = {i: c for c, i in class_to_idx.items()}
    filepaths, labels = [], []
    for c in classes:
        cdir = os.path.join(root, c)
        for dp, _, fns in os.walk(cdir):
            for f in fns:
                if f.lower().endswith(IMG_EXTS):
                    filepaths.append(os.path.join(dp, f))
                    labels.append(class_to_idx[c])
    if not filepaths:
        raise ValueError("Nenhuma imagem encontrada.")
    return filepaths, labels, idx_to_class, class_to_idx

def extract_subject_id(path_str: str) -> str:
    base = os.path.basename(path_str)
    stem = os.path.splitext(base)[0]
    return stem.split("_")[0] if "_" in stem else stem

In [45]:
def stratified_subject_split(filepaths, labels, test_size: float, seed: int):
    subject_to_indices = defaultdict(list)
    for idx, fp in enumerate(filepaths):
        sid = extract_subject_id(fp)
        subject_to_indices[sid].append(idx)

    subject_label = {}
    for sid, idcs in subject_to_indices.items():
        maj = Counter([labels[i] for i in idcs]).most_common(1)[0][0]
        subject_label[sid] = maj

    rng = np.random.default_rng(seed)
    subjects_by_class = defaultdict(list)
    for sid, c in subject_label.items():
        subjects_by_class[c].append(sid)

    train_subjects, test_subjects = [], []
    for c, sids in subjects_by_class.items():
        s = sids[:]; rng.shuffle(s)
        n = len(s); n_test = max(1, int(round(test_size * n))) if n > 1 else 1
        if n - n_test == 0 and n > 1: n_test -= 1
        test_subjects.extend(s[:n_test]); train_subjects.extend(s[n_test:])

    train_idx, test_idx = [], []
    for sid in train_subjects:
        train_idx.extend(subject_to_indices[sid])
    for sid in test_subjects:
        test_idx.extend(subject_to_indices[sid])

    return train_idx, test_idx, subject_label, subject_to_indices

def dirichlet_partition_by_subjects(train_subjects_by_class, alpha: float, num_clients: int, seed: int):
    rng = np.random.default_rng(seed)
    client_to_subjects = {i: [] for i in range(num_clients)}
    for c, subj_list in train_subjects_by_class.items():
        subj_list = subj_list[:]; rng.shuffle(subj_list); n = len(subj_list)
        if n == 0: continue
        prop = rng.dirichlet(alpha=[alpha]*num_clients)
        counts = np.floor(prop * n).astype(int)
        diff = n - counts.sum()
        if diff > 0:
            rem = prop * n - counts
            for i in np.argsort(rem)[-diff:]: counts[i] += 1
        elif diff < 0:
            for i in np.argsort(prop)[:abs(diff)]:
                if counts[i] > 0: counts[i] -= 1
        s = 0
        for cid, k in enumerate(counts):
            if k <= 0: continue
            client_to_subjects[cid].extend(subj_list[s:s+k]); s += k

    empties = [cid for cid, lst in client_to_subjects.items() if not lst]
    avail = [cid for cid, lst in client_to_subjects.items() if len(lst) > 1]
    for cid in empties:
        if not avail: break
        donor = avail.pop()
        moved = client_to_subjects[donor].pop()
        client_to_subjects[cid].append(moved)
        if len(client_to_subjects[donor]) > 1: avail.append(donor)
    return client_to_subjects

In [46]:
filepaths, labels, idx_to_class, class_to_idx = discover_imagefolder(cfg.data_path)
num_classes = len(idx_to_class)
print(f"[INFO] {len(filepaths)} imagens | classes: {idx_to_class}")

train_idx, test_idx, subject_label, subject_to_indices = stratified_subject_split(
    filepaths, labels, cfg.test_size, cfg.seed
)

train_subjects_by_class = defaultdict(list)
for sid, c in subject_label.items():
    if subject_to_indices[sid][0] in train_idx:
        train_subjects_by_class[c].append(sid)

client_to_subjects = dirichlet_partition_by_subjects(
    train_subjects_by_class, cfg.alpha_dirichlet, cfg.num_clients, cfg.seed
)

client_to_files = {i: [] for i in range(cfg.num_clients)}
for cid, sids in client_to_subjects.items():
    for sid in sids:
        for i in subject_to_indices[sid]:
            if i in train_idx:
                client_to_files[cid].append((filepaths[i], labels[i]))

test_files = [(filepaths[i], labels[i]) for i in test_idx]
print(f"[INFO] Treino total: {sum(len(v) for v in client_to_files.values())} | Teste: {len(test_files)}")

[INFO] 545 imagens | classes: {0: 'AD', 1: 'CN', 2: 'MCI'}
[INFO] Treino total: 157 | Teste: 388


In [47]:
AUTOTUNE = tf.data.AUTOTUNE

def _decode(fp, training: bool):
    img_bytes = tf.io.read_file(fp)
    img = tf.io.decode_image(img_bytes, channels=1 if cfg.grayscale_to_3 else 3, expand_animations=False)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, [cfg.img_size, cfg.img_size])
    if cfg.grayscale_to_3:
        img = tf.image.grayscale_to_rgb(img)
    if cfg.per_image_standardization:
        img = tf.image.per_image_standardization(img)
    return img

def build_tfds(files_and_labels: List[Tuple[str, int]], training: bool) -> tf.data.Dataset:
    if not files_and_labels:
        ds = tf.data.Dataset.from_tensors((tf.zeros([cfg.img_size, cfg.img_size, 3], tf.float32), 0))
        return ds.batch(cfg.batch_size).map(lambda x,y: {'x': x, 'y': tf.cast(y, tf.int32)})
    fpaths = tf.constant([fp for fp,_ in files_and_labels])
    labs   = tf.constant([y  for _,y in files_and_labels], dtype=tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((fpaths, labs))
    if training:
        ds = ds.shuffle(min(len(files_and_labels), 1000), seed=cfg.seed, reshuffle_each_iteration=True)
    ds = ds.map(lambda fp, y: (_decode(fp, training), y), num_parallel_calls=AUTOTUNE)
    ds = ds.batch(cfg.batch_size, drop_remainder=False).prefetch(AUTOTUNE)
    ds = ds.map(lambda x, y: {'x': x, 'y': y})
    if training and cfg.local_epochs > 1:
        ds = ds.repeat(cfg.local_epochs)
    return ds

test_ds = build_tfds(test_files, training=False)

In [48]:
def build_keras_model(num_classes: int) -> tf.keras.Model:
    inp = tf.keras.Input(shape=(cfg.img_size, cfg.img_size, 3))
    x = tf.keras.layers.RandomRotation(0.05)(inp)
    x = tf.keras.layers.RandomTranslation(0.02, 0.02)(x)
    x = tf.keras.layers.RandomZoom(0.02, 0.02)(x)
    base = tf.keras.applications.EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x)
    x = tf.keras.layers.GlobalAveragePooling2D()(base.output)
    x = tf.keras.layers.Dropout(0.2)(x)
    out = tf.keras.layers.Dense(num_classes, activation="softmax")(x)
    return tf.keras.Model(inputs=inp, outputs=out)

class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
    def __init__(self, gamma=2.0, alpha=None, label_smoothing=0.0, name="sparse_categorical_focal_loss"):
        super().__init__(name=name)
        self.gamma, self.alpha, self.label_smoothing = gamma, alpha, label_smoothing
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.int32)
        y_true_oh = tf.one_hot(y_true, depth=num_classes, dtype=y_pred.dtype)
        if self.label_smoothing > 0.0:
            y_true_oh = y_true_oh * (1.0 - self.label_smoothing) + self.label_smoothing / num_classes
        y_pred = tf.clip_by_value(y_pred, 1e-8, 1.0)  # modelo já sai com softmax
        pt = tf.reduce_sum(y_true_oh * y_pred, axis=-1)
        alpha_factor = 1.0
        if self.alpha is not None:
            alpha_vec = tf.constant(self.alpha, dtype=y_pred.dtype)
            alpha_factor = tf.reduce_sum(y_true_oh * alpha_vec, axis=-1)
        loss = -alpha_factor * tf.pow(1.0 - pt, self.gamma) * tf.math.log(pt)
        return tf.reduce_mean(loss)

class SparseCCEWithLabelSmoothing(tf.keras.losses.Loss):
    def __init__(self, label_smoothing=0.1, name="sparse_cce_with_ls"):
        super().__init__(name=name)
        self.label_smoothing = float(label_smoothing)
        self._cce = tf.keras.losses.CategoricalCrossentropy(
            from_logits=False, label_smoothing=label_smoothing
        )
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.int32)
        y_true_oh = tf.one_hot(y_true, depth=num_classes, dtype=y_pred.dtype)
        return self._cce(y_true_oh, y_pred)

def make_loss():
    if cfg.use_focal_loss:
        return SparseCategoricalFocalLoss(label_smoothing=cfg.label_smoothing)
    else:
        if cfg.label_smoothing > 0.0:
            return SparseCCEWithLabelSmoothing(label_smoothing=cfg.label_smoothing)
        else:
            return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)


In [49]:
_sample_client = 0
sample_ds = build_tfds(client_to_files[_sample_client], training=True)
input_spec = sample_ds.element_spec

def model_fn_tff():
    keras_model = build_keras_model(num_classes)
    return tff.learning.models.from_keras_model(
        keras_model=keras_model,
        loss=make_loss(),
        input_spec=input_spec,
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
    )


In [50]:
def to_serializable(obj):
    if isinstance(obj, dict):
        return {k: to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [to_serializable(v) for v in obj]
    if isinstance(obj, tf.Tensor):
        v = obj.numpy()
        return v.item() if np.ndim(v) == 0 else v.tolist()
    if isinstance(obj, (np.floating, np.integer)):
        return obj.item()
    return obj

In [51]:
learning_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn_tff,
    client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=cfg.client_lr, momentum=cfg.client_momentum),
    server_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=cfg.server_lr),
)
state = learning_process.initialize()

client_ids = list(client_to_files.keys())
clients_per_round = (min(cfg.clients_per_round, len(client_ids))
                     if cfg.clients_per_round and cfg.clients_per_round > 0
                     else len(client_ids))
print(f"[INFO] Rounds: {cfg.rounds} | clientes/rodada: {clients_per_round}/{len(client_ids)}")

for rnd in range(1, cfg.rounds + 1):
    if clients_per_round < len(client_ids):
        selected = random.sample(client_ids, k=clients_per_round)
    else:
        selected = client_ids
    federated_train_data = [build_tfds(client_to_files[cid], training=True) for cid in selected]

    state, metrics = learning_process.next(state, federated_train_data)
    try:
        metrics_dict = {k: float(v) if hasattr(v, "numpy") else v for k, v in metrics.items()}
    except Exception:
        metrics_dict = metrics
    print(f"[Round {rnd:02d}] train=", json.dumps(to_serializable(metrics), indent=2))

[INFO] Rounds: 15 | clientes/rodada: 5/5
[Round 01] train= {
  "distributor": [],
  "client_work": {
    "train": {
      "accuracy": 0.23566879332065582,
      "loss": 1.2202403545379639,
      "num_examples": 157,
      "num_batches": 12
    }
  },
  "aggregator": {
    "mean_value": [],
    "mean_weight": []
  },
  "finalizer": {
    "update_non_finite": 0
  }
}
[Round 02] train= {
  "distributor": [],
  "client_work": {
    "train": {
      "accuracy": 0.35668790340423584,
      "loss": 1.1421669721603394,
      "num_examples": 157,
      "num_batches": 12
    }
  },
  "aggregator": {
    "mean_value": [],
    "mean_weight": []
  },
  "finalizer": {
    "update_non_finite": 0
  }
}
[Round 03] train= {
  "distributor": [],
  "client_work": {
    "train": {
      "accuracy": 0.5095541477203369,
      "loss": 1.0746499300003052,
      "num_examples": 157,
      "num_batches": 12
    }
  },
  "aggregator": {
    "mean_value": [],
    "mean_weight": []
  },
  "finalizer": {
    "update_

In [52]:
final_model = build_keras_model(num_classes)
weights = learning_process.get_model_weights(state)
try:
    weights.assign_weights_to(final_model)
except Exception:
    # fallback (se necessário)
    try:
        final_model.set_weights(list(weights.trainable) + list(weights.non_trainable))
    except Exception as e:
        print("[WARN] Falha ao transferir pesos:", e)

final_model.save("tff_alzheimer_model.h5")
print("[INFO] Modelo salvo: tff_alzheimer_model.h5")

# Avaliação central detalhada (accuracy, macro-P/R/F1, CM)
def evaluate_central(model: tf.keras.Model, files_and_labels: List[Tuple[str,int]]):
    ds = build_tfds(files_and_labels, training=False)
    y_true_all, y_pred_all = [], []
    for batch in ds:
        x, y = batch['x'], batch['y']
        probs = model(x, training=False).numpy()
        y_pred = probs.argmax(axis=-1)
        y_true_all.append(y.numpy()); y_pred_all.append(y_pred)
    y_true = np.concatenate(y_true_all, axis=0); y_pred = np.concatenate(y_pred_all, axis=0)
    acc = float((y_true == y_pred).mean())
    cm = tf.math.confusion_matrix(y_true, y_pred, num_classes=num_classes).numpy()
    prec_c, rec_c, f1_c = [], [], []
    for c in range(num_classes):
        tp = cm[c,c]; fp = cm[:,c].sum() - tp; fn = cm[c,:].sum() - tp
        prec = tp / (tp + fp + 1e-12); rec = tp / (tp + fn + 1e-12)
        f1 = 2*prec*rec / (prec + rec + 1e-12)
        prec_c.append(float(prec)); rec_c.append(float(rec)); f1_c.append(float(f1))
    return {
        "accuracy": acc,
        "macro_precision": float(np.mean(prec_c)),
        "macro_recall": float(np.mean(rec_c)),
        "macro_f1": float(np.mean(f1_c)),
        "per_class": {
            "precision": prec_c, "recall": rec_c, "f1": f1_c,
            "class_names": [idx_to_class[i] for i in range(num_classes)]
        },
        "confusion_matrix": cm.tolist()
    }

central_eval = evaluate_central(final_model, test_files)
print("[CENTRAL TEST] ", json.dumps({k:v for k,v in central_eval.items() if k!='confusion_matrix'}, indent=2))

  saving_api.save_model(


[INFO] Modelo salvo: tff_alzheimer_model.h5
[CENTRAL TEST]  {
  "accuracy": 0.4329896907216495,
  "macro_precision": 0.14432989690721612,
  "macro_recall": 0.33333333333333137,
  "macro_f1": 0.20143884892072197,
  "per_class": {
    "precision": [
      0.0,
      0.0,
      0.43298969072164833
    ],
    "recall": [
      0.0,
      0.0,
      0.9999999999999941
    ],
    "f1": [
      0.0,
      0.0,
      0.604316546762166
    ],
    "class_names": [
      "AD",
      "CN",
      "MCI"
    ]
  }
}
