<a href="https://colab.research.google.com/github/Drax929/FED-CL/blob/main/FED_CL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision flwr numpy tqdm



In [None]:
import random
import math
import copy
import numpy as np
from typing import Tuple, List, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from torchvision import datasets, transforms
from tqdm import tqdm

import flwr as fl

  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
class TwoCropTransform:
    """Return two strongly/weakly augmented views of an image for contrastive learning."""
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        return self.base_transform(x), self.base_transform(x)

In [None]:
class ContrastiveMNIST(Dataset):
    """Wrapper to provide two augmented views for MNIST."""
    def __init__(self, mnist_dataset, transform):
        self.mnist_dataset = mnist_dataset
        self.transform = transform  # TwoCropTransform

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        img, target = self.mnist_dataset[idx]
        x1, x2 = self.transform(img)
        return x1, x2, target

In [None]:
def partition_dirichlet(dataset: datasets.MNIST, num_clients: int, alpha: float=0.5, seed: int=0):
    """
    Partition indices of dataset to num_clients with Dirichlet distribution per label.
    Returns list of index lists for each client.
    """
    np.random.seed(seed)
    labels = np.array(dataset.targets)
    num_classes = len(np.unique(labels))
    idx_by_class = [np.where(labels == c)[0] for c in range(num_classes)]

    client_indices = [[] for _ in range(num_clients)]
    for c in range(num_classes):
        idx_c = idx_by_class[c]
        # draw proportions
        proportions = np.random.dirichlet(alpha=[alpha]*num_clients)
        # split idx_c according to proportions
        proportions = (proportions / proportions.sum()) * len(idx_c)
        proportions = np.round(proportions).astype(int)

        # adjust rounding issues
        while proportions.sum() > len(idx_c):
            j = np.argmax(proportions)
            proportions[j] -= 1
        while proportions.sum() < len(idx_c):
            j = np.argmin(proportions)
            proportions[j] += 1

        start = 0
        for k in range(num_clients):
            cnt = proportions[k]
            if cnt > 0:
                client_indices[k].extend(idx_c[start:start+cnt].tolist())
                start += cnt

    # ensure non-empty
    for k in range(num_clients):
        if len(client_indices[k]) == 0:
            # assign a random sample
            client_indices[k].append(np.random.choice(len(dataset)))
    return client_indices


In [None]:
class SmallConvEncoder(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        # MNIST 1x28x28
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),  # 28x28
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),            # 14x14
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),            # 7x7
            nn.Flatten(),
            nn.Linear(64*7*7, 256),
            nn.ReLU(),
            nn.Linear(256, out_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=128, out_dim=64):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.head(x)


In [None]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5, device='cpu'):
        super().__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()
        self.device = device

    def forward(self, z1, z2):
        """
        z1, z2: tensors of shape (B, D) - projection outputs (not necessarily normalized)
        """
        batch_size = z1.size(0)
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)
        representations = torch.cat([z1, z2], dim=0)  # 2B x D

        # similarity matrix
        sim = torch.matmul(representations, representations.T) / self.temperature  # 2B x 2B

        # mask to remove similarity with itself
        large_neg = -1e9
        mask = (~torch.eye(2*batch_size, dtype=torch.bool)).to(self.device)
        sim_masked = sim.masked_fill(~mask, large_neg)

        # positives: i <-> i+B
        labels = torch.arange(batch_size).to(self.device)
        positives = torch.cat([torch.diag(sim, batch_size), torch.diag(sim, -batch_size)])  # len 2B?
        # Simpler: compute logits and targets per original SimCLR implementation
        logits = torch.cat([
            torch.cat([sim[i, batch_size:batch_size+batch_size], sim[i, :batch_size]], dim=0).unsqueeze(0)
            for i in range(batch_size)
        ], dim=0)  # (B, 2B) -> but careful. For simplicity use a robust implementation below.

        # We'll implement the standard approach: for each of the 2B examples, positives index is (i+batch_size) mod (2B)
        logits_all = sim_masked  # (2B x 2B) with -inf on diag
        labels_all = (torch.arange(2*batch_size) + batch_size) % (2*batch_size)
        labels_all = labels_all.to(self.device)

        loss = self.criterion(logits_all, labels_all)
        return loss

In [None]:
def nt_xent_loss(z1, z2, temperature=0.5):
    """Simple, robust NT-Xent for two augmented batches z1,z2 (B,D)."""
    device = z1.device
    B = z1.size(0)
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    z = torch.cat([z1, z2], dim=0)  # 2B x D
    sim = torch.matmul(z, z.T) / temperature  # 2B x 2B
    # mask out self-similarity
    mask = (~torch.eye(2*B, dtype=torch.bool)).to(device)
    sim_masked = sim.masked_select(mask).view(2*B, 2*B-1)

    # positives: for i in [0..B-1], positive index is i+B and vice versa
    positives = torch.cat([torch.diag(sim, B), torch.diag(sim, -B)], dim=0).unsqueeze(1)  # 2B x 1

    # logits: concatenate positives and negatives
    logits = torch.cat([positives, sim_masked], dim=1)
    labels = torch.zeros(2*B, dtype=torch.long).to(device)
    loss = F.cross_entropy(logits, labels)
    return loss


In [None]:
class PrivCLClient(fl.client.NumPyClient):
    def __init__(self, model: nn.Module, proj: nn.Module, train_loader: DataLoader,
                 device: torch.device, local_epochs: int = 1, lr=1e-3, tau=0.5):
        self.device = device
        self.model = model.to(self.device)
        self.proj = proj.to(self.device)
        self.train_loader = train_loader
        self.local_epochs = local_epochs
        self.lr = lr
        self.tau = tau

    def get_parameters(self, config):
        # Return model + proj parameters as numpy arrays
        params = []
        for _, p in list(self.model.state_dict().items()) + list(self.proj.state_dict().items()):
            params.append(p.cpu().numpy())
        return params

    def set_parameters(self, parameters):
        # Parameters is a list of numpy arrays in same order
        # Reconstruct state dicts
        model_state = self.model.state_dict()
        proj_state = self.proj.state_dict()
        # Flatten keys
        all_keys = list(model_state.keys()) + list(proj_state.keys())
        assert len(parameters) == len(all_keys)
        new_state = {}
        i = 0
        for k in model_state.keys():
            new_state[k] = torch.tensor(parameters[i])
            i += 1
        self.model.load_state_dict(new_state, strict=False)
        proj_state_new = {}
        for k in proj_state.keys():
            proj_state_new[k] = torch.tensor(parameters[i])
            i += 1
        self.proj.load_state_dict(proj_state_new, strict=False)

    def fit(self, parameters, config):
        # set params
        if parameters is not None:
            self.set_parameters(parameters)

        optimizer = optim.Adam(list(self.model.parameters()) + list(self.proj.parameters()), lr=self.lr)
        device = self.device
        self.model.train()
        self.proj.train()
        for epoch in range(self.local_epochs):
            loop = tqdm(self.train_loader, desc=f"Client local epoch {epoch+1}/{self.local_epochs}", leave=False)
            for x1, x2, _ in loop:
                x1 = x1.to(device)
                x2 = x2.to(device)
                optimizer.zero_grad()
                h1 = self.model(x1)
                h2 = self.model(x2)
                z1 = self.proj(h1)
                z2 = self.proj(h2)
                loss = nt_xent_loss(z1, z2, temperature=self.tau)
                loss.backward()
                optimizer.step()
                loop.set_postfix(loss=loss.item())

        # return updated parameters
        return self.get_parameters({}), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        # no server-side evaluation in baseline; return dummy loss and metrics
        return float(0.0), len(self.train_loader.dataset), {"accuracy": 0.0}

In [None]:
def create_data_loaders(num_clients=5, alpha=0.5, batch_size=128):
    # Download MNIST
    base_transform = transforms.Compose([
        transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    two_crop = TwoCropTransform(base_transform)
    mnist_train = datasets.MNIST(".", train=True, download=True)
    mnist_test = datasets.MNIST(".", train=False, download=True)

    # Partition indices non-iid with dirichlet
    client_idxs = partition_dirichlet(mnist_train, num_clients=num_clients, alpha=alpha, seed=42)

    client_loaders = []
    for k in range(num_clients):
        subset = Subset(mnist_train, client_idxs[k])
        wrapped = ContrastiveMNIST(subset, two_crop)
        loader = DataLoader(wrapped, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
        client_loaders.append(loader)

    # a simple validation set (not used for FL training)
    test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    test_wrapped = ContrastiveMNIST(mnist_test, TwoCropTransform(test_transform))
    test_loader = DataLoader(test_wrapped, batch_size=batch_size, shuffle=False)
    return client_loaders, test_loader

In [None]:
def client_fn(cid: str, client_loaders, device, local_epochs):
    """Create a Flower client instance for the given client id (cid)."""
    idx = int(cid)
    model = SmallConvEncoder(out_dim=128)
    proj = ProjectionHead(in_dim=128, hidden_dim=128, out_dim=64)
    client = PrivCLClient(model=model, proj=proj, train_loader=client_loaders[idx],
                          device=device, local_epochs=local_epochs, lr=1e-3, tau=0.5)
    return client

In [None]:
def main_simulation(num_clients=5, rounds=10, local_epochs=1, batch_size=128, alpha=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)
    client_loaders, test_loader = create_data_loaders(num_clients=num_clients, alpha=alpha, batch_size=batch_size)

    # Create a factory for clients
    def _client_fn(cid: str):
        return client_fn(cid, client_loaders, device=device, local_epochs=local_epochs)

    # Start simulation (in-proc)
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # use all clients each round (for simplicity)
        min_fit_clients=num_clients,
        min_available_clients=num_clients
    )

    print(f"Starting simulation: {num_clients} clients, {rounds} rounds, local_epochs={local_epochs}")
    fl.simulation.start_simulation(
        client_fn=_client_fn,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=rounds),
        strategy=strategy,
    )


In [None]:
!pip install -U "flwr[simulation]"



In [None]:
if __name__ == "__main__":
    # Example run
    main_simulation(num_clients=5, rounds=5, local_epochs=1, batch_size=128, alpha=0.5)

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout


Device: cpu
Starting simulation: 5 clients, 5 rounds, local_epochs=1


  return datetime.utcnow().replace(tzinfo=utc)
2025-09-28 16:43:10,048	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 2.0, 'memory': 7947581031.0, 'node:172.28.0.12': 1.0, 'object_store_memory': 3973790515.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 2 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=1428)[0m 2025-09-28 16:43:29.500647: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory fo

In [None]:
!pip install torch flwr numpy tqdm networkx



In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Looking in indexes: https://download.pytorch.org/whl/cpu


In [None]:
!pip install torch==2.8.0 torchvision==0.19.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cpu
!pip install torch-geometric

Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torchvision==0.19.0
  Downloading https://download.pytorch.org/whl/cpu/torchvision-0.19.0%2Bcpu-cp312-cp312-linux_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.
[31mERROR: Cannot install torch==2.8.0 and torchvision==0.19.0+cpu because these package versions have conflicting dependencies.[0m[31m
[0m
The conflict is caused by:
    The user requested torch==2.8.0
    torchvision 0.19.0+cpu depends on torch==2.4.0

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip to attempt to solve the dependency conflict

[31mERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-res

In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.5.1+cpu.html


Looking in links: https://data.pyg.org/whl/torch-2.5.1+cpu.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcpu/torch_scatter-2.1.2%2Bpt25cpu-cp312-cp312-linux_x86_64.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.7/547.7 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcpu/torch_sparse-0.6.18%2Bpt25cpu-cp312-cp312-linux_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcpu/torch_cluster-1.6.3%2Bpt25cpu-cp312-cp312-linux_x86_64.whl (792 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m792.1/792.1 kB[0m [31m39.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcpu/torch_spline_conv-1.2.2

In [None]:
import math
import random
from typing import List, Tuple

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import flwr as fl

import networkx as nx
from torch_geometric.data import Data as PyGData
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv, global_mean_pool

  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
def synth_graph(num_nodes=20, p=0.2, feat_dim=16, seed=None):
    if seed is not None:
        np.random.seed(seed)
    G = nx.erdos_renyi_graph(num_nodes, p)
    # add node features: random vectors
    for n in G.nodes():
        G.nodes[n]['x'] = np.random.randn(feat_dim).astype(np.float32)
    pyg = from_networkx(G)
    # from_networkx sets 'x' as attribute list; convert to tensor if needed
    if hasattr(pyg, 'x') and pyg.x is None:
        pyg.x = torch.randn((num_nodes, feat_dim), dtype=torch.float32)
    return pyg

In [None]:
def synth_timeseries(length=200, freq=1.0, noise=0.1, phase=0.0, seed=None):
    """Generate synthetic 1D timeseries (sine + noise)."""
    if seed is not None:
        np.random.seed(seed)
    t = np.arange(length)
    series = np.sin(2 * math.pi * freq * (t / length) + phase) + noise * np.random.randn(length)
    series = series.astype(np.float32)
    return series

In [None]:
def create_client_synthetic_data(num_clients=5, graph_nodes=20, feat_dim=16, ts_len=200):
    clients = []
    for k in range(num_clients):
        seed = 100 + k
        p = 0.15 + 0.05 * (k % 3)  # vary density
        graph = synth_graph(num_nodes=graph_nodes, p=p, feat_dim=feat_dim, seed=seed)
        # timeseries parameters vary across clients
        freq = 1.0 + 0.1 * (k % 4)
        phase = 2.0 * math.pi * (k / max(1, num_clients))
        noise = 0.05 + 0.05 * (k % 3)
        ts = synth_timeseries(length=ts_len, freq=freq, noise=noise, phase=phase, seed=seed+1)
        clients.append({'graph': graph, 'ts': ts})
    return clients

In [None]:
def graph_edge_dropout(data: PyGData, p_drop=0.2):
    """
    Randomly drop edges with probability p_drop and return a new PyGData.
    """
    edge_index = data.edge_index.clone()
    E = edge_index.size(1)
    mask = torch.rand(E) > p_drop
    new_ei = edge_index[:, mask]
    new_data = PyGData(x=data.x.clone(), edge_index=new_ei)
    return new_data

In [None]:
def graph_feature_masking(data: PyGData, p_mask=0.1):
    x = data.x.clone()
    mask = (torch.rand(x.size()) > p_mask).float()
    x = x * mask
    return PyGData(x=x, edge_index=data.edge_index.clone())

In [None]:
def ts_jitter(ts: np.ndarray, sigma=0.05):
    return (ts + np.random.randn(*ts.shape) * sigma).astype(np.float32)

In [None]:
def ts_scaling(ts: np.ndarray, sigma=0.1):
    factor = np.random.normal(loc=1.0, scale=sigma)
    return (ts * factor).astype(np.float32)

In [None]:
def ts_window(ts: np.ndarray, window_size=64, shift=None):
    L = len(ts)
    if shift is None:
        start = np.random.randint(0, max(1, L - window_size + 1))
    else:
        start = shift
    return ts[start:start+window_size].copy()

In [None]:
class GraphEncoder(nn.Module):
    def __init__(self, in_dim=16, hidden=64, out_dim=128):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.lin = nn.Linear(hidden, out_dim)

    def forward(self, data: PyGData):
        x, edge_index, batch = data.x, data.edge_index, getattr(data, 'batch', None)
        h = F.relu(self.conv1(x, edge_index))
        h = F.relu(self.conv2(h, edge_index))
        # graph-level embedding via mean pool (if batch not set, assume single graph)
        if batch is None:
            # single graph: compute mean over nodes
            g = h.mean(dim=0, keepdim=True)  # 1 x hidden
        else:
            g = global_mean_pool(h, batch)  # B x hidden
        z = self.lin(g)  # B x out_dim (or 1 x out_dim)
        return z

In [None]:
class TemporalEncoder(nn.Module):
    def __init__(self, in_channels=1, hidden=64, out_dim=128, kernel_size=5):
        super().__init__()
        # simple 1D conv stack
        self.conv1 = nn.Conv1d(in_channels, hidden, kernel_size=kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv1d(hidden, hidden, kernel_size=kernel_size, padding=kernel_size//2)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.lin = nn.Linear(hidden, out_dim)

    def forward(self, ts_batch: torch.Tensor):
        # ts_batch: (B, window_len) -> (B, 1, window_len)
        x = ts_batch.unsqueeze(1)
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = self.pool(h).squeeze(-1)  # B x hidden
        z = self.lin(h)  # B x out_dim
        return z

In [None]:
class FusionProjection(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim//2),
            nn.ReLU(),
            nn.Linear(in_dim//2, proj_dim)
        )
    def forward(self, x):
        return self.net(x)

In [None]:
def nt_xent_loss(z1, z2, temperature=0.5):
    device = z1.device
    B = z1.size(0)
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    z = torch.cat([z1, z2], dim=0)  # 2B x D
    sim = torch.matmul(z, z.T) / temperature  # 2B x 2B
    mask = (~torch.eye(2*B, dtype=torch.bool)).to(device)
    sim_masked = sim.masked_select(mask).view(2*B, 2*B-1)
    positives = torch.cat([torch.diag(sim, B), torch.diag(sim, -B)], dim=0).unsqueeze(1)
    logits = torch.cat([positives, sim_masked], dim=1)
    labels = torch.zeros(2*B, dtype=torch.long).to(device)
    loss = F.cross_entropy(logits, labels)
    return loss

In [None]:
def graph_info_loss(g1, g2, neg_g, margin=0.0):
    # g1, g2: positive pooled graph embeddings (B x D), neg_g: negative graphs (B x D)
    # maximize similarity between g1 and g2, minimize with neg_g
    pos_sim = (F.normalize(g1, dim=1) * F.normalize(g2, dim=1)).sum(dim=1)
    neg_sim = (F.normalize(g1, dim=1) * F.normalize(neg_g, dim=1)).sum(dim=1)
    loss = - (pos_sim - neg_sim).mean()  # push pos up and neg down
    return loss

In [None]:
def temporal_contrastive(z_anchor, z_pos, temperature=0.5):
    return nt_xent_loss(z_anchor, z_pos, temperature=temperature)

In [None]:
class MultimodalPrivCLClient(fl.client.NumPyClient):
    def __init__(self, graph_enc: nn.Module, temp_enc: nn.Module, fusion: nn.Module,
                 client_data: dict, device: torch.device,
                 local_steps:int=50, batch_size=16, window_size=64,
                 lr=1e-3, weights=(1.0, 0.5, 0.5), tau=0.5):
        """
        client_data: {'graph': PyGData, 'ts': numpy array}
        weights: (w_fuse, w_graph, w_temp)
        """
        self.device = device
        self.graph_enc = graph_enc.to(device)
        self.temp_enc = temp_enc.to(device)
        self.fusion = fusion.to(device)
        self.client_data = client_data
        self.local_steps = local_steps
        self.batch_size = batch_size
        self.window_size = window_size
        self.lr = lr
        self.w_fuse, self.w_graph, self.w_temp = weights
        self.tau = tau

    def get_parameters(self, config):
        params = []
        for _, p in list(self.graph_enc.state_dict().items()) + list(self.temp_enc.state_dict().items()) + list(self.fusion.state_dict().items()):
            params.append(p.cpu().numpy())
        return params

    def set_parameters(self, parameters):
        # assign parameters in the same order
        ge_state = self.graph_enc.state_dict()
        te_state = self.temp_enc.state_dict()
        fu_state = self.fusion.state_dict()
        all_keys = list(ge_state.keys()) + list(te_state.keys()) + list(fu_state.keys())
        assert len(parameters) == len(all_keys)
        i = 0
        new_ge = {}
        for k in ge_state.keys():
            new_ge[k] = torch.tensor(parameters[i])
            i += 1
        self.graph_enc.load_state_dict(new_ge, strict=False)
        new_te = {}
        for k in te_state.keys():
            new_te[k] = torch.tensor(parameters[i])
            i += 1
        self.temp_enc.load_state_dict(new_te, strict=False)
        new_fu = {}
        for k in fu_state.keys():
            new_fu[k] = torch.tensor(parameters[i])
            i += 1
        self.fusion.load_state_dict(new_fu, strict=False)

    def fit(self, parameters, config):
        if parameters is not None:
            self.set_parameters(parameters)

        opt = optim.Adam(list(self.graph_enc.parameters()) +
                         list(self.temp_enc.parameters()) +
                         list(self.fusion.parameters()), lr=self.lr)

        self.graph_enc.train(); self.temp_enc.train(); self.fusion.train()
        device = self.device

        # For simplicity: treat each local step as a self-contained minibatch (since each client has 1 graph and 1 ts)
        # We'll create B augmented pairs by repeating augmentations
        for step in range(self.local_steps):
            # Graph positive pair: two augmentations
            g_orig = self.client_data['graph']
            g_aug1 = graph_edge_dropout(g_orig, p_drop=0.2)
            g_aug2 = graph_feature_masking(g_orig, p_mask=0.1)

            # negative graph: corrupt by shuffling features
            neg_x = g_orig.x[torch.randperm(g_orig.x.size(0))]
            g_neg = PyGData(x=neg_x, edge_index=g_orig.edge_index.clone())

            g_aug1 = g_aug1.to(device)
            g_aug2 = g_aug2.to(device)
            g_neg = g_neg.to(device)

            # Temporal positive pair: two augmentations on windows
            ts = self.client_data['ts']
            w1 = ts_window(ts, self.window_size)
            w2 = ts_jitter(ts_window(ts, self.window_size), sigma=0.05)
            # cast to tensors
            w1 = torch.tensor(w1, dtype=torch.float32).to(device)
            w2 = torch.tensor(w2, dtype=torch.float32).to(device)

            # Create a "batch" by duplicating (simple trick to get B>1)
            B = self.batch_size
            g_batch1 = PyGData(x=g_aug1.x.repeat(B,1) if g_aug1.x.dim()==2 else g_aug1.x, edge_index=g_aug1.edge_index)
            g_batch2 = PyGData(x=g_aug2.x.repeat(B,1) if g_aug2.x.dim()==2 else g_aug2.x, edge_index=g_aug2.edge_index)
            g_neg_batch = PyGData(x=g_neg.x.repeat(B,1) if g_neg.x.dim()==2 else g_neg.x, edge_index=g_neg.edge_index)
            # NOTE: above repetition is a hacky way to emulate a batch of B identical graphs with same structure but it's okay for prototyping.

            # graph embeddings (outputs shape: 1 x out_dim) -> expand to B x out_dim
            ge1 = self.graph_enc(g_batch1).squeeze(0)  # out_dim
            ge2 = self.graph_enc(g_batch2).squeeze(0)
            gen = self.graph_enc(g_neg_batch).squeeze(0)
            ge1_b = ge1.unsqueeze(0).repeat(B,1)
            ge2_b = ge2.unsqueeze(0).repeat(B,1)
            gen_b = gen.unsqueeze(0).repeat(B,1)

            # temporal embeddings
            # w1, w2 are (window_len,) -> create batch (B, window_len) by stacking jittered versions
            w1_batch = torch.stack([torch.tensor(ts_jitter(w1.cpu().numpy(), sigma=0.02), dtype=torch.float32) for _ in range(B)]).to(device)
            w2_batch = torch.stack([torch.tensor(ts_jitter(w2.cpu().numpy(), sigma=0.02), dtype=torch.float32) for _ in range(B)]).to(device)
            te1 = self.temp_enc(w1_batch)  # B x out_dim
            te2 = self.temp_enc(w2_batch)

            # fuse (concatenate graph + temporal)
            fused1 = torch.cat([ge1_b, te1], dim=1)  # B x (gdim + tdim)
            fused2 = torch.cat([ge2_b, te2], dim=1)

            # project
            z1 = self.fusion(fused1)
            z2 = self.fusion(fused2)

            # losses
            loss_fuse = nt_xent_loss(z1, z2, temperature=self.tau)
            loss_graph = graph_info_loss(ge1_b, ge2_b, gen_b)
            loss_temp = temporal_contrastive(te1, te2, temperature=self.tau)

            loss = self.w_fuse * loss_fuse + self.w_graph * loss_graph + self.w_temp * loss_temp

            opt.zero_grad()
            loss.backward()
            opt.step()

            if (step+1) % (max(1, self.local_steps//5)) == 0:
                print(f"[Client] step {step+1}/{self.local_steps} loss={loss.item():.4f} (fuse={loss_fuse.item():.4f}, g={loss_graph.item():.4f}, t={loss_temp.item():.4f})")

        return self.get_parameters({}), 1, {}

    def evaluate(self, parameters, config):
        # Not implementing server-side eval here; return dummy
        return float(0.0), 1, {"accuracy": 0.0}

In [None]:
def client_factory_fn(cid: str, client_synth_data, device, local_steps, batch_size, window_size):
    idx = int(cid)
    data = client_synth_data[idx]
    # instantiate model parts
    genc = GraphEncoder(in_dim=data['graph'].x.size(1), hidden=64, out_dim=128)
    tempenc = TemporalEncoder(in_channels=1, hidden=64, out_dim=128)
    fusion = FusionProjection(in_dim=128+128, proj_dim=128)
    client = MultimodalPrivCLClient(graph_enc=genc, temp_enc=tempenc, fusion=fusion,
                                    client_data=data, device=device,
                                    local_steps=local_steps, batch_size=batch_size,
                                    window_size=window_size, lr=1e-3, weights=(1.0,0.5,0.5), tau=0.5)
    return client

In [None]:
def start_simulation(num_clients=5, rounds=5, local_steps=40):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)
    client_data = create_client_synthetic_data(num_clients=num_clients, graph_nodes=20, feat_dim=16, ts_len=200)

    def _client_fn(cid: str):
        return client_factory_fn(cid, client_data, device=device, local_steps=local_steps, batch_size=16, window_size=64)

    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,
        min_fit_clients=num_clients,
        min_available_clients=num_clients
    )
    print(f"Starting Phase 2 simulation: {num_clients} clients, {rounds} rounds")
    fl.simulation.start_simulation(client_fn=_client_fn, num_clients=num_clients,
                                   config=fl.server.ServerConfig(num_rounds=rounds),
                                   strategy=strategy)

In [None]:
if __name__ == "__main__":
    start_simulation(num_clients=4, rounds=3, local_steps=40)

  data_dict[key] = torch.as_tensor(value)
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=3, no round_timeout


Device: cpu
Starting Phase 2 simulation: 4 clients, 3 rounds


  return datetime.utcnow().replace(tzinfo=utc)
2025-09-28 16:54:51,302	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 2.0, 'memory': 7944292763.0, 'node:172.28.0.12': 1.0, 'object_store_memory': 3972146380.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources for clients.
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.0}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 2 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=4757)[0m 2025-09-28 16:55:11.994740: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory fo

[36m(ClientAppActor pid=4758)[0m [Client] step 8/40 loss=5.2021 (fuse=3.4659, g=0.0009, t=3.4715)
[36m(ClientAppActor pid=4758)[0m [Client] step 16/40 loss=5.1992 (fuse=3.4658, g=0.0009, t=3.4658)
[36m(ClientAppActor pid=4758)[0m [Client] step 24/40 loss=5.1989 (fuse=3.4657, g=0.0004, t=3.4660)
[36m(ClientAppActor pid=4758)[0m [Client] step 32/40 loss=5.1990 (fuse=3.4658, g=0.0006, t=3.4660)
[36m(ClientAppActor pid=4758)[0m [Client] step 40/40 loss=5.1987 (fuse=3.4657, g=0.0003, t=3.4658)


[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=4758)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=4758)[0m         


[36m(ClientAppActor pid=4758)[0m [Client] step 8/40 loss=5.1998 (fuse=3.4659, g=0.0015, t=3.4664)
[36m(ClientAppActor pid=4758)[0m [Client] step 16/40 loss=5.1990 (fuse=3.4658, g=0.0007, t=3.4658)
[36m(ClientAppActor pid=4758)[0m [Client] step 24/40 loss=5.1993 (fuse=3.4658, g=0.0008, t=3.4663)
[36m(ClientAppActor pid=4758)[0m [Client] step 32/40 loss=5.1990 (fuse=3.4657, g=0.0003, t=3.4662)
[36m(ClientAppActor pid=4758)[0m [Client] step 40/40 loss=5.1988 (fuse=3.4657, g=0.0004, t=3.4657)


[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=4758)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=4758)[0m         


[36m(ClientAppActor pid=4758)[0m [Client] step 8/40 loss=5.1997 (fuse=3.4658, g=0.0013, t=3.4664)
[36m(ClientAppActor pid=4758)[0m [Client] step 16/40 loss=5.1992 (fuse=3.4658, g=0.0010, t=3.4658)
[36m(ClientAppActor pid=4758)[0m [Client] step 24/40 loss=5.1996 (fuse=3.4658, g=0.0004, t=3.4673)
[36m(ClientAppActor pid=4758)[0m [Client] step 32/40 loss=5.1997 (fuse=3.4658, g=0.0010, t=3.4668)




[36m(ClientAppActor pid=4758)[0m [Client] step 40/40 loss=5.1989 (fuse=3.4657, g=0.0005, t=3.4658)


[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4757)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=4757)[0m             entirely in future versions of Flower.


[36m(ClientAppActor pid=4757)[0m [Client] step 8/40 loss=5.1999 (fuse=3.4659, g=0.0007, t=3.4674)
[36m(ClientAppActor pid=4757)[0m [Client] step 16/40 loss=5.2016 (fuse=3.4659, g=0.0015, t=3.4698)
[36m(ClientAppActor pid=4757)[0m [Client] step 24/40 loss=5.1989 (fuse=3.4658, g=0.0005, t=3.4657)


[92mINFO [0m:      aggregate_fit: received 4 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 4 clients (out of 4)


[36m(ClientAppActor pid=4757)[0m [Client] step 32/40 loss=5.1995 (fuse=3.4658, g=0.0003, t=3.4672)
[36m(ClientAppActor pid=4757)[0m [Client] step 40/40 loss=5.1992 (fuse=3.4658, g=0.0004, t=3.4664)


[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=4757)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=4757)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[92mINFO [0m:      aggregate_evaluate: received 4 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 4 clients (out of 4)
[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4758)[0m

[36m(ClientAppActor pid=4757)[0m [Client] step 8/40 loss=5.1989 (fuse=3.4657, g=0.0004, t=3.4660)


[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         
[92mINFO [0m:      aggregate_fit: received 4 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 4 clients (out of 4)
[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[92mINFO [0m:      aggregate_evaluate: received 4 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 4 clients (out of 4)
[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         


[36m(ClientAppActor pid=4757)[0m [Client] step 40/40 loss=5.1986 (fuse=3.4657, g=0.0000, t=3.4657)[32m [repeated 29x across cluster][0m


[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[92mINFO [0m:      aggregate_fit: received 4 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 4 clients (out of 4)
[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4757)[0m             This is a deprecated feature. It will be removed[32m [repeated 15x across cluster][0m
[36m(ClientAppActor pid=4757)[0m             entirely in future versions of Flower.[32m [repeated 15x across cluster][0m
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
[92mINFO [0m:      aggregate_evaluate: received 4 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 11.90s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.0


In [None]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# Load Cora dataset
dataset = Planetoid(root="data/Cora", name="Cora", transform=NormalizeFeatures())
data = dataset[0]  # Single graph object

print("Cora graph info:")
print("Nodes:", data.num_nodes)
print("Edges:", data.num_edges)
print("Features:", data.num_node_features)
print("Classes:", dataset.num_classes)

# Train/test splits
X_train = data.x[data.train_mask]
y_train = data.y[data.train_mask]

X_val = data.x[data.val_mask]
y_val = data.y[data.val_mask]

X_test = data.x[data.test_mask]
y_test = data.y[data.test_mask]

print("Train nodes:", X_train.shape, " Val nodes:", X_val.shape, " Test nodes:", X_test.shape)


[36m(ClientAppActor pid=4757)[0m 
[36m(ClientAppActor pid=4757)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=4757)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=4757)[0m         
[36m(ClientAppActor pid=4758)[0m 
[36m(ClientAppActor pid=4758)[0m         
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...

Cora graph info:
Nodes: 2708
Edges: 10556
Features: 1433
Classes: 7
Train nodes: torch.Size([140, 1433])  Val nodes: torch.Size([500, 1433])  Test nodes: torch.Size([1000, 1433])


  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
!pip install tslearn

Collecting tslearn
  Downloading tslearn-0.6.4-py3-none-any.whl.metadata (15 kB)
Downloading tslearn-0.6.4-py3-none-any.whl (389 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m389.9/389.9 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tslearn
Successfully installed tslearn-0.6.4


In [None]:
from tslearn.datasets import UCR_UEA_datasets
from sklearn.model_selection import train_test_split

# Load ECG200 dataset
ucr = UCR_UEA_datasets()
loaded_data = ucr.load_dataset("ECG200") # returns arrays; X shape (200, 96, 1) typically
X, y = loaded_data[0], loaded_data[1] # Unpack only the data and labels

print("ECG200 shape:", X.shape, "Labels:", set(y))

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print("Train shape:", X_train.shape, " Test shape:", X_test.shape)

ECG200 shape: (100, 96, 1) Labels: {np.int64(1), np.int64(-1)}
Train shape: (80, 96, 1)  Test shape: (20, 96, 1)


  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
import random
import math
from typing import Tuple
import numpy as np
from tqdm import trange, tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# ---- Graph (PyG) imports ----
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import GCNConv

In [None]:
from tslearn.datasets import UCR_UEA_datasets
from sklearn.model_selection import train_test_split

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)


Device: cpu


In [None]:
GRAPH_BATCH_NODES = 256   # number of node indices per contrastive batch (sampled from all nodes)
GRAPH_LOCAL_EPOCHS = 20
GRAPH_LR = 1e-3

TS_BATCH = 64
TS_EPOCHS = 30
TS_LR = 1e-3
TS_WINDOW = 96  # ECG200 length is 96

TEMPERATURE = 0.5
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
def nt_xent_loss(z1: torch.Tensor, z2: torch.Tensor, temperature: float = 0.5) -> torch.Tensor:
    """
    z1, z2: (B, D)
    returns scalar loss
    """
    device = z1.device
    B = z1.size(0)
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    z = torch.cat([z1, z2], dim=0)  # 2B x D
    sim = torch.matmul(z, z.T) / temperature  # 2B x 2B
    # mask out self-similarity
    mask = (~torch.eye(2*B, dtype=torch.bool)).to(device)
    sim_masked = sim.masked_select(mask).view(2*B, 2*B-1)
    positives = torch.cat([torch.diag(sim, B), torch.diag(sim, -B)], dim=0).unsqueeze(1)  # 2B x 1
    logits = torch.cat([positives, sim_masked], dim=1)  # 2B x (2B)
    labels = torch.zeros(2*B, dtype=torch.long).to(device)  # positives are at index 0
    loss = F.cross_entropy(logits, labels)
    return loss

In [None]:
class GraphEncoder(nn.Module):
    def __init__(self, in_dim, hidden=128, out_dim=128):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.proj = nn.Sequential(nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Linear(hidden//2, out_dim))

    def forward(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        h = F.relu(self.conv2(h, edge_index))  # N x hidden
        z = self.proj(h)  # N x out_dim (node embeddings)
        return z


In [None]:
def graph_augment_edge_dropout(x, edge_index, drop_prob=0.2):
    # drop edges using PyG's dropout_adj
    new_ei, _ = dropout_adj(edge_index, p=drop_prob, force_undirected=True)
    return x, new_ei

In [None]:
def graph_feat_mask(x, mask_prob=0.1):
    x2 = x.clone()
    mask = (torch.rand_like(x2) > mask_prob).float()
    return x2 * mask

In [None]:
def train_graph_contrastive(epochs=GRAPH_LOCAL_EPOCHS, batch_nodes=GRAPH_BATCH_NODES):
    print("\n--- Graph contrastive training (Cora) ---")
    dataset = Planetoid(root="data/Cora", name="Cora", transform=NormalizeFeatures())
    data = dataset[0].to(DEVICE)
    N = data.num_nodes
    feat_dim = data.num_node_features

    enc = GraphEncoder(in_dim=feat_dim, hidden=256, out_dim=128).to(DEVICE)
    opt = optim.Adam(enc.parameters(), lr=GRAPH_LR, weight_decay=1e-5)

    node_indices = np.arange(N)

    pbar = trange(epochs, desc="Graph epochs")
    for ep in pbar:
        enc.train()
        epoch_loss = 0.0
        # iterate in mini-batches of node indices
        np.random.shuffle(node_indices)
        for i in range(0, N, batch_nodes):
            batch_idx = node_indices[i: i+batch_nodes]
            if len(batch_idx) == 0:
                continue
            # --- create two augmented graph views ---
            # view 1: edge dropout + feat mask
            x1 = data.x.clone()
            x1 = graph_feat_mask(x1, mask_prob=0.1)
            x1, ei1 = graph_augment_edge_dropout(x1, data.edge_index, drop_prob=0.2)
            # view 2: different augmentations
            x2 = data.x.clone()
            x2 = graph_feat_mask(x2, mask_prob=0.15)
            x2, ei2 = graph_augment_edge_dropout(x2, data.edge_index, drop_prob=0.25)

            # compute node embeddings for all nodes (N x D)
            with torch.no_grad():
                # nothing
                pass
            z1_all = enc(x1.to(DEVICE), ei1.to(DEVICE))  # N x D
            z2_all = enc(x2.to(DEVICE), ei2.to(DEVICE))  # N x D

            # select batch rows
            z1 = z1_all[batch_idx].to(DEVICE)  # b x D
            z2 = z2_all[batch_idx].to(DEVICE)

            # loss
            loss = nt_xent_loss(z1, z2, temperature=TEMPERATURE)

            opt.zero_grad()
            loss.backward()
            opt.step()

            epoch_loss += float(loss.detach().cpu().item())

        pbar.set_postfix({"loss": epoch_loss / max(1, math.ceil(N / batch_nodes))})
    print("Graph training done. Encoder params ready.")
    return enc

In [None]:
class TemporalEncoder(nn.Module):
    def __init__(self, window_len=TS_WINDOW, hidden=128, out_dim=128):
        super().__init__()
        # 1D conv stack
        self.conv1 = nn.Conv1d(1, hidden, kernel_size=9, padding=4)
        self.conv2 = nn.Conv1d(hidden, hidden, kernel_size=9, padding=4)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.proj = nn.Sequential(nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Linear(hidden//2, out_dim))

    def forward(self, x):
        # x: (B, window_len)
        x = x.unsqueeze(1)  # B x 1 x L
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = self.pool(h).squeeze(-1)  # B x hidden
        z = self.proj(h)  # B x out_dim
        return z

In [None]:
def ts_jitter(series: np.ndarray, sigma=0.03):
    return (series + np.random.normal(0, sigma, size=series.shape)).astype(np.float32)

def ts_scaling(series: np.ndarray, sigma=0.1):
    factor = np.random.normal(1.0, sigma)
    return (series * factor).astype(np.float32)

def ts_window_slice(series: np.ndarray, window=TS_WINDOW):
    L = series.shape[0]
    if L == window:
        return series.copy()
    start = np.random.randint(0, L - window + 1)
    return series[start:start+window].copy()

In [None]:
class TSContrastiveDataset(torch.utils.data.Dataset):
    def __init__(self, X_raw: np.ndarray):
        # X_raw shape: (n_samples, length) or (n_samples, length, 1)
        if X_raw.ndim == 3:
            X_raw = X_raw.squeeze(-1)
        self.X = X_raw.astype(np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        s = self.X[idx]
        # create two augmented views
        v1 = ts_jitter(ts_window_slice(s, window=TS_WINDOW), sigma=0.03)
        v1 = ts_scaling(v1, sigma=0.05)
        v2 = ts_jitter(ts_window_slice(s, window=TS_WINDOW), sigma=0.05)
        v2 = ts_scaling(v2, sigma=0.08)
        return v1, v2

In [None]:
def train_ts_contrastive(epochs=TS_EPOCHS, batch_size=TS_BATCH):
    print("\n--- Time-series contrastive training (ECG200) ---")
    # Load ECG200
    ucr = UCR_UEA_datasets()
    loaded_data = ucr.load_dataset("ECG200") # returns arrays; X shape (200, 96, 1) typically
    X, y = loaded_data[0], loaded_data[1] # Unpack only the data and labels

    # Flatten/truncate/reshape as needed
    X = np.asarray(X)
    if X.ndim == 3:
        X = X.squeeze(-1)
    # simple train/test split (we train on train set)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=SEED, stratify=y)
    dataset = TSContrastiveDataset(X_train)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

    enc = TemporalEncoder(window_len=TS_WINDOW, hidden=128, out_dim=128).to(DEVICE)
    opt = optim.Adam(enc.parameters(), lr=TS_LR, weight_decay=1e-5)

    pbar = trange(epochs, desc="TS epochs")
    for ep in pbar:
        enc.train()
        epoch_loss = 0.0
        for v1_np, v2_np in loader:
            v1 = torch.tensor(v1_np).to(DEVICE)
            v2 = torch.tensor(v2_np).to(DEVICE)
            z1 = enc(v1)
            z2 = enc(v2)
            loss = nt_xent_loss(z1, z2, temperature=TEMPERATURE)
            opt.zero_grad()
            loss.backward()
            opt.step()
            epoch_loss += float(loss.detach().cpu().item())
        pbar.set_postfix({"loss": epoch_loss / max(1, len(loader))})
    print("Time-series training done.")
    return enc, (X_test, y_test)

In [None]:
if __name__ == "__main__":
    # Train graph encoder
    graph_encoder = train_graph_contrastive(epochs=GRAPH_LOCAL_EPOCHS, batch_nodes=GRAPH_BATCH_NODES)

    # Train temporal encoder
    ts_encoder, (X_test, y_test) = train_ts_contrastive(epochs=TS_EPOCHS, batch_size=TS_BATCH)

    # Simple check: output embedding shapes
    # Graph: test by running full graph through encoder
    from torch_geometric.datasets import Planetoid
    ds = Planetoid(root="data/Cora", name="Cora", transform=NormalizeFeatures())
    d0 = ds[0].to(DEVICE)
    graph_encoder.eval()
    with torch.no_grad():
        z_nodes = graph_encoder(d0.x, d0.edge_index)
    print("Graph node embedding shape:", z_nodes.shape)  # N x D

    # TS: embed a small batch from test set
    # X_test, y_test = ts_test # This unpacking is no longer needed here
    sample = torch.tensor(X_test[:8], dtype=torch.float).to(DEVICE)
    ts_encoder.eval()
    with torch.no_grad():
        z_ts = ts_encoder(sample)
    print("TS embedding shape (8 samples):", z_ts.shape)
    print("\nPhase 2 complete: trained encoders for graph and time-series.")


--- Graph contrastive training (Cora) ---


Graph epochs: 100%|██████████| 20/20 [01:05<00:00,  3.26s/it, loss=4.6]


Graph training done. Encoder params ready.

--- Time-series contrastive training (ECG200) ---


  return datetime.utcnow().replace(tzinfo=utc)
  v1 = torch.tensor(v1_np).to(DEVICE)
  v2 = torch.tensor(v2_np).to(DEVICE)
TS epochs: 100%|██████████| 30/30 [00:05<00:00,  5.81it/s, loss=3.99]


Time-series training done.
Graph node embedding shape: torch.Size([2708, 128])
TS embedding shape (8 samples): torch.Size([8, 128])

Phase 2 complete: trained encoders for graph and time-series.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class FusionProjection(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, in_dim // 2),
            nn.ReLU(),
            nn.Linear(in_dim // 2, proj_dim)
        )
    def forward(self, x):
        return self.mlp(x)

In [None]:
class MultimodalEncoder(nn.Module):
    def __init__(self, graph_in_dim, ts_window_len,
                 hidden_dim=128, out_dim=128, fusion_dim=128):
        super().__init__()
        # Individual encoders
        self.graph_enc = GraphEncoder(in_dim=graph_in_dim,
                                      hidden=hidden_dim, out_dim=out_dim)
        self.ts_enc = TemporalEncoder(window_len=ts_window_len,
                                      hidden=hidden_dim, out_dim=out_dim)
        # Fusion projection
        self.fusion = FusionProjection(in_dim=out_dim*2, proj_dim=fusion_dim)

    def forward(self, x_graph, edge_index, x_ts):
        # Encode graph and time-series separately
        z_g = self.graph_enc(x_graph, edge_index)  # shape: N x out_dim
        z_t = self.ts_enc(x_ts)                    # shape: B x out_dim

        # NOTE: For simplicity, align batch sizes (e.g., sample B nodes = B ts samples)
        fused = torch.cat([z_g, z_t], dim=1)       # shape: B x (2*out_dim)
        z_f = self.fusion(fused)                   # shape: B x fusion_dim
        return z_f

In [None]:
def train_multimodal_contrastive(model, graph_data, ts_loader, optimizer, epochs=10, temperature=0.5):
    device = next(model.parameters()).device
    model.train()

    for ep in range(epochs):
        total_loss = 0.0
        for (ts1, ts2) in ts_loader:  # two augmented time-series views
            ts1, ts2 = ts1.to(device), ts2.to(device)
            # Graph augmentations
            x1 = graph_feat_mask(graph_data.x.clone(), mask_prob=0.1)
            x1, ei1 = graph_augment_edge_dropout(x1, graph_data.edge_index, drop_prob=0.2)
            x2 = graph_feat_mask(graph_data.x.clone(), mask_prob=0.15)
            x2, ei2 = graph_augment_edge_dropout(x2, graph_data.edge_index, drop_prob=0.25)

            # Forward pass: fused embeddings
            z1 = model(x1.to(device), ei1.to(device), ts1)  # fused view 1
            z2 = model(x2.to(device), ei2.to(device), ts2)  # fused view 2

            # Contrastive loss
            loss = nt_xent_loss(z1, z2, temperature=temperature)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {ep+1}/{epochs}, Loss={total_loss/len(ts_loader):.4f}")

In [None]:
!pip install opacus
from opacus import PrivacyEngine
from scipy.stats import entropy

Collecting opacus
  Downloading opacus-1.5.4-py3-none-any.whl.metadata (8.7 kB)
Downloading opacus-1.5.4-py3-none-any.whl (254 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.4/254.4 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: opacus
Successfully installed opacus-1.5.4


In [None]:
import torch, torch.nn as nn, torch.optim as optim
from opacus import PrivacyEngine
import numpy as np
from scipy.stats import entropy
import flwr as fl
from torch.utils.data import DataLoader, Dataset

In [None]:
def _hist_entropy(tensor, bins=64):
    arr = tensor.detach().cpu().float().view(-1).numpy()
    if arr.size == 0:
        return 0.0
    hist, _ = np.histogram(arr, bins=bins, density=True)
    hist = hist + 1e-12
    hist = hist / hist.sum()
    return float(-np.sum(hist * np.log(hist)))

In [None]:
def adaptive_noise_from_entropy(entropy_val, base=1.0, min_noise=0.5, max_noise=3.0):
    noise = base / (1.0 + entropy_val)
    noise = max(min_noise, min(noise, max_noise))
    return float(noise)

In [None]:
def noise_for_image_batch(image_batch, base=1.0):
    return adaptive_noise_from_entropy(_hist_entropy(image_batch), base=base)

In [None]:
def noise_for_graph(graph_data, node_idx=None, base=1.0):
    x = graph_data.x if node_idx is None else graph_data.x[node_idx]
    return adaptive_noise_from_entropy(_hist_entropy(x), base=base)

In [None]:
def noise_for_ts_batch(ts_batch, base=1.0):
    return adaptive_noise_from_entropy(_hist_entropy(ts_batch), base=base)

In [None]:
def combine_noise(n_img=None, n_graph=None, n_ts=None, strategy="max", weights=(1.0,1.0,1.0)):
    vals = [v for v in (n_img, n_graph, n_ts) if v is not None]
    if not vals:
        return 1.0
    if strategy == "max":
        return float(max(vals))
    if strategy == "weighted":
        w = np.array(weights)[:len(vals)]
        w = w / w.sum()
        return float((w * np.array(vals)).sum())
    return float(max(vals))

In [None]:
class MultimodalDataset(Dataset):
    """
    Given:
      - graph_data: PyG Data (graph.x, graph.edge_index), used only for node features.
      - ts_windows: torch.Tensor shape (N_ts, L)
      - optional image tensor images: (N_img, C, H, W)
    This simple dataset samples indices 0..N-1 and yields tuples (img, node_idx, ts_window)
    For real data you should align or sample properly per-client.
    """
    def __init__(self, graph_data=None, ts_windows=None, images=None, N=None):
        self.graph = graph_data
        self.ts = ts_windows
        self.images = images
        # choose effective length
        candidates = []
        if ts_windows is not None:
            candidates.append(ts_windows.shape[0])
        if images is not None:
            candidates.append(images.shape[0])
        if graph_data is not None:
            candidates.append(graph_data.x.shape[0])
        self.N = N if N is not None else (min(candidates) if candidates else 0)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        img = self.images[idx % self.images.shape[0]] if self.images is not None else torch.zeros(1)
        node_idx = idx % self.graph.x.size(0) if self.graph is not None else 0
        ts = self.ts[idx % self.ts.shape[0]] if self.ts is not None else torch.zeros(10)
        return img, node_idx, ts

In [None]:
class GraphEncoder(nn.Module):
    def __init__(self, in_dim, out_dim=128):
        super().__init__()
        self.lin = nn.Linear(in_dim, out_dim)
    def forward(self, x, edge_index=None):
        return self.lin(x)

In [None]:
class TemporalEncoder(nn.Module):
    def __init__(self, L, out_dim=128):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv1d(1,64,9,padding=4), nn.ReLU(), nn.AdaptiveAvgPool1d(1), nn.Flatten(), nn.Linear(64, out_dim))
    def forward(self, x):
        if x.dim()==2: x = x.unsqueeze(1)  # (B,1,L)
        return self.conv(x)

class FusionProjection(nn.Module):
    def __init__(self, in_dim, proj_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(in_dim, proj_dim), nn.ReLU(), nn.Linear(proj_dim, proj_dim))
    def forward(self, x): return self.mlp(x)

In [None]:
def nt_xent_loss(z1, z2, temperature=0.5):
    z1 = nn.functional.normalize(z1, dim=1)
    z2 = nn.functional.normalize(z2, dim=1)
    B = z1.shape[0]
    z = torch.cat([z1,z2], dim=0)
    sim = torch.matmul(z, z.T) / temperature
    mask = (~torch.eye(2*B, dtype=torch.bool, device=sim.device)).float()
    exp_sim = torch.exp(sim) * mask
    positives = torch.exp((z1*z2).sum(dim=1)/temperature)
    positives = torch.cat([positives, positives], dim=0)
    denom = exp_sim.sum(dim=1)
    loss = -torch.log(positives / (denom + 1e-12))
    return loss.mean()

In [None]:
class AdaptiveDPMultimodalClient(fl.client.NumPyClient):
    def __init__(self, graph_enc, ts_enc, fusion, client_data, device="cpu",
                 local_epochs=1, local_steps=100, batch_size=16, lr=1e-3):
        self.device = torch.device(device)
        self.graph_enc = graph_enc.to(self.device)
        self.ts_enc = ts_enc.to(self.device)
        self.fusion = fusion.to(self.device)

        self.client_data = client_data  # dict with 'graph' (pyg.Data), 'ts' (Tensor NxL), optional 'images' (Tensor)
        self.local_epochs = local_epochs
        self.local_steps = local_steps
        self.batch_size = batch_size
        self.lr = lr

    def get_parameters(self, config=None):
        params = []
        for model in (self.graph_enc, self.ts_enc, self.fusion):
            for _, v in model.state_dict().items():
                params.append(v.detach().cpu().numpy())
        return params

    def set_parameters(self, parameters):
        it = iter(parameters)
        for model in (self.graph_enc, self.ts_enc, self.fusion):
            sd = model.state_dict()
            new_sd = {}
            for k in sd.keys():
                arr = next(it)
                new_sd[k] = torch.tensor(arr, dtype=sd[k].dtype)
            model.load_state_dict(new_sd)

    def fit(self, parameters, config):
        # 1) load server params
        if parameters is not None:
            try:
                self.set_parameters(parameters)
            except Exception as e:
                print("set_parameters failed:", e)

        # 2) build combined dataset and DataLoader (yields single-sample tuples)
        graph = self.client_data.get("graph", None)
        ts = self.client_data.get("ts", None)
        images = self.client_data.get("images", None)
        # choose dataset length: min available or explicit
        N = min(
            (images.shape[0] if images is not None else float('inf'),
             ts.shape[0] if ts is not None else float('inf'),
             graph.x.shape[0] if graph is not None else float('inf'))
        )
        N = int(N if N != float('inf') else (ts.shape[0] if ts is not None else 0))
        ds = MultimodalDataset(graph, ts, images, N=N)
        loader = DataLoader(ds, batch_size=self.batch_size, shuffle=True)

        # 3) compute per-modality adaptive noise from single sample batch
        sample_img, sample_node_idx, sample_ts = next(iter(loader))
        # convert to typical tensor shapes expected by helpers
        n_img = noise_for_image_batch(sample_img) if images is not None else None
        n_graph = noise_for_graph(graph, node_idx=sample_node_idx) if graph is not None else None
        n_ts = noise_for_ts_batch(sample_ts) if ts is not None else None

        noise_multiplier = combine_noise(n_img, n_graph, n_ts, strategy="max")
        print(f"[Client] adaptive noise multipliers: img={n_img}, graph={n_graph}, ts={n_ts} -> combined={noise_multiplier}")

        # 4) Prepare a single optimizer over all params (we do DP over the whole multimodal model)
        all_params = list(self.graph_enc.parameters()) + list(self.ts_enc.parameters()) + list(self.fusion.parameters())
        optimizer = optim.Adam(all_params, lr=self.lr)

        # 5) Attach PrivacyEngine
        privacy_engine = PrivacyEngine()
        # IMPORTANT: make_private expects the dataloader to yield single samples; our loader does.
        model_wrapper = nn.Sequential()  # dummy wrapper; Opacus registers module parameters, so it's OK to pass a wrapper with submodules if needed
        # Instead of wrapper, we pass nothing special; we still call make_private with module=self.graph_enc (only for bookkeeping).
        # To be safe, pass an nn.Module that contains all params:
        class _All(nn.Module):
            def __init__(self, graph_enc, ts_enc, fusion):
                super().__init__()
                self.graph_enc = graph_enc
                self.ts_enc = ts_enc
                self.fusion = fusion
        combined_module = _All(self.graph_enc, self.ts_enc, self.fusion)

        combined_module, optimizer, private_loader = privacy_engine.make_private(
            module=combined_module,
            optimizer=optimizer,
            data_loader=loader,
            noise_multiplier=noise_multiplier,
            max_grad_norm=1.0,
        )

        # 6) local training loop (DP-SGD)
        combined_module.train()
        step = 0
        last_loss = 0.0
        for epoch in range(self.local_epochs):
            for batch in private_loader:
                img_batch, node_idx_batch, ts_batch = batch
                # move to device
                ts_batch = ts_batch.to(self.device).float()
                # graph: gather node features for batch (simple example: per-node features)
                node_idx_batch = node_idx_batch.to(self.device)
                x_nodes = graph.x[node_idx_batch].to(self.device).float() if graph is not None else torch.zeros((ts_batch.size(0),1)).to(self.device)
                # image: pass if provided
                img_tensor = img_batch.to(self.device).float() if images is not None else torch.zeros((ts_batch.size(0),1)).to(self.device)

                # create two augmentations per modality (simple jitter/mask examples)
                # For brevity, do basic augmentations:
                ts_v1 = ts_batch + 0.01*torch.randn_like(ts_batch)
                ts_v2 = ts_batch + 0.02*torch.randn_like(ts_batch)
                x1 = x_nodes * (torch.rand_like(x_nodes) > 0.1).float()
                x2 = x_nodes * (torch.rand_like(x_nodes) > 0.15).float()

                # encode
                z_g1 = self.graph_enc(x1)
                z_g2 = self.graph_enc(x2)
                z_t1 = self.ts_enc(ts_v1)
                z_t2 = self.ts_enc(ts_v2)

                # fuse (concatenate per-sample)
                fused1 = torch.cat([z_g1, z_t1], dim=1)
                fused2 = torch.cat([z_g2, z_t2], dim=1)
                z_f1 = self.fusion(fused1)
                z_f2 = self.fusion(fused2)

                loss = nt_xent_loss(z_f1, z_f2)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                last_loss = loss.item()
                step += 1
                if step >= self.local_steps:
                    break
            if step >= self.local_steps:
                break

        # 7) after training get epsilon (delta chosen as typical 1e-5)
        try:
            epsilon = privacy_engine.get_epsilon(delta=1e-5)
        except Exception as e:
            # fallback if not supported
            epsilon = None
            print("Could not compute epsilon:", e)

        # 8) return updated params + sample count + metrics
        params = []
        for model in (self.graph_enc, self.ts_enc, self.fusion):
            for _, v in model.state_dict().items():
                params.append(v.detach().cpu().numpy())
        num_examples = len(loader.dataset)
        metrics = {"loss": float(last_loss)}
        if epsilon is not None:
            metrics["epsilon"] = float(epsilon)
        return params, num_examples, metrics

    def evaluate(self, parameters, config):
        return 0.0, 0, {}

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
from opacus import PrivacyEngine
from scipy.stats import entropy

In [None]:
def compute_dataset_entropy_numpy(dataset_tensor):
    """
    Compute an entropy-like score for a dataset tensor.
    Accepts: numpy array or torch tensor of shape (N, ...).
    Returns: scalar >0. Larger => more "complex" -> adjust noise lower (example).
    """
    if isinstance(dataset_tensor, torch.Tensor):
        arr = dataset_tensor.detach().cpu().numpy().ravel()
    else:
        arr = np.asarray(dataset_tensor).ravel()
    # Normalize to histogram
    if arr.size == 0:
        return 0.0
    # Use 256 bins (clamped)
    hist, _ = np.histogram(arr, bins=256, density=True)
    # Add small eps to avoid zeros
    hist = hist + 1e-12
    return float(entropy(hist))

In [None]:
def compute_adaptive_noise_multiplier(entropy_score, base_noise=1.0, min_noise=0.3, max_noise=5.0, sensitivity=1.0):
    """
    Map entropy_score -> noise_multiplier.
    Heuristic: higher entropy -> dataset more informative -> use lower noise (improve utility).
    Lower entropy -> add more noise to preserve privacy.
    This mapping is tunable.
    """
    # Sigmoid-like mapping: invert entropy to make higher entropy -> lower noise
    # First normalize entropy_score to [0,1] using a soft scale
    scale = 1.0 / (1.0 + math.exp(- (entropy_score - 4.0)))  # center ~4.0; adjust if needed
    # inverted
    inv = 1.0 - scale
    noise = base_noise * (min_noise + (max_noise - min_noise) * inv)
    # ensure bounds
    noise = max(min_noise, min(max_noise, float(noise)))
    return noise

In [None]:
class AdaptiveDPPrivCLClient(fl.client.NumPyClient):
    def __init__(self, model: nn.Module, trainloader, device='cpu',
                 lr=0.01, epochs=1, base_noise=1.0, max_grad_norm=1.0):
        self.model = model.to(device)
        self.trainloader = trainloader
        self.device = device
        self.lr = lr
        self.epochs = epochs
        self.base_noise = base_noise
        self.max_grad_norm = max_grad_norm

        # optimizer (will be wrapped by Opacus when making private)
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)
        # placeholder for privacy engine; will be attached during fit with computed noise
        self.privacy_engine = None
        self.criterion = nn.CrossEntropyLoss()  # replace with NT-Xent or contrastive loss as required

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters_list):
        params = {k: torch.tensor(v) for k, v in zip(self.model.state_dict().keys(), parameters_list)}
        self.model.load_state_dict(params)

    def _compute_client_entropy(self):
        # Sample a subset (or entire) dataset to compute entropy. Keep lightweight.
        # Assumes trainloader yields (x, y) or (x1, x2, y) tuples. We'll flatten inputs.
        sample_tensors = []
        max_samples = 512  # cap for speed
        taken = 0
        for batch in self.trainloader:
            x = batch[0]
            if isinstance(x, (list, tuple)):
                x = x[0]
            sample_tensors.append(x.detach().cpu())
            taken += x.shape[0]
            if taken >= max_samples:
                break
        if len(sample_tensors) == 0:
            return 0.0
        cat = torch.cat(sample_tensors, dim=0)
        # compute entropy on flattened values
        return compute_dataset_entropy_numpy(cat)
    def fit(self, parameters, config):
        # Set model weights from server
        if parameters is not None:
            self.set_parameters(parameters)

        # Compute client-specific noise multiplier
        entropy_score = self._compute_client_entropy()
        noise_multiplier = compute_adaptive_noise_multiplier(entropy_score, base_noise=self.base_noise)

        # Recreate optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

        # Detach any previous privacy engine
        if self.privacy_engine is not None:
            try:
                self.privacy_engine.detach()
            except Exception:
                pass

        # Explicit sample rate
        batch_size = self.trainloader.batch_size
        dataset_size = len(self.trainloader.dataset)
        sample_rate = batch_size / dataset_size
        delta = 1.0 / dataset_size  # standard choice

        # Reinitialize PrivacyEngine with accountant
        self.privacy_engine = PrivacyEngine(accountant="rdp")  # force RDP accountant
        self.model, self.optimizer, self.trainloader = self.privacy_engine.make_private(
            module=self.model,
            optimizer=self.optimizer,
            data_loader=self.trainloader,
            noise_multiplier=noise_multiplier,
            max_grad_norm=self.max_grad_norm,
            poisson_sampling=False,  # we’re using uniform sampling
        )

        # Local training loop
        self.model.train()
        for epoch in range(self.epochs):
            for batch in self.trainloader:
                if isinstance(batch, (list, tuple)) and len(batch) >= 2:
                    x = batch[0].to(self.device)
                    if len(batch) >= 2 and isinstance(batch[1], torch.Tensor) and batch[1].dim() == x.dim():
                        x2 = batch[1].to(self.device)
                        z1 = self.model(x)
                        z2 = self.model(x2)
                        loss = nn.functional.mse_loss(z1, z2)
                    else:
                        inputs = x
                        targets = batch[1].to(self.device) if len(batch) > 1 else None
                        outputs = self.model(inputs)
                        loss = self.criterion(outputs, targets) if targets is not None else torch.tensor(0.0, device=self.device)
                else:
                    inputs = batch.to(self.device)
                    outputs = self.model(inputs)
                    loss = torch.tensor(0.0, device=self.device)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        # Compute epsilon after training
        try:
            epsilon, best_alpha = self.privacy_engine.accountant.get_epsilon(delta)
        except Exception:
            epsilon, best_alpha = None, None

        # Log results
        print(f"[Client] noise_multiplier={noise_multiplier:.3f}, epsilon={epsilon}, delta={delta}")

        # Return updated params to server
        new_params = self.get_parameters({})
        return new_params, len(self.trainloader.dataset), {"noise_multiplier": noise_multiplier, "epsilon": epsilon}

    def evaluate(self, parameters, config):
        # Set model and run local eval (simple)
        if parameters is not None:
            self.set_parameters(parameters)
        self.model.eval()
        loss_total = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in self.trainloader:
                if isinstance(batch, (list, tuple)):
                    x = batch[0].to(self.device)
                    y = batch[1].to(self.device) if len(batch) > 1 else None
                else:
                    x = batch.to(self.device)
                    y = None
                outputs = self.model(x)
                if y is not None:
                    loss_total += nn.functional.cross_entropy(outputs, y, reduction="sum").item()
                    preds = outputs.argmax(dim=1)
                    correct += (preds == y).sum().item()
                    total += y.size(0)
        if total == 0:
            return 0.0, 0, {}
        loss_avg = loss_total / total
        accuracy = correct / total
        return float(loss_avg), total, {"accuracy": float(accuracy)}

In [None]:
def test_adaptive_dp_client(model_class, trainloader, device="cpu"):
    """
    Quick test of AdaptiveDPPrivCLClient.
    Trains for 1 epoch with adaptive DP-SGD and prints noise multiplier & epsilon.
    Accepts model_class (e.g., SmallNet) and instantiates it inside.
    """
    # Create a new model instance for each test run
    model = model_class().to(device)

    client = AdaptiveDPPrivCLClient(
        model=model,
        trainloader=trainloader,
        device=device,
        lr=0.01,
        epochs=1,
        base_noise=1.0,
        max_grad_norm=1.0,
    )

    # Run one fit round
    params, num_examples, metrics = client.fit(parameters=None, config={})
    print("Client trained with:")
    print(f"  - Noise multiplier: {metrics.get('noise_multiplier')}")
    print(f"  - Epsilon: {metrics.get('epsilon')}")
    print(f"  - Examples seen: {num_examples}")
    return metrics

In [None]:
# Suppose you already have a simple model & DataLoader
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# Simple model
class SmallNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(28*28, 10)
    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))

# Dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
# Larger subset for testing epsilon
subset = Subset(mnist, range(50000))   # instead of 200
trainloader = DataLoader(subset, batch_size=64, shuffle=True)

# Train longer
metrics = test_adaptive_dp_client(SmallNet, trainloader, device="cpu")

100%|██████████| 9.91M/9.91M [00:00<00:00, 39.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.00MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.22MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.4MB/s]
  loss.backward()


[Client] noise_multiplier=4.688, epsilon=None, delta=2e-05
Client trained with:
  - Noise multiplier: 4.688039162171171
  - Epsilon: None
  - Examples seen: 50000




In [None]:
!pip install tenseal

Collecting tenseal
  Downloading tenseal-0.3.16-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Downloading tenseal-0.3.16-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (4.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/4.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/4.8 MB[0m [31m11.5 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━[0m [32m3.6/4.8 MB[0m [31m51.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tenseal
Successfully installed tenseal-0.3.16


In [None]:
import threading, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset
import flwr as fl
import tenseal as ts

In [None]:
def create_tenseal_context():
    ctx = ts.context(
        ts.SCHEME_TYPE.CKKS,
        poly_modulus_degree=8192,
        coeff_mod_bit_sizes=[60, 40, 40, 60],
    )
    ctx.global_scale = 2**40
    ctx.generate_galois_keys()
    public_bytes = ctx.serialize(save_secret_key=False)
    return ctx, public_bytes

In [None]:
def params_to_numpy_list(state_dict):
    return [v.cpu().numpy().astype(np.float64) for _, v in state_dict.items()]

In [None]:
def _params_to_flat_lists(param_list):
    arrays, shapes = [], []
    for a in param_list:
        arr = np.asarray(a, dtype=np.float64)
        shapes.append(arr.shape)
        arrays.append(arr.ravel().tolist())
    return arrays, shapes

In [None]:
def _flat_lists_to_param_arrays(flat_lists, shapes):
    out = []
    for flat, shape in zip(flat_lists, shapes):
        arr = np.asarray(flat, dtype=np.float64).reshape(shape)
        out.append(arr)
    return out


In [None]:
def encrypt_params_with_public_ctx(params_numpy_list, public_ctx_bytes):
    client_ctx = ts.context_from(public_ctx_bytes)
    encrypted_serialized, shapes = [], []
    for arr in params_numpy_list:
        flat = arr.ravel().tolist()
        ck = ts.ckks_vector(client_ctx, flat)
        encrypted_serialized.append(ck.serialize())
        shapes.append(arr.shape)
    return encrypted_serialized, shapes

In [None]:
def deserialize_and_aggregate_encrypted(all_clients_serialized, server_ctx, shapes):
    num_clients = len(all_clients_serialized)
    aggregated_flat_lists = []
    for param_idx in range(len(shapes)):
        sum_cipher = None
        for client_serialized in all_clients_serialized:
            ck = ts.ckks_vector_from(server_ctx, client_serialized[param_idx])
            if sum_cipher is None:
                sum_cipher = ck
            else:
                sum_cipher += ck
        avg_cipher = sum_cipher * (1.0 / float(num_clients))
        decrypted = avg_cipher.decrypt()
        aggregated_flat_lists.append(decrypted)
    aggregated_params = _flat_lists_to_param_arrays(aggregated_flat_lists, shapes)
    return aggregated_params

In [None]:
class ImageNet(nn.Module):  # for MNIST
    def __init__(self, out=10):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, out),
        )
    def forward(self, x):
        return self.fc(x)

In [None]:
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv


In [None]:
class GraphNet(nn.Module):  # for Cora
    def __init__(self, in_feats, hidden=64, out=7):
        super().__init__()
        self.conv1 = GCNConv(in_feats, hidden)
        self.conv2 = GCNConv(hidden, out)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

In [None]:
class TimeNet(nn.Module):  # simple RNN for time series
    def __init__(self, in_dim=1, hidden=32, out=5):
        super().__init__()
        self.rnn = nn.GRU(in_dim, hidden, batch_first=True)
        self.fc = nn.Linear(hidden, out)
    def forward(self, x):
        _, h = self.rnn(x)
        return self.fc(h.squeeze(0))

In [None]:
class HEClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, public_ctx_bytes, is_graph=False, device="cpu"):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.device = device
        self.public_ctx_bytes = public_ctx_bytes
        self.is_graph = is_graph

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters_list):
        params = {k: torch.tensor(v) for k, v in zip(self.model.state_dict().keys(), parameters_list)}
        self.model.load_state_dict(params)

    def fit(self, parameters, config):
        if parameters is not None:
            self.set_parameters(parameters)
        opt = torch.optim.SGD(self.model.parameters(), lr=0.01)
        loss_fn = nn.CrossEntropyLoss()

        self.model.train()
        for batch in self.train_loader:
            opt.zero_grad()
            if self.is_graph:
                data = batch.to(self.device)
                out = self.model(data)
                loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
            else:
                x, y = batch
                x, y = x.to(self.device), y.to(self.device)
                out = self.model(x)
                loss = loss_fn(out, y)
            loss.backward()
            opt.step()

        params_numpy = params_to_numpy_list(self.model.state_dict())
        encrypted_serialized, shapes = encrypt_params_with_public_ctx(params_numpy, self.public_ctx_bytes)
        return None, len(self.train_loader.dataset), {"enc_params": encrypted_serialized, "shapes": shapes}

    def evaluate(self, parameters, config):
        if parameters is not None:
            self.set_parameters(parameters)
        return 0.0, len(self.train_loader.dataset), {}

In [None]:
class SecureHEFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, server_tenseal_ctx, param_shapes, **kwargs):
        super().__init__(**kwargs)
        self.server_ctx = server_tenseal_ctx
        self.param_shapes = param_shapes
        self.parameters = None

    def aggregate_fit(self, server_round, results, failures):
        if not results:
            return None, {}
        all_clients_serialized = []
        for _, fit_res in results:
            client_return = fit_res[1]
            if isinstance(client_return, dict):
                enc_list = client_return["enc_params"]
                shapes = client_return.get("shapes", self.param_shapes)
            else:
                enc_list, shapes = client_return
            all_clients_serialized.append(enc_list)
        averaged_params = deserialize_and_aggregate_encrypted(all_clients_serialized, self.server_ctx, self.param_shapes)
        self.parameters = averaged_params
        return averaged_params, {}