# Imports & Config

In [None]:

!pip install torch torchvision pandas --quiet

import os
import random
import math
import numpy as np
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as T

from google.colab import drive
drive.mount('/content/drive')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# High-level configuration
NUM_CLIENTS     = 10
NUM_ROUNDS      = 100
BATCH_SIZE      = 64
LR_LOCAL        = 0.01
LR_RL           = 0.001
RL_EPSILON      = 1.0
ALPHA_DIRICHLET = 1.0
DATASET_NAME    = "mnist"

EPOCH_ACTION_VALUES = [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]
NUM_EPOCH_ACTIONS   = len(EPOCH_ACTION_VALUES)

LAYER_ACTION_VALUES = [1, 2, 3, 4, 5, 6]
NUM_LAYER_ACTIONS   = len(LAYER_ACTION_VALUES)


CLIENT_CAPABILITIES = np.array(
    [0.5, 0.01, 1.0, 2.0, 2.0, 0.5, 1.0, 2.0, 0.1, 5.0],
    dtype=float
)

COMP_PENALTY_LAMBDA = 1.0

CLIENT_COMM_BUDGETS = np.array(
    [3, 5, 1, 2, 4, 6, 1, 6, 2, 5],
    dtype=float
)

COMM_PENALTY_LAMBDA = 2.0


LOG_DIR = "/content/drive/MyDrive/FL_59500_RL_Logs"
os.makedirs(LOG_DIR, exist_ok=True)
print("Logs will be saved to:", LOG_DIR)


Mounted at /content/drive
Using device: cuda
Logs will be saved to: /content/drive/MyDrive/FL_59500_RL_Logs


# Dataset & non-IID split

In [None]:
def get_transforms(dataset_name):
    if dataset_name.lower() == "cifar10":
        transform_train = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465),
                        (0.2023, 0.1994, 0.2010))
        ])
        transform_test = T.Compose([
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465),
                        (0.2023, 0.1994, 0.2010))
        ])
    elif dataset_name.lower() == "mnist":
        transform_train = T.Compose([
            T.RandomRotation(10),
            T.ToTensor(),
            T.Normalize((0.1307,), (0.3081,))
        ])
        transform_test = T.Compose([
            T.ToTensor(),
            T.Normalize((0.1307,), (0.3081,))
        ])
    else:
        raise ValueError("Unknown dataset: " + dataset_name)
    return transform_train, transform_test

transform_train, transform_test = get_transforms(DATASET_NAME)

if DATASET_NAME.lower() == "cifar10":
    train_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform_train
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )
    NUM_CLASSES = 10
elif DATASET_NAME.lower() == "mnist":
    train_dataset = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=transform_train
    )
    test_dataset = torchvision.datasets.MNIST(
        root="./data", train=False, download=True, transform=transform_test
    )
    NUM_CLASSES = 10
else:
    raise ValueError("Unknown dataset")

if hasattr(train_dataset, "targets"):
    all_targets = np.array(train_dataset.targets)
elif hasattr(train_dataset, "labels"):
    all_targets = np.array(train_dataset.labels)
else:
    all_targets = np.array([train_dataset[i][1] for i in range(len(train_dataset))])

def iid_split(num_clients, n_samples, seed=42):
    rng = np.random.RandomState(seed)
    indices = np.arange(n_samples)
    rng.shuffle(indices)
    splits = np.array_split(indices, num_clients)
    return [split.tolist() for split in splits]

def dirichlet_split(num_clients, targets, alpha, seed=42):

    rng = np.random.RandomState(seed)
    n_classes = int(targets.max()) + 1
    idx_by_class = {c: np.where(targets == c)[0] for c in range(n_classes)}
    client_indices = [[] for _ in range(num_clients)]

    dirichlet_conc = 1.0 / alpha
    dirichlet_conc = float(np.clip(dirichlet_conc, 1e-3, 1e3))

    for c in range(n_classes):
        idx_c = idx_by_class[c].copy()
        rng.shuffle(idx_c)

        proportions = rng.dirichlet(dirichlet_conc * np.ones(num_clients))
        split_points = (np.cumsum(proportions) * len(idx_c)).astype(int)[:-1]
        class_splits = np.split(idx_c, split_points)

        for i in range(num_clients):
            client_indices[i].extend(class_splits[i].tolist())

    for i in range(num_clients):
        rng.shuffle(client_indices[i])
    return client_indices

if ALPHA_DIRICHLET == 0:
    print("Using IID split")
    client_indices = iid_split(NUM_CLIENTS, len(train_dataset))
else:
    print(f"Using Dirichlet non-IID split, alpha={ALPHA_DIRICHLET} (larger = more non-IID)")
    client_indices = dirichlet_split(NUM_CLIENTS, all_targets, ALPHA_DIRICHLET)


test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)


for cid, idxs in enumerate(client_indices):
    print(f"Client {cid}: {len(idxs)} samples")


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.05MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.11MB/s]

Using Dirichlet non-IID split, alpha=1.0 (larger = more non-IID)
Client 0: 6917 samples
Client 1: 7463 samples
Client 2: 6669 samples
Client 3: 4864 samples
Client 4: 6447 samples
Client 5: 2824 samples
Client 6: 5674 samples
Client 7: 6792 samples
Client 8: 5194 samples
Client 9: 7156 samples





# Model definition & helpers



In [None]:
def get_model(dataset_name, num_classes):

    base = torchvision.models.resnet18(weights=None)

    if dataset_name.lower() == "mnist":

        base.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1,
                               padding=1, bias=False)
        base.maxpool = nn.Identity()
    else:
        base.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1,
                               padding=1, bias=False)
        base.maxpool = nn.Identity()

    in_features = base.fc.in_features
    base.fc = nn.Linear(in_features, num_classes)
    return base.to(device)

def get_model_params_vector(model):

    with torch.no_grad():
        return torch.cat([p.view(-1).cpu() for p in model.parameters()])

def set_model_params(model, global_model):

    model.load_state_dict(global_model.state_dict())


# Client class

In [None]:
class FLClient:
    def __init__(self, client_id, dataset, indices, capability, batch_size=64, lr=0.01):
        self.client_id = client_id
        self.dataset = dataset
        self.indices = indices
        self.capability = float(capability)
        self.batch_size = batch_size
        self.lr = lr

        self.data_size = len(indices)

        # Rolling pointer over shuffled local indices
        self.current_perm = np.array(self.indices)
        np.random.shuffle(self.current_perm)
        self.ptr = 0
    def _get_train_indices_for_action(self, action_value):

        total_required = int(round(action_value * self.data_size))
        total_required = max(1, total_required)

        chosen = []

        while len(chosen) < total_required and self.data_size > 0:
            remaining_in_cycle = self.data_size - self.ptr
            need = total_required - len(chosen)
            take = min(remaining_in_cycle, need)

            chosen.extend(self.current_perm[self.ptr: self.ptr + take].tolist())
            self.ptr += take

            if self.ptr >= self.data_size and len(chosen) < total_required:
                np.random.shuffle(self.current_perm)
                self.ptr = 0

        return chosen

    def get_full_train_loader(self):
        subset = Subset(self.dataset, self.indices)
        loader = DataLoader(subset, batch_size=self.batch_size,
                            shuffle=False, num_workers=2)
        return loader

    def get_action_train_loader(self, action_value):
        """
        BN-safe DataLoader for the chosen action_value.
        """
        subset_indices = self._get_train_indices_for_action(action_value)
        subset = Subset(self.dataset, subset_indices)
        n = len(subset_indices)

        if n < 2:
            loader = DataLoader(subset, batch_size=1, shuffle=False, num_workers=0)
            return loader, n

        curr_bs = min(self.batch_size, n)
        loader = DataLoader(
            subset,
            batch_size=curr_bs,
            shuffle=True,
            num_workers=2,
            drop_last=True  # avoid batch_size=1 for BN
        )
        return loader, n

    def local_train(self, model, action_value, device):

        model = model.to(device)

        train_loader, num_samples_used = self.get_action_train_loader(action_value)

        if num_samples_used >= 2:
            model.train()
            optimizer = optim.SGD(model.parameters(), lr=self.lr,
                                  momentum=0.9, weight_decay=5e-4)

            for x, y in train_loader:
                if x.size(0) < 2:
                    continue  # BN safety
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                out = model(x)
                loss = F.cross_entropy(out, y)
                loss.backward()
                optimizer.step()


        model.eval()
        full_loader = self.get_full_train_loader()
        correct = 0
        total = 0
        total_loss = 0.0

        with torch.no_grad():
            for x, y in full_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                loss = F.cross_entropy(out, y)
                total_loss += loss.item() * x.size(0)
                preds = out.argmax(dim=1)
                correct += (preds == y).sum().item()
                total += x.size(0)

        local_loss = total_loss / max(1, total)
        local_acc = correct / max(1, total)

        if self.data_size > 0:
            effective_epochs = num_samples_used / float(self.data_size)
        else:
            effective_epochs = 0.0

        return model, local_loss, local_acc, num_samples_used, effective_epochs


# Server class (Partial sharing + shared-only distance for reward)

In [None]:
class FLServer:
    def __init__(self, global_model, test_loader, device):
        self.global_model = global_model.to(device)
        self.test_loader = test_loader
        self.device = device

        self.param_keys = list(self.global_model.state_dict().keys())
        self.num_param_keys = len(self.param_keys)

    def aggregate(self, client_models, num_param_to_share_per_client):

        assert len(client_models) == len(num_param_to_share_per_client)
        global_dict = self.global_model.state_dict()

        sum_dict = {k: torch.zeros_like(global_dict[k]) for k in self.param_keys}
        cnt_dict = {k: 0 for k in self.param_keys}

        for cid, model in enumerate(client_models):
            sd = model.state_dict()
            L_i = int(num_param_to_share_per_client[cid])
            L_i = max(0, min(L_i, self.num_param_keys))

            for j in range(L_i):
                k = self.param_keys[j]
                sum_dict[k] += sd[k].to(sum_dict[k].device)
                cnt_dict[k] += 1


        for k in self.param_keys:
            if cnt_dict[k] > 0:
                global_dict[k] = sum_dict[k] / float(cnt_dict[k])
            else:
                pass

        self.global_model.load_state_dict(global_dict)

    def evaluate_global(self):

        self.global_model.eval()
        self.global_model.to(self.device)

        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            for x, y in self.test_loader:
                x, y = x.to(self.device), y.to(self.device)
                out = self.global_model(x)
                loss = F.cross_entropy(out, y)
                total_loss += loss.item() * x.size(0)
                preds = out.argmax(dim=1)
                total_correct += (preds == y).sum().item()
                total_samples += x.size(0)

        avg_loss = total_loss / max(1, total_samples)
        avg_acc = total_correct / max(1, total_samples)
        return avg_loss, avg_acc


    def compute_layerwise_avg_distances_from_shared(self, client_models, num_param_to_share_per_client):

        N = len(client_models)
        assert N == len(num_param_to_share_per_client)

        sds = [m.state_dict() for m in client_models]
        Ls  = [int(max(0, min(num_param_to_share_per_client[i], self.num_param_keys))) for i in range(N)]

        layer_avg = np.zeros(self.num_param_keys, dtype=np.float32)

        for j in range(self.num_param_keys):
            shared_ids = [i for i in range(N) if Ls[i] > j]
            m = len(shared_ids)
            if m < 2:
                layer_avg[j] = 0.0
                continue

            k = self.param_keys[j]

            total = 0.0
            cnt = 0

            for a in range(m):
                i = shared_ids[a]
                wi = sds[i][k].detach().view(-1).float().cpu()
                for b in range(a + 1, m):
                    t = shared_ids[b]
                    wt = sds[t][k].detach().view(-1).float().cpu()
                    total += float(torch.norm(wi - wt, p=2).item())
                    cnt += 1

            layer_avg[j] = float(total / max(1, cnt))

        return layer_avg

    def compute_client_shared_distance_for_reward(self, layer_avg, num_param_to_share_per_client):

        N = len(num_param_to_share_per_client)
        client_dist = np.zeros(N, dtype=np.float32)

        for i in range(N):
            L_i = int(max(0, min(num_param_to_share_per_client[i], self.num_param_keys)))
            if L_i <= 0:
                client_dist[i] = 0.0
            else:
                client_dist[i] = float(np.mean(layer_avg[:L_i]))
        return client_dist

    def compute_distance_matrix(self, client_models):

        num_clients = len(client_models)
        vectors = [get_model_params_vector(m) for m in client_models]
        D = np.zeros((num_clients, num_clients), dtype=np.float32)

        for i in range(num_clients):
            for j in range(num_clients):
                if i == j:
                    D[i, j] = 0.0
                else:
                    diff = vectors[i] - vectors[j]
                    D[i, j] = float(torch.norm(diff, p=2).item())

        sum_dist = D.sum(axis=1)
        return D, sum_dist


# RL Agent

In [None]:
class RLBanditAgent(nn.Module):

    def __init__(self, num_clients,
                 num_epoch_actions,
                 num_layer_actions,
                 hidden_dim=64,
                 lr=1e-3,
                 epsilon=1.0,
                 eps_min=0.05,
                 eps_decay=0.995,
                 dropout=0.1):
        super().__init__()
        self.num_clients        = num_clients
        self.num_epoch_actions  = num_epoch_actions
        self.num_layer_actions  = num_layer_actions

        # Epsilon-greedy (shared for both heads)
        self.epsilon   = epsilon
        self.eps_min   = eps_min
        self.eps_decay = eps_decay

        # Embedding for client IDs
        self.emb_dim = min(32, hidden_dim)
        self.embed = nn.Embedding(num_clients, self.emb_dim)

        in_dim = self.emb_dim + 2

        # Shared trunk
        self.trunk = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        # Two heads: epochs and layers
        self.head_epoch = nn.Linear(hidden_dim, num_epoch_actions)
        self.head_layer = nn.Linear(hidden_dim, num_layer_actions)

        self.to(device)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)

        self.r_mean = 0.0
        self.r_std  = 1.0
        self.r_beta = 0.01

    def forward(self, client_ids_t, accs_t, caps_t):

        emb = self.embed(client_ids_t)                 # [B, emb_dim]
        ctx = torch.stack([accs_t, caps_t], dim=1)     # [B, 2]
        x = torch.cat([emb, ctx], dim=1)               # [B, emb_dim+2]
        h = self.trunk(x)
        q_epoch = self.head_epoch(h)
        q_layer = self.head_layer(h)
        return q_epoch, q_layer

    def select_actions(self, client_ids, accs, caps):

        client_ids_t = torch.tensor(client_ids, dtype=torch.long, device=device)
        accs_t       = torch.tensor(accs,      dtype=torch.float32, device=device)
        caps_t       = torch.tensor(caps,      dtype=torch.float32, device=device)

        q_epoch, q_layer = self.forward(client_ids_t, accs_t, caps_t)
        B = q_epoch.size(0)

        with torch.no_grad():
            # Greedy actions
            greedy_epoch = torch.argmax(q_epoch, dim=1)   # [B]
            greedy_layer = torch.argmax(q_layer, dim=1)   # [B]

            # Random actions
            random_epoch = torch.randint(0, self.num_epoch_actions, (B,), device=device)
            random_layer = torch.randint(0, self.num_layer_actions, (B,), device=device)

            explore_mask = (torch.rand(B, device=device) < self.epsilon)

            epoch_actions = torch.where(explore_mask, random_epoch, greedy_epoch)
            layer_actions = torch.where(explore_mask, random_layer, greedy_layer)

        return (
            epoch_actions,
            layer_actions,
            q_epoch.detach().cpu().numpy(),
            q_layer.detach().cpu().numpy()
        )

    def _update_reward_stats(self, rewards_t):
        batch_mean = rewards_t.mean().item()
        batch_std  = rewards_t.std(unbiased=False).item()
        if batch_std < 1e-6:
            batch_std = 1.0

        self.r_mean = (1 - self.r_beta) * self.r_mean + self.r_beta * batch_mean
        self.r_std  = (1 - self.r_beta) * self.r_std  + self.r_beta * batch_std

    def update(self, client_ids, accs, caps,
               epoch_actions, layer_actions, rewards):
        """
        One-step Q regression for both heads:
            Q_epoch(s,a_e) ≈ normalized reward
            Q_layer(s,a_l) ≈ normalized reward
        """
        self.optimizer.zero_grad()

        client_ids_t = torch.tensor(client_ids,    dtype=torch.long,   device=device)
        accs_t       = torch.tensor(accs,          dtype=torch.float32, device=device)
        caps_t       = torch.tensor(caps,          dtype=torch.float32, device=device)
        epoch_act_t  = torch.tensor(epoch_actions, dtype=torch.long,   device=device)
        layer_act_t  = torch.tensor(layer_actions, dtype=torch.long,   device=device)
        rewards_t    = torch.tensor(rewards,       dtype=torch.float32, device=device)

        # Normalize rewards
        self._update_reward_stats(rewards_t)
        norm_rewards = (rewards_t - self.r_mean) / (self.r_std + 1e-6)

        q_epoch, q_layer = self.forward(client_ids_t, accs_t, caps_t)

        q_epoch_sel = q_epoch[torch.arange(len(epoch_actions), device=device), epoch_act_t]
        q_layer_sel = q_layer[torch.arange(len(layer_actions), device=device), layer_act_t]

        loss_epoch = F.smooth_l1_loss(q_epoch_sel, norm_rewards)
        loss_layer = F.smooth_l1_loss(q_layer_sel, norm_rewards)
        loss = 0.5 * (loss_epoch + loss_layer)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5.0)
        self.optimizer.step()

        # decay epsilon
        self.epsilon = max(self.eps_min, self.epsilon * self.eps_decay)

        return loss.item()


# Main training loop

In [None]:
def compute_comp_penalty(capability, requested_epochs):
    """
    Penalize ONLY compute overuse based on requested epochs:
      exceed = max(0, requested_epochs - capability)
      penalty = COMP_PENALTY_LAMBDA * exceed
    """
    capability = float(capability)
    requested_epochs = float(requested_epochs)
    exceed = max(0.0, requested_epochs - capability)
    return float(COMP_PENALTY_LAMBDA) * float(exceed)

def compute_comm_penalty(requested_layers, comm_budget_layers):
    """
    Penalize ONLY comm overuse in layers:
      exceed = max(0, requested_layers - comm_budget_layers)
      penalty = COMM_PENALTY_LAMBDA * exceed
    """
    requested_layers = int(requested_layers)
    comm_budget_layers = int(comm_budget_layers)
    exceed = max(0, requested_layers - comm_budget_layers)
    return float(COMM_PENALTY_LAMBDA) * float(exceed), int(exceed)

def evaluate_model_on_test(model, test_loader, device):
    """
    Evaluate a given model on the global TEST set.
    Returns (test_loss, test_accuracy).
    """
    model.eval()
    model.to(device)

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = F.cross_entropy(out, y)
            total_loss += loss.item() * x.size(0)
            preds = out.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total_samples += x.size(0)

    avg_loss = total_loss / max(1, total_samples)
    avg_acc  = total_correct / max(1, total_samples)
    return float(avg_loss), float(avg_acc)

# 1) Build global model
global_model = get_model(DATASET_NAME, NUM_CLASSES).to(device)

# 2) Build clients
clients = []
for cid in range(NUM_CLIENTS):
    cap = CLIENT_CAPABILITIES[cid]
    c = FLClient(
        client_id=cid,
        dataset=train_dataset,
        indices=client_indices[cid],
        capability=cap,
        batch_size=BATCH_SIZE,
        lr=LR_LOCAL
    )
    clients.append(c)

# 3) Build server
server = FLServer(global_model=global_model, test_loader=test_loader, device=device)


NUM_PARAM_KEYS = server.num_param_keys
PARAM_KEYS     = server.param_keys

def apply_partial_global_to_client(client_model, global_model, num_param_keys, param_keys):

    client_sd = client_model.state_dict()
    global_sd = global_model.state_dict()

    L = int(max(0, min(num_param_keys, len(param_keys))))
    for j in range(L):
        k = param_keys[j]
        client_sd[k] = global_sd[k].clone()

    client_model.load_state_dict(client_sd)

agent = RLBanditAgent(
    num_clients=NUM_CLIENTS,
    num_epoch_actions=NUM_EPOCH_ACTIONS,
    num_layer_actions=NUM_LAYER_ACTIONS,
    hidden_dim=32,
    lr=LR_RL,
    epsilon=RL_EPSILON
)


global_log   = []
client_log   = []
distance_log = []

client_models = []
for cid in range(NUM_CLIENTS):
    m = get_model(DATASET_NAME, NUM_CLASSES)
    set_model_params(m, server.global_model)
    client_models.append(m)

prev_local_accs  = np.zeros(NUM_CLIENTS, dtype=np.float32)
client_caps_list = CLIENT_CAPABILITIES.astype(np.float32).tolist()

for rnd in range(1, NUM_ROUNDS + 1):
    print(f"\n===== Round {rnd} =====")

    local_losses         = []
    local_accs           = []
    num_used_list        = []
    eff_epochs_list      = []
    num_params_shared    = []
    epoch_action_indices = []
    layer_action_indices = []

    # ------- RL: select 2 actions per client -------
    client_ids_list = list(range(NUM_CLIENTS))
    epoch_actions_t, layer_actions_t, q_epoch_np, q_layer_np = agent.select_actions(
        client_ids=client_ids_list,
        accs=prev_local_accs,
        caps=client_caps_list
    )

    epoch_actions = epoch_actions_t.cpu().numpy()
    layer_actions = layer_actions_t.cpu().numpy()

    models_for_agg = []


    for cid, client in enumerate(clients):
        local_model = client_models[cid]


        epoch_idx   = int(epoch_actions[cid])
        epoch_value = float(EPOCH_ACTION_VALUES[epoch_idx])


        layer_idx   = int(layer_actions[cid])
        layer_val   = int(LAYER_ACTION_VALUES[layer_idx])
        num_share   = int(max(1, min(layer_val, NUM_PARAM_KEYS)))


        apply_partial_global_to_client(
            client_model=local_model,
            global_model=server.global_model,
            num_param_keys=num_share,
            param_keys=PARAM_KEYS
        )


        updated_model, local_loss, local_acc, num_used, eff_epochs = client.local_train(
            model=local_model,
            action_value=epoch_value,
            device=device
        )

        client_models[cid] = updated_model
        models_for_agg.append(updated_model)

        local_losses.append(float(local_loss))
        local_accs.append(float(local_acc))
        num_used_list.append(int(num_used))
        eff_epochs_list.append(float(eff_epochs))
        num_params_shared.append(int(num_share))
        epoch_action_indices.append(int(epoch_idx))
        layer_action_indices.append(int(layer_idx))

        print(
            f"Client {cid}: epoch_action={epoch_value}, "
            f"layer_action_value={layer_val} -> shared_params={num_share}, "
            f"loss={local_loss:.4f}, acc={local_acc:.4f}, "
            f"used={num_used}, eff_epochs={eff_epochs:.3f}"
        )


    server.aggregate(models_for_agg, num_param_to_share_per_client=num_params_shared)


    global_test_loss, global_test_acc = server.evaluate_global()
    print(f"Global-like model test accuracy: {global_test_acc:.4f}, loss: {global_test_loss:.4f}")

    layer_avg = server.compute_layerwise_avg_distances_from_shared(
        client_models=models_for_agg,
        num_param_to_share_per_client=num_params_shared
    )
    client_shared_dist = server.compute_client_shared_distance_for_reward(
        layer_avg=layer_avg,
        num_param_to_share_per_client=num_params_shared
    )


    client_test_losses = []
    client_test_accs   = []
    for cid in range(NUM_CLIENTS):
        tl, ta = evaluate_model_on_test(client_models[cid], test_loader, device)
        client_test_losses.append(float(tl))
        client_test_accs.append(float(ta))
        print(f"Client {cid}: test_acc={ta:.4f}, test_loss={tl:.4f}")

    avg_client_test_acc  = float(np.mean(client_test_accs))
    avg_client_test_loss = float(np.mean(client_test_losses))
    print(f"Average client test accuracy: {avg_client_test_acc:.4f}, "
          f"Average client test loss: {avg_client_test_loss:.4f}")

    rewards            = []
    comp_penalties     = []
    comm_penalties     = []
    exceed_layer_counts = []

    for cid in range(NUM_CLIENTS):
        acc_i   = float(local_accs[cid])
        cap_i   = float(CLIENT_CAPABILITIES[cid])

        # requested action values
        epoch_idx   = int(epoch_action_indices[cid])
        epoch_value = float(EPOCH_ACTION_VALUES[epoch_idx])

        shared_layers = int(num_params_shared[cid])
        pref_layers   = int(CLIENT_COMM_BUDGETS[cid])

        # penalties
        comp_penalty = compute_comp_penalty(cap_i, epoch_value)
        comm_penalty, exceed_layers = compute_comm_penalty(shared_layers, pref_layers)

        # shared-only distance term (already averaged by layers shared)
        dist_i = float(client_shared_dist[cid])

        # reward
        reward_i = acc_i + cap_i - dist_i - comp_penalty - comm_penalty

        rewards.append(float(reward_i))
        comp_penalties.append(float(comp_penalty))
        comm_penalties.append(float(comm_penalty))
        exceed_layer_counts.append(int(exceed_layers))

        print(
            f"Client {cid}: reward={reward_i:.4f} "
            f"(local_acc={acc_i:.4f}, cap={cap_i:.2f}, "
            f"shared_dist={dist_i:.4f}, "
            f"epoch_req={epoch_value:.3f}, comp_penalty={comp_penalty:.4f}, "
            f"shared_layers={shared_layers}, pref_layers={pref_layers}, "
            f"exceed_layers={exceed_layers}, comm_penalty={comm_penalty:.4f})"
        )

    # ------- RL update -------
    rl_loss = agent.update(
        client_ids=client_ids_list,
        accs=prev_local_accs,
        caps=client_caps_list,
        epoch_actions=epoch_actions,
        layer_actions=layer_actions,
        rewards=rewards
    )
    print(f"RL loss: {rl_loss:.4f}")

    # Update RL state
    prev_local_accs = np.array(local_accs, dtype=np.float32)

    # ------- Logging (per client) -------
    for cid in range(NUM_CLIENTS):
        epoch_idx = int(epoch_action_indices[cid])
        layer_idx = int(layer_action_indices[cid])

        layer_val   = int(LAYER_ACTION_VALUES[layer_idx])
        num_share   = int(num_params_shared[cid])
        pref_layers = int(CLIENT_COMM_BUDGETS[cid])

        row = {
            "round": rnd,
            "client_id": cid,

            # Local train
            "local_loss": float(local_losses[cid]),
            "local_acc": float(local_accs[cid]),

            # Client global test
            "client_test_loss": float(client_test_losses[cid]),
            "client_test_acc": float(client_test_accs[cid]),

            # Budgets
            "capability": float(CLIENT_CAPABILITIES[cid]),
            "comm_budget_layers": float(CLIENT_COMM_BUDGETS[cid]),

            # State
            "last_acc_state": float(prev_local_accs[cid]),

            # Actions
            "epoch_action_index": int(epoch_idx),
            "epoch_action_value": float(EPOCH_ACTION_VALUES[epoch_idx]),
            "layer_action_index": int(layer_idx),
            "layer_action_value": int(layer_val),

            # Sharing stats
            "num_param_keys_shared": int(num_share),
            "preferred_layers": int(pref_layers),
            "exceed_layers": int(exceed_layer_counts[cid]),

            # Shared-only distance used for reward
            "shared_dist_reward": float(client_shared_dist[cid]),

            # Penalties + reward
            "comp_penalty": float(comp_penalties[cid]),
            "comm_penalty": float(comm_penalties[cid]),
            "reward": float(rewards[cid]),
        }

        for a_idx, q in enumerate(q_epoch_np[cid]):
            row[f"Q_epoch_action_{a_idx}"] = float(q)

        for a_idx, q in enumerate(q_layer_np[cid]):
            row[f"Q_layer_action_{a_idx}"] = float(q)

        client_log.append(row)

    # ------- Logging (per round) -------
    global_log.append({
        "round": rnd,
        "global_test_loss": float(global_test_loss),
        "global_test_acc": float(global_test_acc),
        "avg_client_test_loss": float(avg_client_test_loss),
        "avg_client_test_acc": float(avg_client_test_acc),
    })



===== Round 1 =====
Client 0: epoch_action=0.1, layer_action_value=4 -> shared_params=4, loss=1.6766, acc=0.4204, used=692, eff_epochs=0.100
Client 1: epoch_action=0.1, layer_action_value=1 -> shared_params=1, loss=2.0107, acc=0.2642, used=746, eff_epochs=0.100
Client 2: epoch_action=0.1, layer_action_value=6 -> shared_params=6, loss=1.8860, acc=0.3890, used=667, eff_epochs=0.100
Client 3: epoch_action=2.0, layer_action_value=5 -> shared_params=5, loss=0.1694, acc=0.9453, used=9728, eff_epochs=2.000
Client 4: epoch_action=0.1, layer_action_value=6 -> shared_params=6, loss=1.9592, acc=0.3371, used=645, eff_epochs=0.100
Client 5: epoch_action=0.01, layer_action_value=1 -> shared_params=1, loss=2.1635, acc=0.2836, used=28, eff_epochs=0.010
Client 6: epoch_action=2.0, layer_action_value=5 -> shared_params=5, loss=0.1582, acc=0.9544, used=11348, eff_epochs=2.000
Client 7: epoch_action=2.0, layer_action_value=2 -> shared_params=2, loss=0.0972, acc=0.9738, used=13584, eff_epochs=2.000
Client

# Save logs to CSV

In [None]:
global_df   = pd.DataFrame(global_log)
client_df   = pd.DataFrame(client_log)
distance_df = pd.DataFrame(distance_log)

prefix = "NEW_2_2D_MNIST_ALPHA_1_RESNET18_PFL_RL"

global_path   = os.path.join(LOG_DIR, f"{prefix}_global_log.csv")
client_path   = os.path.join(LOG_DIR, f"{prefix}_client_log.csv")
distance_path = os.path.join(LOG_DIR, f"{prefix}_distance_log.csv")

global_df.to_csv(global_path, index=False)
client_df.to_csv(client_path, index=False)
distance_df.to_csv(distance_path, index=False)

print("Saved:")
print("  ", global_path)
print("  ", client_path)
print("  ", distance_path)


Saved:
   /content/drive/MyDrive/FL_59500_RL_Logs/NEW_2_2D_MNIST_ALPHA_1_RESNET18_PFL_RL_global_log.csv
   /content/drive/MyDrive/FL_59500_RL_Logs/NEW_2_2D_MNIST_ALPHA_1_RESNET18_PFL_RL_client_log.csv
   /content/drive/MyDrive/FL_59500_RL_Logs/NEW_2_2D_MNIST_ALPHA_1_RESNET18_PFL_RL_distance_log.csv
