In [None]:
import os
import glob
import random
import math
from pathlib import Path
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, backend as K
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

print('TensorFlow:', tf.__version__)
tf.random.set_seed(42)
np.random.seed(42)
random.seed(42)

DATA_ROOT = ''
CLASS_NAMES = None

# Image/Training Params
IMG_SIZE = (150, 150)
IMG_SHAPE = IMG_SIZE + (3,)
NUM_CLASSES = 2
EPOCHS_BASELINE = 50

# Hierarchical
K_PROCESSES = 7
T_LOCAL = 1
M_DELETE_EVERY = 2 
MAX_ROUNDS = 12
LEARNING_RATE = 1e-3
BATCH_SIZE = 32

# Data Loader

In [None]:

def _list_images_in_class_dir(class_dir):
    exts = ('*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff')
    files = []
    for e in exts:
        files.extend(glob.glob(os.path.join(class_dir, e)))
    return files

def _scan_classes(root, class_names=None):
    if class_names is None:
        classes = [d.name for d in Path(root).iterdir() if d.is_dir()]
        classes.sort()
    else:
        classes = class_names
    return classes

def _load_paths_labels(root, class_names):
    X, y = [], []
    for idx, cname in enumerate(class_names):
        cdir = os.path.join(root, cname)
        files = _list_images_in_class_dir(cdir)
        X.extend(files)
        y.extend([idx]*len(files))
    return np.array(X), np.array(y)

def _read_image(path, target_size):
    img = tf.keras.utils.load_img(path, target_size=target_size)
    img = tf.keras.utils.img_to_array(img)
    return img

def load_br35h(DATA_ROOT, class_names=None, img_size=(150,150), val_split=0.2, test_split=0.1):
    data_root = Path(DATA_ROOT)
    if not data_root.exists():
        raise FileNotFoundError(f"DATA_ROOT not found: {DATA_ROOT}")

    split_dirs = ['train', 'val', 'test']
    if all((data_root / d).exists() for d in split_dirs):
        tr_classes = _scan_classes(data_root / 'train', class_names)
        X_train, y_train = _load_paths_labels(data_root / 'train', tr_classes)
        X_val, y_val = _load_paths_labels(data_root / 'val', tr_classes)
        X_test, y_test = _load_paths_labels(data_root / 'test', tr_classes)
        return (X_train, y_train), (X_val, y_val), (X_test, y_test), tr_classes

    classes = _scan_classes(DATA_ROOT, class_names)
    X, y = _load_paths_labels(DATA_ROOT, classes)
    
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=(600 + 300), 
                                                        stratify=y, random_state=42, shuffle=True)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=300, 
                                                    stratify=y_temp, random_state=42, shuffle=True)
    return (X_train, y_train), (X_val, y_val), (X_test, y_test), classes

def make_tf_dataset(X_paths, y, batch_size=32, img_size=(150,150), shuffle=True, augment=False):
    AUTOTUNE = tf.data.AUTOTUNE
    
    def _load_and_preprocess(path, label):
        img = tf.numpy_function(lambda p: _read_image(p.decode('utf-8'), img_size), [path], tf.float32)
        img.set_shape(img_size + (3,))
        img = img / 255.0
        if augment:
            img = tf.image.random_flip_left_right(img)
            img = tf.image.random_flip_up_down(img)
        return img, tf.one_hot(label, depth=NUM_CLASSES)
    
    ds = tf.data.Dataset.from_tensor_slices((X_paths.astype('U'), y.astype(np.int32)))
    if shuffle:
        ds = ds.shuffle(buffer_size=min(1000, len(X_paths)), reshuffle_each_iteration=True)
    ds = ds.map(_load_and_preprocess, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

(X_train, y_train), (X_val, y_val), (X_test, y_test), CLASS_NAMES_INFER = load_br35h(DATA_ROOT, CLASS_NAMES, IMG_SIZE)
print('Classes:', CLASS_NAMES_INFER)
print('Train/Val/Test sizes:', len(X_train), len(X_val), len(X_test))

train_ds = make_tf_dataset(X_train, y_train, BATCH_SIZE, IMG_SIZE, shuffle=True, augment=True)
val_ds   = make_tf_dataset(X_val,   y_val,   BATCH_SIZE, IMG_SIZE, shuffle=False)
test_ds  = make_tf_dataset(X_test,  y_test,  BATCH_SIZE, IMG_SIZE, shuffle=False)



## ACBAM Model (Atrous-CNN with CBAM)

- Input (150×150×3)
- **Atrous Conv**: 32 filters, 3×3, dilation=2, ReLU → MaxPool 2×2
- Three blocks (each: Conv 3×3 + MaxPool 2×2) with filters **64**, **128**, **128**
- **Channel Attention**:
  - Global **Max** and **Avg** pooling → Dense(16, ReLU) → Dense(128, ReLU), sum → **Sigmoid**
  - Multiply with feature map
- **Spatial Attention**:
  - Channel-wise **Max** and **Avg** → concat → **Atrous Conv** 7×7, dilation=2, **Sigmoid**
  - Multiply with feature map
- Flatten → Dense(512, ReLU) → Dense(k, **Softmax**)


In [None]:

def channel_attention(x, reduction_1=16, out_channels=128):
    C = x.shape[-1]
    avg_pool = layers.GlobalAveragePooling2D()(x)
    max_pool = layers.GlobalMaxPooling2D()(x)
    mlp = keras.Sequential([
        layers.Dense(reduction_1, activation='relu'),
        layers.Dense(out_channels, activation='relu')
    ])
    avg_out = mlp(avg_pool)
    max_out = mlp(max_pool)
    attn = layers.Activation('sigmoid')(layers.Add()([avg_out, max_out]))
    attn = layers.Reshape((1,1,out_channels))(attn)
    return layers.Multiply()([x, attn])

def spatial_attention(x):
    # x: (B, H, W, C)
    avg_out = tf.reduce_mean(x, axis=-1, keepdims=True)
    max_out = tf.reduce_max(x, axis=-1, keepdims=True)
    concat = layers.Concatenate(axis=-1)([avg_out, max_out])
    s = layers.Conv2D(1, kernel_size=7, dilation_rate=2, padding='same', activation='sigmoid')(concat)
    return layers.Multiply()([x, s])

def build_acbam_model(input_shape=(150,150,3), num_classes=2):
    I = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, padding='same', dilation_rate=2, activation='relu')(I)
    x = layers.MaxPooling2D(pool_size=(2,2))(x)

    for k in [64, 128, 128]:
        x = layers.Conv2D(k, 3, padding='same', activation='relu')(x)
        x = layers.MaxPooling2D(pool_size=(2,2))(x)
    theta2 = x
    theta3 = channel_attention(theta2, reduction_1=16, out_channels=128)
    theta4 = theta3 
    theta7 = spatial_attention(theta4)

    flat = layers.Flatten()(theta7)
    dense = layers.Dense(512, activation='relu')(flat)
    outputs = layers.Dense(num_classes, activation='softmax')(dense)
        
    model = keras.Model(inputs=I, outputs=outputs, name='ACBAM')
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])
    return model

# Hierarchical strategy

In [None]:
def evaluate_model(model, ds, y_true=None):
    y_prob = model.predict(ds, verbose=0)
    y_pred = np.argmax(y_prob, axis=1)
    if y_true is None:
        y_true_list = []
        for _, yb in ds:
            y_true_list.extend(np.argmax(yb.numpy(), axis=1))
        y_true = np.array(y_true_list)
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='binary' if NUM_CLASSES==2 else 'macro', zero_division=0)
    rec  = recall_score(y_true, y_pred, average='binary' if NUM_CLASSES==2 else 'macro', zero_division=0)
    f1   = f1_score(y_true, y_pred, average='binary' if NUM_CLASSES==2 else 'macro', zero_division=0)
    cm   = confusion_matrix(y_true, y_pred)
    return {'acc':acc, 'precision':prec, 'recall':rec, 'f1':f1, 'cm':cm}, (y_true, y_pred)

In [None]:
def split_dataset_for_k_minus_2(X_train, y_train, k_processes, batch_size=32, img_size=(150,150)):
    parts = k_processes - 2
    idxs = np.arange(len(X_train))
    np.random.shuffle(idxs)
    splits = np.array_split(idxs, parts)
    local_dsets = []
    for s in splits:
        Xp, yp = X_train[s], y_train[s]
        ds = make_tf_dataset(Xp, yp, batch_size=batch_size, img_size=img_size, shuffle=True, augment=True)
        local_dsets.append(ds)
    return local_dsets

def get_model_weights(model):
    return model.get_weights()

def set_model_weights(model, weights):
    model.set_weights(weights)

def average_weights_weighted(weight_list, coeffs):
    avg = []
    for weights in zip(*weight_list):
        w = np.zeros_like(weights[0])
        for wi, ci in zip(weights, coeffs):
            w += wi * ci
        avg.append(w)
    return avg

In [None]:
def fedavg_with_removal(train_dsets, val_ds, test_ds, 
                        T_local=1, M_delete_every=2, max_rounds=12):
    
    num_clients = len(train_dsets)
    clients = [build_acbam_model(IMG_SHAPE, NUM_CLASSES) for _ in range(num_clients)]
    global_model = build_acbam_model(IMG_SHAPE, NUM_CLASSES)
    
    client_histories = {i: {"train_acc": [], "val_acc": []} for i in range(num_clients)}
    history_records = {'global_val_acc':[], 'round':[]}

    round_idx = 0
    active = list(range(num_clients))
    last_removed_ds = None
    
    while len(active) > 0 and round_idx < max_rounds:
        for i in list(active):
            set_model_weights(clients[i], global_model.get_weights())
            hist = clients[i].fit(train_dsets[i], epochs=T_local, verbose=0, validation_data=val_ds)
            client_histories[i]["train_acc"].append(hist.history["accuracy"][-1])
            client_histories[i]["val_acc"].append(hist.history["val_accuracy"][-1])
        
        metrics = []
        for i in active:
            m, _ = evaluate_model(clients[i], val_ds)
            metrics.append(m)

        v_max = {
            'acc': max(m['acc'] for m in metrics),
            'precision': max(m['precision'] for m in metrics),
            'recall': max(m['recall'] for m in metrics),
            'f1': max(m['f1'] for m in metrics),
        }
        denom = 4.0
        w_sig = []
        for m in metrics:
            score = (m['acc']/v_max['acc'] + m['precision']/v_max['precision'] + 
                     m['recall']/v_max['recall'] + m['f1']/v_max['f1'])/denom
            w_sig.append(score)
        w_sum = sum(w_sig)
        coeffs = [w/w_sum for w in w_sig]
        
        weight_list = [get_model_weights(clients[i]) for i in active]
        new_avg_weights = average_weights_weighted(weight_list, coeffs)
        set_model_weights(global_model, new_avg_weights)

        if (round_idx+1) % M_delete_every == 0 and len(active) > 0:
            best_idx_in_active = int(np.argmin(w_sig))
            to_remove = active[best_idx_in_active]

            global_model.fit(train_dsets[to_remove], epochs=T_local, verbose=0, validation_data=val_ds)
            last_removed_ds = train_dsets[to_remove]
            active.remove(to_remove)
        
        val_eval = global_model.evaluate(val_ds, verbose=0)
        history_records['global_val_acc'].append(val_eval[1])
        history_records['round'].append(round_idx)
        
        round_idx += 1

    if last_removed_ds is not None:
        global_model.fit(last_removed_ds, epochs=T_local, verbose=0, validation_data=val_ds)
    
    return global_model, client_histories, history_records

hier_model, client_histories, averaged_histories = fedavg_with_removal(
    split_dataset_for_k_minus_2(X_train, y_train, K_PROCESSES, BATCH_SIZE, IMG_SIZE), val_ds, test_ds, T_local=T_LOCAL, 
    M_delete_every=M_DELETE_EVERY, max_rounds=MAX_ROUNDS
)

metrics_val_h, _ = evaluate_model(hier_model, val_ds)
metrics_test_h, _ = evaluate_model(hier_model, test_ds)
print('Hierarchical — Validation:', metrics_val_h)
print('Hierarchical — Test:', metrics_test_h)