**Library Import**

In [1]:
import flwr as fl
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import euclidean
from collections import defaultdict, Counter
import random
import uuid
import warnings
import json
import matplotlib.pyplot as plt
import umap.umap_ as UMAP
from sklearn.decomposition import PCA
from kneed import KneeLocator
from flwr.common import Context, ndarrays_to_parameters, FitIns, EvaluateIns
from flwr.client import NumPyClient
from flwr.server.strategy import FedAvg
import os
from sklearn.linear_model import LogisticRegression
from scipy.stats import skew, kurtosis
from collections import defaultdict
import numpy as np
import json

**Utility Functions**

In [2]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings("ignore")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Dataset setup
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
imagenet_subset = ImageFolder(root='/gpfs/helios/home/mahmouds/Thesis/data/ILSVRC2012', transform=transform)

# Dirichlet sampling
def create_dirichlet_clients(dataset, num_clients=3, alpha=1, seed=42):
    np.random.seed(seed)
    num_classes = len(dataset.classes)
    data_indices = [[] for _ in range(num_classes)]
    for idx, (_, label) in enumerate(dataset):
        data_indices[label].append(idx)
    client_indices = defaultdict(list)
    for c in range(num_classes):
        idxs = np.array(data_indices[c])
        np.random.shuffle(idxs)
        proportions = np.random.dirichlet(alpha * np.ones(num_clients))
        proportions = (np.cumsum(proportions) * len(idxs)).astype(int)[:-1]
        split = np.split(idxs, proportions)
        for i, chunk in enumerate(split):
            client_indices[i].extend(chunk.tolist())
    return client_indices



# Visualization
def plot_real_client_distributions( client_indices, num_clients=3):
    try:
        dataset=imagenet_subset
        #print (f"the Dataset is{imagenet_subset.classes}  ")
        #print (f"And the client Indices are {client_indices}")
        idx_to_class = {i: name for i, name in enumerate(dataset.classes)}
        num_classes = len(idx_to_class)
        class_counts = np.zeros((num_clients, num_classes), dtype=int)
        for client_id, indices in client_indices.items():
            labels = [dataset[idx][1] for idx in indices]
            for label in labels:
                class_counts[client_id, label] += 1
        fig, ax = plt.subplots(figsize=(12, 6))
        bottom = np.zeros(num_clients)
        for cls in range(num_classes):
            heights = class_counts[:, cls]
            total_per_client = class_counts.sum(axis=1)
            proportions = heights / total_per_client
            bars = ax.bar(range(num_clients), proportions, bottom=bottom, label=idx_to_class[cls])
            for i, bar in enumerate(bars):
                if heights[i] > 0:
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        bottom[i] + proportions[i] / 2,
                        str(heights[i]),
                        ha='center', va='center', fontsize=8, color='white'
                    )
            bottom += proportions
        ax.set_title('Client-Class Distribution (Proportion + Raw Count)')
        ax.set_xlabel('Client ID')
        ax.set_ylabel('Proportion of Samples')
        ax.set_ylim(0, 1.1)
        ax.legend(ncol=2, bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.savefig(f"{self.cid} Data Distribution.png")
        plt.close()
        
    except Exception as e:
            print(f"Client can't plot the data distribution because of : {e}")
        

def unnormalize_image(tensor_img):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = tensor_img * std + mean
    return torch.clamp(img, 0, 1)

def visualize_clusters(client_clusters, dataset_classes, sample_per_class=500,client="client_0", top_k_images=5):
    for class_id, data in client_clusters.items():
        class_name = dataset_classes[class_id]
        #truncate_n = min(sample_per_class, len(data['features_raw']))
        features = np.array(data['features_raw'])#[:truncate_n]
        cluster_ids = np.array(data['cluster_ids'])#[:truncate_n]
        imgs = data['images']#[:truncate_n]
        features_2d = data['features_2d']
        fig, axs = plt.subplots(1, 2, figsize=(12, 5), gridspec_kw={'width_ratios': [4, 1]})
        scatter = axs[0].scatter(features_2d[:, 0], features_2d[:, 1], c=cluster_ids, cmap='tab10', s=15, alpha=0.8)
        axs[0].set_title(f"Class {class_id} - {class_name} - Cluster View")
        axs[0].set_xlabel("UMAP 1")
        axs[0].set_ylabel("UMAP 2")
        axs[0].grid(True)
        fig.colorbar(scatter, ax=axs[0], ticks=np.unique(cluster_ids))
        class_image = transforms.ToPILImage()(unnormalize_image(imgs[0]))
        axs[1].imshow(class_image)
        axs[1].set_title(f"Sample\n({class_name})")
        axs[1].axis("off")
        plt.tight_layout()
        plt.savefig(f" In {client} : Class {class_id} ({class_name}) Clustering with example")
        plt.close()
        unique_clusters = np.unique(cluster_ids)
        for cluster_id in unique_clusters:
            mask = (cluster_ids == cluster_id)
            cluster_feats = features[mask]
            cluster_imgs = [imgs[i] for i in range(len(imgs)) if mask[i]]
            centroid = np.mean(cluster_feats, axis=0)
            dists = np.linalg.norm(cluster_feats - centroid, axis=1)
            top_indices = np.argsort(dists)[:top_k_images]
            fig_cluster, axs_cluster = plt.subplots(1, top_k_images, figsize=(top_k_images * 3, 3))
            for i, idx in enumerate(top_indices):
                img = cluster_imgs[idx]
                img_for_display = transforms.ToPILImage()(unnormalize_image(img))
                if not isinstance(axs_cluster, np.ndarray):
                    axs_cluster = [axs_cluster]
                axs_cluster[i].imshow(img_for_display)
                axs_cluster[i].axis("off")
            plt.suptitle(f"Class {class_id} ({class_name}) - Cluster {cluster_id}")
            plt.tight_layout()
            plt.savefig(f"In {client} : Class {class_id} ({class_name}) - Cluster {cluster_id}")
            plt.close()

# Feature extraction
@torch.no_grad()
def extract_features_dinov2(dataloader, model, device="cuda"):
    try:
        model = model.to(device).eval()
        features, images, labels = [], [], []
        for imgs, lbls in dataloader:
            imgs = imgs.to(device, non_blocking=True)
            feats = model(imgs).cpu()
            features.append(feats)
            images.append(imgs.cpu())
            labels.append(lbls.cpu())
        features = torch.cat(features)
        images = torch.cat(images)
        labels = torch.cat(labels)
        assert len(labels.shape) == 1, f"Labels must be 1D, got shape {labels.shape}"
        assert images.shape[0] == labels.shape[0] == features.shape[0], "Mismatched output sizes"
        torch.cuda.empty_cache()
        return features, images, labels
    except Exception as e:
            print(f"Client can't extract the Dino features because of : {e}")
    
# Clustering
def cluster_per_class(images, labels, features, sample_fraction=1.0, max_k=10):
    client_clusters = defaultdict(dict)
    data = defaultdict(list)
    for img, lbl, feat in zip(images, labels, features):
        data[lbl.item()].append((img, feat))
        
    for class_id, items in data.items():
        if len(items) < 5:
            print(f"Skipping class {class_id}: only {len(items)} samples")
            continue
        sample_count = max(int(len(items) * sample_fraction), 10)
        items = items[:sample_count]
        if len(items) < 4:
            print(f"Skipping class {class_id}: too few after sampling")
            continue
        imgs = [x[0] for x in items]
        feats = np.stack([x[1].numpy() for x in items])
        features_pca = PCA(n_components=0.95).fit_transform(feats)
        safe_n_neighbors = min(20, max(2, len(features_pca) - 1))
        try:
            features_umap = UMAP.UMAP(n_components=5, n_neighbors=safe_n_neighbors, random_state=42).fit_transform(feats)
        except Exception as e:
            print(f"UMAP failed on class {class_id}: {e}")
            continue
            
        # Optimal K using Elbow Method to determine number of clusters per class    
        inertias = []
        k_range = list(range(2, min(max_k, len(features_umap)) + 1))
        for k in k_range:
            kmeans_try = KMeans(n_clusters=k, random_state=42).fit(features_umap)
            inertias.append(kmeans_try.inertia_)
        try:
            elbow = KneeLocator(k_range, inertias, curve='convex', direction='decreasing').elbow
            num_clusters = elbow if elbow is not None else 2
        except Exception:
            num_clusters = 2
        
        # Final KMeans with estimated k
        try:
            kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(features_umap)
            cluster_ids = kmeans.labels_
        except Exception as e:
            print(f"KMeans failed on class {class_id}: {e}")
            continue
        # Precompute 2D projection for visualization only    
        features_2d = UMAP.UMAP(n_components=2, n_neighbors=safe_n_neighbors, random_state=42).fit_transform(feats)
        client_clusters[class_id] = {
            "cluster_ids": cluster_ids,
            "features": features_pca,
            "features_2d": features_2d,
            "features_raw": feats,
            "umap": features_umap,
            "images": imgs
        }
    return client_clusters
    
def compute_saliency(img_tensor, model):
    img_tensor = img_tensor.unsqueeze(0).cuda().requires_grad_()
    output = model(img_tensor)
    pred_class = output.argmax(dim=1)

    loss = output[0, pred_class]
    loss.backward()

    saliency = img_tensor.grad.data.squeeze(0)  # (C, H, W)
    saliency = saliency.mean(dim=0).cpu()  # Average over channels, keep negative/positive info
    saliency = saliency.cpu()

    # Robust normalization
    vmax = torch.quantile(saliency.abs(), 0.99)  # 99th percentile
    saliency = saliency / (vmax + 1e-10)  # scale using robust maximum
    saliency = torch.clamp(saliency, -1, 1)

    return saliency

def overlay_saliency_on_image(image_tensor, saliency_map, alpha=0.5):
    import matplotlib.cm as cm

    # Convert image back to [0,1] range
    image = unnormalize_image(image_tensor).cpu()
    image = torch.clamp(image, 0, 1)

    # Get saliency colormap
    saliency_color = cm.RdYlGn((saliency_map.numpy() + 1) / 2.0)  # Normalize saliency to [0,1] for colormap
    saliency_color = torch.tensor(saliency_color[..., :3]).permute(2,0,1)  # (C,H,W)

    # Blend image and saliency
    overlay = (1 - alpha) * image + alpha * saliency_color
    overlay = torch.clamp(overlay, 0, 1)

    return to_pil_image(overlay)


def compute_tcav_score(cav, classwise_feats):
    scores = {}
    for class_id, feats in classwise_feats.items():
        if len(feats) == 0:
            scores[str(class_id)] = 0.0
            continue
        feats_norm_before = np.linalg.norm(feats)
        #print(f"[compute_tcav_score] feature value before normalization: {feats_norm_before:.4f}")
        feats_normalized = feats / (feats_norm_before + 1e-8)
        feats_normalized_value=np.linalg.norm(feats_normalized)
        #print(f"[compute_tcav_score] feature value after normalization: {feats_normalized_value}")
        dot_products = np.dot(feats_normalized, cav)
        tcav = float((dot_products > 0).mean())
        scores[str(class_id)] = round(tcav, 4)
        #print(f"🔎 Class {class_id}: mean dot={dot_products.mean():.4f}, std={dot_products.std():.4f}")
        #plt.hist(dot_products, bins=30)
        #plt.title(f"Dot Product Distribution: Class {class_id}")
       # plt.xlabel("dot(cav, feature)")
        #plt.ylabel("Count")
        #plt.show()
    print(f"[compute_tcav_score] TCAV scores -how relevance the concept to the class {class_id} -: {scores}")
    return scores


def get_classwise_feats_from_clusters(client_clusters):
    return {
        class_id: np.stack(data['features_raw']) for class_id, data in client_clusters.items()}

def train_cluster_cav(cluster_feats, negative_feats):
    X = np.vstack([cluster_feats, negative_feats])
    y = np.array([1]*len(cluster_feats) + [0]*len(negative_feats))

    clf = LogisticRegression(max_iter=1000).fit(X, y)
    cav = clf.coef_[0]

    norm_before = np.linalg.norm(cav)
    cav_normalized = cav / (norm_before + 1e-8)
    norm_after = np.linalg.norm(cav_normalized)
    dim = cav.shape[0]
    
    print(f"[train_cluster_cav] Positive samples: {len(cluster_feats)}, Negative samples: {len(negative_feats)}")
    #print(f"[train_cluster_cav] CAV norm before normalization: {norm_before:.4f}")
    #print(f"[train_cluster_cav] CAV norm after normalization: {norm_after:.4f}")
    #print(f"[train_cluster_cav] CAV dimension -shape-: {dim}")
    return cav_normalized, clf  # normalized CAV

def build_concept_signature(cluster_feats):
    flat = cluster_feats.flatten()
    signature = {
        "mean": round(np.mean(flat), 5),
        "variance": round(np.var(flat), 5),
        "skewness": round(skew(flat), 5),
        "kurtosis": round(kurtosis(flat), 5)
    }
    print(f"[build_concept_signature] Concept signature: {signature}")
    return signature


def process_clusters_for_tcav(client_clusters, classwise_feats, model, client_id="client_0"):
    concept_payloads = []
    print(f"\n🔍 [START] Processing clusters for TCAV (Client: {client_id})")
    total_clusters = 0
    skipped_clusters = 0
    
    for class_id, cluster_data in client_clusters.items():
        feats = np.array(cluster_data['features_raw'])
        cluster_ids = np.array(cluster_data['cluster_ids'])
        
        for cluster_idx in np.unique(cluster_ids):
            total_clusters += 1
            pos_mask = (cluster_ids == cluster_idx)
            pos_feats = feats[pos_mask]
            cluster_size = len(pos_feats)
            print(f"  🔸 Cluster {cluster_idx}: {cluster_size} samples")
            if len(pos_feats) < 2:
                print(f"⏩ Skipping cluster {cluster_idx} in class {class_id}: too few positives")
                skipped_clusters += 1
                continue
             # Intra-class negatives (other clusters in same class)
            intra_mask = (cluster_ids != cluster_idx)
            intra_neg_feats = feats[intra_mask]
            
            # Inter-class negatives (all other classes)
            inter_neg_feats = [data['features_raw'] for other_class, data in client_clusters.items() if other_class != class_id]
            inter_neg_feats = np.vstack(inter_neg_feats) if inter_neg_feats else np.empty((0, feats.shape[1]))
            intra_count = len(pos_feats) // 2
            inter_count = len(pos_feats) - intra_count
            intra_sample = intra_neg_feats[np.random.choice(len(intra_neg_feats), size=intra_count, replace=True)] if len(intra_neg_feats) > 0 else np.empty((0, feats.shape[1]))
            inter_sample = inter_neg_feats[np.random.choice(len(inter_neg_feats), size=inter_count, replace=True)] if len(inter_neg_feats) > 0 else np.empty((0, feats.shape[1]))
            neg_feats = np.vstack([intra_sample, inter_sample])
            if len(neg_feats) == 0:
                print(f"⏩ Skipping cluster {cluster_idx} in class {class_id}: too few positives")
                skipped_clusters += 1
                continue
            #  Train CAV
            print(f"Computing CAV...")
            cav, _ = train_cluster_cav(pos_feats, neg_feats)

            # TCAV scoring
            print(f"Computing TCAV...")
            tcav_scores = compute_tcav_score(cav, classwise_feats)

            #  Signature
            print(f"Computing gradient signature...")
            signature = build_concept_signature(pos_feats)
            
            concept_id = f"{client_id}_class{class_id}_cluster{cluster_idx}"
            top_class = max(tcav_scores.items(), key=lambda x: x[1])[0]
            concept_accuracy = float(top_class == str(class_id))
            #purity = 1.0  # Placeholder; compute actual purity if labels are available per feature
            #tcav_mean = float(np.mean(list(tcav_scores.values())))
            #tcav_max = float(max(tcav_scores.values()))
            concept_payloads.append({
                "concept_id": concept_id,
                "concept_size": len(pos_feats),
                "signature": {k: float(v) for k, v in signature.items()},
                "tcav_scores": {k: float(v) for k, v in tcav_scores.items()},
                "true_class": str(class_id),
                "concept_accuracy": concept_accuracy
                #"cluster_purity": purity,
                #"tcav_mean": tcav_mean,
                #"tcav_max": tcav_max
            })
    return concept_payloads

def map_local_concepts_to_global(local_concepts, global_dict, threshold=1):
    mapping = {}
    unmatched_local = []
    unmatched_global = []
    match_details = []

    # Collect all GC signatures (already normalized)
    gc_ids, gc_sigs = [], []
    for gcid, info in global_dict.items():
        sig = info["signature"]
        gc_ids.append(gcid)
        gc_sigs.append([sig["mean"], sig["variance"], sig["skewness"], sig["kurtosis"]])

    gc_signature_map = dict(zip(gc_ids, gc_sigs))

    for local_concept in local_concepts:
        local_sig = [
            local_concept['signature']['mean'],
            local_concept['signature']['variance'],
            local_concept['signature']['skewness'],
            local_concept['signature']['kurtosis']
        ]

        best_match = None
        best_score = float('inf')

        for gcid, gc_sig in gc_signature_map.items():
            dist = euclidean(local_sig, gc_sig)
            if dist < best_score:
                best_score = dist
                best_match = gcid

        if best_score <= threshold:
            mapping[local_concept["concept_id"]] = best_match
            match_details.append({
                "local_concept": local_concept["concept_id"],
                "matched_global_concept": best_match,
                "distance": round(best_score, 4)
            })
        else:
            unmatched_local.append(local_concept["concept_id"])

    matched_global = set(mapping.values())
    unmatched_global = [gcid for gcid in global_dict if gcid not in matched_global]

    print(f"\n[Client Matching] Matched: {len(mapping)}, Unmatched Local: {len(unmatched_local)}, Unmatched Global: {len(unmatched_global)}")
    print("\n🔗 Matched Pairs:")
    for d in match_details:
        print(f"  - {d['local_concept']} ↔ {d['matched_global_concept']} (distance={d['distance']})")

    return mapping, unmatched_local, unmatched_global


**FedCAPEClient**

In [3]:
class FedCAPENumpyClient(NumPyClient):
    def __init__(self, cid, dataset, indices, model_path):
        self.cid = f"client_{cid}"
        self.cid_int=cid
        self.dataset = Subset(dataset, indices)
        self.model_path = model_path
        self.model = None
        self.concepts = []
        self.global_dict = None
        print(f"[{self.cid}] Client initialized with {len(self.dataset)} samples.")

    def _load_model(self):
        if self.model is None:
            print(f"[{self.cid}] Loading model...")
            self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device).eval()

    def get_parameters(self, config):
        self._load_model()
        return [val.cpu().detach().numpy() for val in self.model.parameters()]

    def set_parameters(self, parameters):
        self._load_model()
        state_dict = dict(zip(self.model.state_dict().keys(), map(torch.tensor, parameters)))
        self.model.load_state_dict(state_dict, strict=False)

    def analyze_global_local_differences(self, local_concepts, global_dict):
        print(f"[{self.cid}] 🧩 Analyzing alignment with global dictionary (threshold={1})")
        mapping, unmatched_local, unmatched_global = map_local_concepts_to_global(local_concepts, global_dict, threshold=0.5)

        print(f"[{self.cid}] ✅ Alignment Summary:")
        print(f"  • Total Local Concepts: {len(local_concepts)}")
        print(f"  • Matched Concepts: {len(mapping)}")
        print(f"  • Unmatched Local: {len(unmatched_local)} → {unmatched_local[:5]}")
        print(f"  • Unmatched Global: {len(unmatched_global)} → {unmatched_global[:5]}")

        return mapping, unmatched_local, unmatched_global


    def fit(self, parameters, config):
        try:
            print(f"[{self.cid}] 🔧 Starting fit.")
            self._load_model()
            self.set_parameters(parameters)
            
            if "global_dict" in config and len(config["global_dict"]) == 0:
                print(f"[{self.cid}] ⚠️ Warning: Global dictionary is empty!")
            if "global_dict" in config:
                self.global_dict = json.loads(config["global_dict"])
                print(f"[{self.cid}] ✅ Received global_dict with {len(self.global_dict)} entries.")
            else:
                print(f"[{self.cid}] ⚠️ Warning: Global dictionary not found in config.")
                self.global_dict = {}
            
            loader = DataLoader(self.dataset, batch_size=128, shuffle=False)
            features, images, labels = extract_features_dinov2(loader, self.model, device)
            plot_real_client_distributions({0: client_indices[self.cid_int]}, num_clients=1)
            clusters = cluster_per_class(images, labels, features, sample_fraction=1.0)
            print (f"printing Imagenet subnet classes Clustering for Client {self.cid}")
            visualize_clusters(clusters, dataset_classes=imagenet_subset.classes,client=self.cid,sample_per_class=500,top_k_images=5)
            feats_by_class = get_classwise_feats_from_clusters(clusters)
            self.concepts = process_clusters_for_tcav(clusters, feats_by_class, self.model, self.cid)
            if self.concepts:
                    self.local_to_global_mapping, unmatched_local, unmatched_global = self.analyze_global_local_differences(self.concepts, self.global_dict)
            concepts_json = json.dumps(self.concepts)
            print (f"the fetched concepts Length is {len(concepts_json)}")
            print(f"[{self.cid}] 🔧 Ending fit.")
            return self.get_parameters({}), len(self.dataset), {"concepts": concepts_json}
        except Exception as e:
            print(f"[{self.cid}] ❌ Exception in fit(): {e}")
            raise

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        if not self.concepts and "concepts" in config:
            self.concepts = json.loads(config["concepts"])
        if not self.concepts:
            return 1.0, 1, {"tcav_accuracy": 0.0}
        matches = total = 0
        for concept in self.concepts:
            try:
                concept_id = concept.get("concept_id", "")
                tcav_scores = concept.get("tcav_scores", {})
                if not tcav_scores:
                    continue
                true_class = concept.get("true_class", concept_id.split("_class")[1].split("_")[0])
                concept_accuracy = concept.get("concept_accuracy", 0.0)
                if concept_accuracy > 0.5:  # Threshold for correct prediction
                    matches += 1
                total += 1
            except Exception:
                continue
        accuracy = matches / max(total, 1)
        loss = 1.0 - accuracy
        return loss, total, {"tcav_accuracy": accuracy}

**FedCAPEServer**

In [4]:
# FedCAPE Global Concept Matching with Ordered TCAV and Debuggable Output

from collections import defaultdict, Counter
import numpy as np
import uuid
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import euclidean
from sklearn.metrics.pairwise import cosine_similarity
import json

class FederatedConceptServer:
    def __init__(self, similarity_threshold=1): # meaning the Metrics Vectors are within ~1 standard deviation
        self.similarity_threshold = similarity_threshold

    def _extract_vectors(self, concept):
        sig = concept["signature"]
        return np.array([sig["mean"], sig["variance"], sig["skewness"], sig["kurtosis"]], dtype=np.float32)

    def extract_metrics(self, payloads):
        metrics, metadata = [], []
        for payload in payloads:
            sig = payload['signature']
            metrics.append([sig['mean'], sig['variance'], sig['skewness'], sig['kurtosis']])
            metadata.append({
                'concept_id': payload['concept_id'],
                'concept_size': payload['concept_size'],
                'tcav_scores': payload['tcav_scores'],
                'concept_accuracy': payload.get('concept_accuracy', 0.0),
                'true_class': payload.get('true_class', payload['concept_id'].split('_class')[1].split('_')[0]),
                'client_id': payload['concept_id'].split('_')[0]
            })
        return np.array(metrics), metadata

    def normalize_metrics(self, metrics):
        scaler = StandardScaler()
        return scaler.fit_transform(metrics), scaler

    def compute_similarities(self, metrics_0, metrics_1, meta_0, meta_1, threshold):
        global_concepts = []
        all_candidate_pairs = []

        for i, (m0, meta0) in enumerate(zip(metrics_0, meta_0)):
            pred_0 = max(meta0['tcav_scores'], key=meta0['tcav_scores'].get)
            if str(pred_0) != str(meta0['true_class']):
                continue
            for j, (m1, meta1) in enumerate(zip(metrics_1, meta_1)):
                pred_1 = max(meta1['tcav_scores'], key=meta1['tcav_scores'].get)
                if str(pred_1) != str(meta1['true_class']):
                    continue
                distance = euclidean(m0, m1)
                if distance <= threshold:
                    all_candidate_pairs.append((distance, i, j, m0, m1, meta0, meta1))

        all_candidate_pairs.sort(key=lambda x: x[0])
        used_0 = set()
        used_1 = set()

        for distance, i, j, m0, m1, meta0, meta1 in all_candidate_pairs:
            if i in used_0 or j in used_1:
                continue
            used_0.add(i)
            used_1.add(j)
            all_classes = set(meta0['tcav_scores']) | set(meta1['tcav_scores'])
            dominant_class = max(
                {k: (meta0['tcav_scores'].get(k, 0.0) + meta1['tcav_scores'].get(k, 0.0)) / 2 for k in all_classes},
                key=lambda k: (meta0['tcav_scores'].get(k, 0.0) + meta1['tcav_scores'].get(k, 0.0)) / 2
            )
            global_concepts.append({
                'global_concept_id': str(uuid.uuid4()),
                'signature': {
                    'mean': float((m0[0] + m1[0]) / 2),
                    'variance': float((m0[1] + m1[1]) / 2),
                    'skewness': float((m0[2] + m1[2]) / 2),
                    'kurtosis': float((m0[3] + m1[3]) / 2),
                },
                'tcav_scores': {
                    k: (meta0['tcav_scores'].get(k, 0.0) + meta1['tcav_scores'].get(k, 0.0)) / 2
                    for k in all_classes
                },
                'dominant_class': str(dominant_class),
                'contributing_concepts': [
                    {
                        'client_id': meta0['client_id'],
                        'concept_id': meta0['concept_id'],
                        'concept_size': meta0['concept_size'],
                        'tcav_scores': meta0['tcav_scores'],
                        'concept_accuracy': meta0['concept_accuracy']
                    },
                    {
                        'client_id': meta1['client_id'],
                        'concept_id': meta1['concept_id'],
                        'concept_size': meta1['concept_size'],
                        'tcav_scores': meta1['tcav_scores'],
                        'concept_accuracy': meta1['concept_accuracy']
                    }
                ],
                'distance': float(distance)
            })

        return global_concepts

    def evaluate_global_dictionary(self, global_dict):
        total_groups = len(global_dict)
        pure_groups = total_members = 0
        mismatch_samples = []

        for gc_id, gc_info in global_dict.items():
            dominant = str(gc_info['dominant_class'])
            true_classes = [str(cid.split("_class")[1].split("_")[0]) for cid in gc_info['concept_ids']]
            total_members += len(true_classes)
            if all(cls == dominant for cls in true_classes):
                pure_groups += 1
            else:
                mismatch_samples.append({"gc_id": gc_id, "distribution": dict(Counter(true_classes))})

        purity = pure_groups / total_groups if total_groups > 0 else 0
        avg_members = total_members / total_groups if total_groups > 0 else 0
        print(f"\n📊 [Global Matching Evaluation]: Total={total_groups}, Avg Members={avg_members:.2f}, Purity={purity*100:.2f}%")
        if mismatch_samples:
            print("\n⚡ Mixed-class groups:")
            for m in mismatch_samples[:5]:
                print(f"  - {m['gc_id']}: {m['distribution']}")

    def receive_concepts(self, client_concepts_list):
        if len(client_concepts_list) < 2:
            raise ValueError("Need at least two clients to align concepts")

        all_metrics, all_metadata = [], []
        for concepts in client_concepts_list:
            metrics, meta = self.extract_metrics(concepts)
            all_metrics.append(metrics)
            all_metadata.append(meta)

        combined_metrics = np.vstack(all_metrics)
        normalized_metrics, _ = self.normalize_metrics(combined_metrics)

        split_0 = normalized_metrics[:len(all_metrics[0])]
        split_1 = normalized_metrics[len(all_metrics[0]):]

        global_concepts = self.compute_similarities(
            split_0, split_1, all_metadata[0], all_metadata[1], threshold=self.similarity_threshold
        )

        ordered_all = sorted(global_concepts, key=lambda gc: max(gc['tcav_scores'].values()), reverse=True)
        print("\n🌐 Top 5 contributing global concepts (by TCAV max score):")
        for gc in ordered_all[:5]:
            print(f"  - Global Concept: {gc['global_concept_id']} (TCAV max={max(gc['tcav_scores'].values()):.3f})")
            for contrib in gc['contributing_concepts']:
                print(f"    - {contrib['concept_id']} from {contrib['client_id']} (size={contrib['concept_size']})")
                for cls, score in contrib['tcav_scores'].items():
                    print(f"        → Class {cls} TCAV score: {score:.4f}")
                print(f"        • Concept Accuracy: {contrib['concept_accuracy']:.3f}")

        global_dict = {}
        for i, gc in enumerate(global_concepts):
            concept_ids = [c['concept_id'] for c in gc['contributing_concepts']]
            global_dict[f"GC_{i}"] = {
                "global_concept_id": gc["global_concept_id"],
                "signature": gc["signature"],
                "tcav_scores": gc["tcav_scores"],
                "dominant_class": gc["dominant_class"],
                "contributing_concepts": gc["contributing_concepts"],
                "concept_ids": concept_ids,
                "distance": gc["distance"]
            }

        self.evaluate_global_dictionary(global_dict)

        class_to_concepts = defaultdict(list)
        for gc in global_dict.values():
            cls = gc.get('dominant_class')
            score = gc['tcav_scores'].get(cls, 0.0)
            class_to_concepts[cls].append((gc['global_concept_id'], score))

        print("\n✅ Available Classes:", sorted(class_to_concepts.keys()))
        print("\n📘 Detailed Global Concepts by Class:")
        for cls, concept_scores in class_to_concepts.items():
            print(f"\n🔹 Class {cls}:")
            top_concepts = sorted(concept_scores, key=lambda x: x[1], reverse=True)
            for gc_id, tcav in top_concepts:
                gc = next((g for g in global_concepts if g['global_concept_id'] == gc_id), None)
                if not gc:
                    continue
                print(f"  📦 Global Concept {gc_id}")
                print(f"     • Signature: μ={gc['signature']['mean']:.5f}, σ²={gc['signature']['variance']:.5f}, skew={gc['signature']['skewness']:.5f}, kurt={gc['signature']['kurtosis']:.5f}")
                print(f"     • TCAV scores:")
                for k, v in gc['tcav_scores'].items():
                    print(f"       - Class {k}: {v:.4f}")
                print(f"     • Contributors:")
                for contrib in gc['contributing_concepts']:
                    print(f"       - {contrib['client_id']}::{contrib['concept_id']} (size={contrib['concept_size']})")
                    for k, v in contrib['tcav_scores'].items():
                        print(f"           → Class {k}: {v:.4f}")
                    print(f"           • Concept Accuracy: {contrib['concept_accuracy']:.3f}")
                print(f"     • Distance between contributors: {gc['distance']:.4f}")

        print("\n📌 Global Concepts per Class (Top Contributors):")
        for cls, concept_scores in class_to_concepts.items():
            top_concepts = sorted(concept_scores, key=lambda x: x[1], reverse=True)[:5]
            print(f"  → Class {cls}:")
            for gc_id, score in top_concepts:
                print(f"      - {gc_id}: TCAV={score:.4f}")

        with open("class_concept_map.json", "w") as f:
            json.dump({cls: sorted(concept_scores, key=lambda x: x[1], reverse=True)
                       for cls, concept_scores in class_to_concepts.items()}, f, indent=2)

        return global_dict


In [9]:
# FedCAPE Global Concept Matching with Ordered TCAV and Debuggable Output

from collections import defaultdict, Counter
import numpy as np
import uuid
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import euclidean
from sklearn.metrics.pairwise import cosine_similarity
import json

class FederatedConceptServer:
    def __init__(self, similarity_threshold=1):   # meaning the Metrics Vectors are within ~1 standard deviation
        self.similarity_threshold = similarity_threshold

    def _extract_vectors(self, concept):
        sig = concept["signature"]
        return np.array([sig["mean"], sig["variance"], sig["skewness"], sig["kurtosis"]], dtype=np.float32)

    def extract_metrics(self, payloads):
        metrics, metadata = [], []
        for payload in payloads:
            sig = payload['signature']
            metrics.append([sig['mean'], sig['variance'], sig['skewness'], sig['kurtosis']])
            metadata.append({
                'concept_id': payload['concept_id'],
                'concept_size': payload['concept_size'],
                'tcav_scores': payload['tcav_scores'],
                'concept_accuracy': payload.get('concept_accuracy', 0.0),
                'true_class': payload.get('true_class', payload['concept_id'].split('_class')[1].split('_')[0]),
                'client_id': payload['concept_id'].split('_')[0]
            })
        return np.array(metrics), metadata

    def normalize_metrics(self, metrics):
        scaler = StandardScaler()
        return scaler.fit_transform(metrics), scaler

    def compute_similarities(self, metrics_0, metrics_1, meta_0, meta_1, threshold):
        global_concepts = []
        seen_pairs = set()
        used_concepts = set()

        for i, (m0, meta0) in enumerate(zip(metrics_0, meta_0)):
            pred_0 = max(meta0['tcav_scores'], key=meta0['tcav_scores'].get)
            if str(pred_0) != str(meta0['true_class']) or meta0['concept_id'] in used_concepts:
                continue

            for j, (m1, meta1) in enumerate(zip(metrics_1, meta_1)):
                pair_key = tuple(sorted([meta0['concept_id'], meta1['concept_id']]))
                if pair_key in seen_pairs or meta1['concept_id'] in used_concepts:
                    continue
                seen_pairs.add(pair_key)

                pred_1 = max(meta1['tcav_scores'], key=meta1['tcav_scores'].get)
                if str(pred_1) != str(meta1['true_class']):
                    continue

                distance = euclidean(m0, m1)
                if distance <= threshold:
                    all_classes = set(meta0['tcav_scores']) | set(meta1['tcav_scores'])
                    used_concepts.update([meta0['concept_id'], meta1['concept_id']])
                    dominant_class = max(
                        {k: (meta0['tcav_scores'].get(k, 0.0) + meta1['tcav_scores'].get(k, 0.0)) / 2 for k in all_classes},
                        key=lambda k: (meta0['tcav_scores'].get(k, 0.0) + meta1['tcav_scores'].get(k, 0.0)) / 2
                    )
                    global_concepts.append({
                        'global_concept_id': str(uuid.uuid4()),
                        'signature': {
                            'mean': float((m0[0] + m1[0]) / 2),
                            'variance': float((m0[1] + m1[1]) / 2),
                            'skewness': float((m0[2] + m1[2]) / 2),
                            'kurtosis': float((m0[3] + m1[3]) / 2),
                        },
                        'tcav_scores': {
                            k: (meta0['tcav_scores'].get(k, 0.0) + meta1['tcav_scores'].get(k, 0.0)) / 2
                            for k in all_classes
                        },
                        'dominant_class': str(dominant_class),
                        'contributing_concepts': [
                            {
                                'client_id': meta0['client_id'],
                                'concept_id': meta0['concept_id'],
                                'concept_size': meta0['concept_size'],
                                'tcav_scores': meta0['tcav_scores'],
                                'concept_accuracy': meta0['concept_accuracy']
                            },
                            {
                                'client_id': meta1['client_id'],
                                'concept_id': meta1['concept_id'],
                                'concept_size': meta1['concept_size'],
                                'tcav_scores': meta1['tcav_scores'],
                                'concept_accuracy': meta1['concept_accuracy']
                            }
                        ],
                        'distance': float(distance)
                    })
        return global_concepts

    def evaluate_global_dictionary(self, global_dict):
        total_groups = len(global_dict)
        pure_groups = total_members = 0
        mismatch_samples = []

        for gc_id, gc_info in global_dict.items():
            dominant = str(gc_info['dominant_class'])
            true_classes = [str(cid.split("_class")[1].split("_")[0]) for cid in gc_info['concept_ids']]
            total_members += len(true_classes)
            if all(cls == dominant for cls in true_classes):
                pure_groups += 1
            else:
                mismatch_samples.append({"gc_id": gc_id, "distribution": dict(Counter(true_classes))})

        purity = pure_groups / total_groups if total_groups > 0 else 0
        avg_members = total_members / total_groups if total_groups > 0 else 0
        print(f"\n📊 [Global Matching Evaluation]: Total={total_groups}, Avg Members={avg_members:.2f}, Purity={purity*100:.2f}%")
        if mismatch_samples:
            print("\n⚡ Mixed-class groups:")
            for m in mismatch_samples[:5]:
                print(f"  - {m['gc_id']}: {m['distribution']}")

    def receive_concepts(self, client_concepts_list):
        if len(client_concepts_list) < 2:
            raise ValueError("Need at least two clients to align concepts")

        all_metrics, all_metadata = [], []
        for concepts in client_concepts_list:
            metrics, meta = self.extract_metrics(concepts)
            all_metrics.append(metrics)
            all_metadata.append(meta)

        combined_metrics = np.vstack(all_metrics)
        normalized_metrics, _ = self.normalize_metrics(combined_metrics)

        split_0 = normalized_metrics[:len(all_metrics[0])]
        split_1 = normalized_metrics[len(all_metrics[0]):]

        global_concepts = self.compute_similarities(
            split_0, split_1, all_metadata[0], all_metadata[1], threshold=self.similarity_threshold
        )

        ordered_all = sorted(global_concepts, key=lambda gc: max(gc['tcav_scores'].values()), reverse=True)
        print("\n🌐 Top 5 contributing global concepts (by TCAV max score):")
        for gc in ordered_all[:5]:
            print(f"  - Global Concept: {gc['global_concept_id']} (TCAV max={max(gc['tcav_scores'].values()):.3f})")
            for contrib in gc['contributing_concepts']:
                print(f"    - {contrib['concept_id']} from {contrib['client_id']} (size={contrib['concept_size']})")
                for cls, score in contrib['tcav_scores'].items():
                    print(f"        → Class {cls} TCAV score: {score:.4f}")
                print(f"        • Concept Accuracy: {contrib['concept_accuracy']:.3f}")

        global_dict = {}
        for i, gc in enumerate(global_concepts):
            concept_ids = [c['concept_id'] for c in gc['contributing_concepts']]
            global_dict[f"GC_{i}"] = {
                "global_concept_id": gc["global_concept_id"],
                "signature": gc["signature"],
                "tcav_scores": gc["tcav_scores"],
                "dominant_class": gc["dominant_class"],
                "contributing_concepts": gc["contributing_concepts"],
                "concept_ids": concept_ids,
                "distance": gc["distance"]
            }

        self.evaluate_global_dictionary(global_dict)

        class_to_concepts = defaultdict(list)
        for gc in global_dict.values():
            cls = gc.get('dominant_class')
            score = gc['tcav_scores'].get(cls, 0.0)
            class_to_concepts[cls].append((gc['global_concept_id'], score))

        print("\n✅ Available Classes:", sorted(class_to_concepts.keys()))
        print("\n📘 Detailed Global Concepts by Class:")
        for cls, concept_scores in class_to_concepts.items():
            print(f"\n🔹 Class {cls}:")
            top_concepts = sorted(concept_scores, key=lambda x: x[1], reverse=True)
            for gc_id, tcav in top_concepts:
                gc = next((g for g in global_concepts if g['global_concept_id'] == gc_id), None)
                if not gc:
                    continue
                print(f"  📦 Global Concept {gc_id}")
                print(f"     • Signature: μ={gc['signature']['mean']:.5f}, σ²={gc['signature']['variance']:.5f}, skew={gc['signature']['skewness']:.5f}, kurt={gc['signature']['kurtosis']:.5f}")
                print(f"     • TCAV scores:")
                for k, v in gc['tcav_scores'].items():
                    print(f"       - Class {k}: {v:.4f}")
                print(f"     • Contributors:")
                for contrib in gc['contributing_concepts']:
                    print(f"       - {contrib['client_id']}::{contrib['concept_id']} (size={contrib['concept_size']})")
                    for k, v in contrib['tcav_scores'].items():
                        print(f"           → Class {k}: {v:.4f}")
                    print(f"           • Concept Accuracy: {contrib['concept_accuracy']:.3f}")
                print(f"     • Distance between contributors: {gc['distance']:.4f}")

        print("\n📌 Global Concepts per Class (Top Contributors):")
        for cls, concept_scores in class_to_concepts.items():
            top_concepts = sorted(concept_scores, key=lambda x: x[1], reverse=True)[:5]
            print(f"  → Class {cls}:")
            for gc_id, score in top_concepts:
                print(f"      - {gc_id}: TCAV={score:.4f}")

        with open("class_concept_map.json", "w") as f:
            json.dump({cls: sorted(concept_scores, key=lambda x: x[1], reverse=True)
                       for cls, concept_scores in class_to_concepts.items()}, f, indent=2)

        return global_dict


**FedCAPEStrategy**

In [5]:
class FedCAPEStrategy(FedAvg):
    def __init__(self, initial_parameters, **kwargs):
        super().__init__(**kwargs)
        self.initial_parameters = initial_parameters
        self.server = FederatedConceptServer(similarity_threshold=1)  # meaning the Metrics Vectors are within ~1 standard deviation 
        self.global_dict = {}
        self.client_concept_map = {}

    def initialize_parameters(self, client_manager):
        return self.initial_parameters

    def aggregate_fit(self, server_round, results, failures):
        if failures:
            print(f"⚠️ Failures in round {server_round}: {failures}")
        client_concepts = []
        self.client_concept_map = {}
        for client, res in results:
            cid = client.cid
            raw_json = res.metrics.get("concepts", "[]")
            parsed = json.loads(raw_json)
            client_concepts.append(parsed)
            self.client_concept_map[cid] = parsed
        if not client_concepts:
            print("❌ No valid concepts received.")
            _, first_result = results[0]
            return first_result.parameters, {}
        self.global_dict = self.server.receive_concepts(client_concepts)
        print(f"[Server] ✅ Global dictionary built with {len(self.global_dict)} entries.")
        _, first_res = results[0]
        return first_res.parameters, {}

    def configure_fit(self, server_round, parameters, client_manager):
        config = {"global_dict": json.dumps(self.global_dict)}
        fit_ins = FitIns(parameters, config)
        return [(client, fit_ins) for client in client_manager.all().values()]

    def configure_evaluate(self, server_round, parameters, client_manager):
        evaluate_ins = []
        for client in client_manager.all().values():
            cid = client.cid
            concept_json = json.dumps(self.client_concept_map.get(cid, []))
            config = {"concepts": concept_json}
            evaluate_ins.append((client, EvaluateIns(parameters, config)))
        return evaluate_ins

    def aggregate_evaluate(self, server_round, results, failures):
        if not results:
            return 1.0, {"tcav_accuracy": 0.0}
        total = sum(res.num_examples for _, res in results)
        avg_acc = sum(res.num_examples * float(res.metrics.get("tcav_accuracy", 0.0)) for _, res in results)
        avg_accuracy = avg_acc / max(total, 1)
        print(f"📊 [Server] TCAV Accuracy: {avg_accuracy:.4f}")
        return 1.0 - avg_accuracy, {"tcav_accuracy": round(avg_accuracy, 4)}


**FedCAPESimulation**

In [None]:
# determining number of clients and prepare the data distribution 
NUM_CLIENTS = 2
client_indices = create_dirichlet_clients(imagenet_subset, num_clients=NUM_CLIENTS, alpha=1)
print("ImageNet-style dataset loaded and Assigned to client")
def client_fn(context: Context):
    cid = int(context.node_config['partition-id'])
    return FedCAPENumpyClient(cid, imagenet_subset, client_indices[cid], model_path="").to_client()

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
initial_parameters = ndarrays_to_parameters([p.detach().cpu().numpy() for p in model.parameters()])
strategy = FedCAPEStrategy(
    initial_parameters=initial_parameters,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS
)
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
    client_resources={"num_cpus": 2, "num_gpus": 1},
    ray_init_args={"runtime_env": {"pip": ["torch", "flwr", "xformers", "umap-learn", "kneed", "scikit-learn"]}}
)