## Hybrid PSO + Adam with Multiprocessing Nodes

This notebook implements a **hybrid PSO–Adam training setup** using **processes as nodes**.

- We reuse the **data splits from the previous data-split approach** via `splits.pt` created in `Approach_1.ipynb`.
- We create **5 worker processes (PSO/Adam nodes)** and **1 main node (Adam + aggregation)**.
- Communication between nodes uses **`multiprocessing.Queue`** to minimize IPC operations.
- We **track the number of queue reads/writes** and compare communication cost with model performance.

High-level loop per communication round:

1. Main node sends the current global model to each worker.
2. Each worker trains locally (Adam + PSO-inspired search) for a few epochs and returns its best model.
3. Main node **aggregates the 5 worker models** using the **same weighted-averaging aggregator** used before.
4. Main node runs **Adam on the aggregated model** for a few epochs.
5. Steps 1–4 repeat for several rounds while we log metrics and IPC counts.

In [1]:
import os
import copy
import time
import math
import queue
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from multiprocessing import Process, Queue, set_start_method

import matplotlib.pyplot as plt

# For Windows / notebook safety
try:
    set_start_method("spawn")
except RuntimeError:
    # Already set in this interpreter
    pass

DEVICE = torch.device("cpu")
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

print("Using device:", DEVICE)

Using device: cpu


In [2]:
# Load tensor splits created in Approach_1.ipynb

splits_path = "splits.pt"
assert os.path.exists(splits_path), f"{splits_path} not found. Please run the data-split notebook first."

raw_splits = torch.load(splits_path)
print(f"Loaded {len(raw_splits)} tensor splits from {splits_path}")


def make_dataloaders_from_splits(tensor_splits, batch_size=1024):
    """Create train/test DataLoaders for each split."""
    loaders = []
    for i, s in enumerate(tensor_splits):
        X_train, y_train = s["X_train"], s["y_train"]
        X_test, y_test = s["X_test"], s["y_test"]

        train_ds = TensorDataset(X_train, y_train)
        test_ds = TensorDataset(X_test, y_test)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

        loaders.append({
            "train_loader": train_loader,
            "test_loader": test_loader,
        })

    return loaders

split_loaders = make_dataloaders_from_splits(raw_splits, batch_size=2048)
print(f"Created DataLoaders for {len(split_loaders)} splits")

# Infer global user/movie ID ranges from the first split
example_X = raw_splits[0]["X_train"]
n_users_global = int(example_X[:, 0].max().item()) + 1
n_movies_global = int(example_X[:, 1].max().item()) + 1

print("n_users_global =", n_users_global)
print("n_movies_global =", n_movies_global)

Loaded 5 tensor splits from splits.pt
Created DataLoaders for 5 splits
n_users_global = 1000
n_movies_global = 11453


  raw_splits = torch.load(splits_path)


In [3]:
# Model definition (same as in Approach_1)

class CollabFiltering(nn.Module):
    def __init__(self, n_users, n_movies, emb_dim=16, hidden=16, dropout=0.1):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, emb_dim)
        self.movie_emb = nn.Embedding(n_movies, emb_dim)
        self.dropout_emb = 0.4

        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
            nn.ReLU(),
        )

    def forward(self, user, movie):
        u = F.dropout(self.user_emb(user), p=self.dropout_emb, training=self.training)
        m = F.dropout(self.movie_emb(movie), p=self.dropout_emb, training=self.training)
        x = torch.cat([u, m], dim=1)
        return self.mlp(x).squeeze()


loss_fn = nn.MSELoss()


def create_model():
    model = CollabFiltering(n_users_global, n_movies_global, emb_dim=16, hidden=16, dropout=0.1)
    return model.to(DEVICE)

In [4]:
# Model definition (same as in Approach_1)

class CollabFiltering(nn.Module):
    def __init__(self, n_users, n_movies, emb_dim=16, hidden=16, dropout=0.1):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, emb_dim)
        self.movie_emb = nn.Embedding(n_movies, emb_dim)
        self.dropout_emb = 0.4

        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
            nn.ReLU(),
        )

    def forward(self, user, movie):
        u = F.dropout(self.user_emb(user), p=self.dropout_emb, training=self.training)
        m = F.dropout(self.movie_emb(movie), p=self.dropout_emb, training=self.training)
        x = torch.cat([u, m], dim=1)
        return self.mlp(x).squeeze()


loss_fn = nn.MSELoss()


def create_model():
    model = CollabFiltering(n_users_global, n_movies_global, emb_dim=16, hidden=16, dropout=0.1)
    return model.to(DEVICE)

In [5]:
# Aggregation utilities (same technique as before: weighted average of state_dicts)

@torch.no_grad()
def aggregate_models_cpu(weights, node_states):
    """Weighted average of multiple model state_dicts on CPU."""
    n_nodes = len(node_states)
    assert len(weights) == n_nodes

    agg_state = {}
    for key in node_states[0].keys():
        agg_param = torch.zeros_like(node_states[0][key])
        for i in range(n_nodes):
            agg_param += weights[i] * node_states[i][key]
        agg_state[key] = agg_param
    return agg_state


@torch.no_grad()
def evaluate_model_state(state_dict, model_template, data_loader, loss_fn):
    model = copy.deepcopy(model_template)
    model.load_state_dict(state_dict)
    model.eval()

    total_loss = 0.0
    total_batches = 0
    for X_batch, y_batch in data_loader:
        X_batch = X_batch.to(DEVICE)
        y_batch = y_batch.float().to(DEVICE)
        preds = model(X_batch[:, 0].long(), X_batch[:, 1].long())
        loss = loss_fn(preds, y_batch)
        total_loss += loss.item()
        total_batches += 1
    return total_loss / max(total_batches, 1)

In [6]:
# Local training helpers


def train_one_epoch_adam(model, train_loader, optimizer, loss_fn):
    model.train()
    total_loss = 0.0
    total_batches = 0
    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(DEVICE)
        y_batch = y_batch.float().to(DEVICE)

        optimizer.zero_grad()
        preds = model(X_batch[:, 0].long(), X_batch[:, 1].long())
        loss = loss_fn(preds, y_batch)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_batches += 1
    return total_loss / max(total_batches, 1)


@torch.no_grad()
def evaluate_model(model, data_loader, loss_fn):
    model.eval()
    total_loss = 0.0
    total_batches = 0
    for X_batch, y_batch in data_loader:
        X_batch = X_batch.to(DEVICE)
        y_batch = y_batch.float().to(DEVICE)
        preds = model(X_batch[:, 0].long(), X_batch[:, 1].long())
        loss = loss_fn(preds, y_batch)
        total_loss += loss.item()
        total_batches += 1
    return total_loss / max(total_batches, 1)

In [7]:
# Worker process: represents a PSO/Adam node


def worker_node(worker_id, cmd_queue, result_queue, split_index,
                base_lr=1e-3, local_epochs=5):
    """Worker loop running in a separate process.

    Each worker:
    - Receives the current global model parameters.
    - Trains locally with Adam for a few epochs on its data split.
    - Sends back its best local model and metrics.
    """
    # Recreate model and data inside the process (safe under spawn)
    model = create_model()
    loaders = split_loaders[split_index]
    train_loader = loaders["train_loader"]
    test_loader = loaders["test_loader"]

    while True:
        msg = cmd_queue.get()  # 1 read on cmd_queue
        msg_type = msg.get("type", None)

        if msg_type == "stop":
            break

        assert msg_type == "train", f"Unknown message type: {msg_type}"
        round_idx = msg["round"]
        global_state = msg["state_dict"]

        # Load global model
        model.load_state_dict(global_state)
        optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)

        # Local training (Adam)
        last_train_loss = None
        for _ in range(local_epochs):
            last_train_loss = train_one_epoch_adam(model, train_loader, optimizer, loss_fn)

        # Evaluate on local test data
        local_test_loss = evaluate_model(model, test_loader, loss_fn)

        # Send result back to main node (1 write on result_queue)
        result_queue.put({
            "worker_id": worker_id,
            "round": round_idx,
            "state_dict": copy.deepcopy(model.state_dict()),
            "train_loss": float(last_train_loss),
            "test_loss": float(local_test_loss),
        })

In [8]:
# Main orchestration: 5 worker nodes + 1 main node


def run_hybrid_pso_adam(num_rounds=3,
                         num_workers=5,
                         local_epochs=5,
                         global_adam_epochs=3,
                         base_lr=1e-3):
    assert num_workers <= len(split_loaders)

    # Queues for IPC
    cmd_queues = [Queue() for _ in range(num_workers)]  # per-worker command queue
    result_queue = Queue()  # shared results queue

    # Spawn worker processes
    workers = []
    for wid in range(num_workers):
        p = Process(
            target=worker_node,
            args=(wid, cmd_queues[wid], result_queue, wid, base_lr, local_epochs),
        )
        p.start()
        workers.append(p)

    # Main node model
    global_model = create_model()

    # Use the first split's test loader as a simple global validation set
    val_loader = split_loaders[0]["test_loader"]

    history = {
        "round": [],
        "global_val_loss": [],
        "avg_worker_test_loss": [],
    }

    reads_per_round = []
    writes_per_round = []

    for r in range(num_rounds):
        # 1) Main node broadcasts current global model to all workers
        state_to_send = copy.deepcopy(global_model.state_dict())
        for q in cmd_queues:
            q.put({"type": "train", "round": r, "state_dict": state_to_send})
        writes_this_round = num_workers  # command writes

        # 2) Collect results from workers
        worker_states = []
        worker_test_losses = []
        for _ in range(num_workers):
            msg = result_queue.get()
            worker_states.append(msg["state_dict"])
            worker_test_losses.append(msg["test_loss"])
        reads_this_round = num_workers  # result reads

        # 3) PSO aggregation on the 5 worker models (hybrid PSO + Adam)
        model_template = create_model()
        best_weights, best_score = pso_optimize_aggregation(
            worker_states,
            val_loader,
            model_template,
            loss_fn,
            num_particles=10,
            max_iters=5,
        )
        agg_state = aggregate_models_cpu(best_weights, worker_states)

        # 4) Main node runs Adam on aggregated model for a few epochs
        global_model.load_state_dict(agg_state)
        optimizer = torch.optim.Adam(global_model.parameters(), lr=base_lr)
        for _ in range(global_adam_epochs):
            _ = train_one_epoch_adam(global_model, val_loader, optimizer, loss_fn)

        # 5) Log metrics
        global_val_loss = evaluate_model(global_model, val_loader, loss_fn)
        history["round"].append(r)
        history["global_val_loss"].append(float(global_val_loss))
        history["avg_worker_test_loss"].append(float(np.mean(worker_test_losses)))

        reads_per_round.append(reads_this_round)
        writes_per_round.append(writes_this_round)

        print(
            f"Round {r}: global_val_loss={global_val_loss:.4f}, "
            f"avg_worker_test_loss={np.mean(worker_test_losses):.4f}"
        )

    # 6) Stop workers (one extra command per worker)
    for q in cmd_queues:
        q.put({"type": "stop"})
    extra_writes = num_workers

    for p in workers:
        p.join()

    # Analytical IPC accounting (Queue-based, each message = 1 write + 1 read)
    train_messages = num_rounds * (2 * num_workers)  # train commands + results
    stop_messages = num_workers  # stop commands
    total_messages = train_messages + stop_messages

    comm_stats = {
        "num_rounds": num_rounds,
        "num_workers": num_workers,
        "reads_per_round": reads_per_round,
        "writes_per_round": writes_per_round,
        "total_train_round_messages": train_messages,
        "total_stop_messages": stop_messages,
        "total_queue_writes": total_messages,
        "total_queue_reads": total_messages,
    }

    return global_model, history, comm_stats

In [9]:
# Main orchestration: 5 worker nodes + 1 main node


def run_hybrid_pso_adam(num_rounds=3,
                         num_workers=5,
                         local_epochs=5,
                         global_adam_epochs=3,
                         base_lr=1e-3):
    assert num_workers <= len(split_loaders)

    # Queues for IPC
    cmd_queues = [Queue() for _ in range(num_workers)]  # per-worker command queue
    result_queue = Queue()  # shared results queue

    # Spawn worker processes
    workers = []
    for wid in range(num_workers):
        p = Process(
            target=worker_node,
            args=(wid, cmd_queues[wid], result_queue, wid, base_lr, local_epochs),
        )
        p.start()
        workers.append(p)

    # Main node model
    global_model = create_model()

    # Use the first split's test loader as a simple global validation set
    val_loader = split_loaders[0]["test_loader"]

    history = {
        "round": [],
        "global_val_loss": [],
        "avg_worker_test_loss": [],
    }

    reads_per_round = []
    writes_per_round = []

    for r in range(num_rounds):
        # 1) Main node broadcasts current global model to all workers
        state_to_send = copy.deepcopy(global_model.state_dict())
        for q in cmd_queues:
            q.put({"type": "train", "round": r, "state_dict": state_to_send})
        writes_this_round = num_workers  # command writes

        # 2) Collect results from workers
        worker_states = []
        worker_test_losses = []
        for _ in range(num_workers):
            msg = result_queue.get()
            worker_states.append(msg["state_dict"])
            worker_test_losses.append(msg["test_loss"])
        reads_this_round = num_workers  # result reads

        # 3) PSO aggregation on the 5 worker models (hybrid PSO + Adam)
        model_template = create_model()
        best_weights, best_score = pso_optimize_aggregation(
            worker_states,
            val_loader,
            model_template,
            loss_fn,
            num_particles=10,
            max_iters=5,
        )
        agg_state = aggregate_models_cpu(best_weights, worker_states)

        # 4) Main node runs Adam on aggregated model for a few epochs
        global_model.load_state_dict(agg_state)
        optimizer = torch.optim.Adam(global_model.parameters(), lr=base_lr)
        for _ in range(global_adam_epochs):
            _ = train_one_epoch_adam(global_model, val_loader, optimizer, loss_fn)

        # 5) Log metrics
        global_val_loss = evaluate_model(global_model, val_loader, loss_fn)
        history["round"].append(r)
        history["global_val_loss"].append(float(global_val_loss))
        history["avg_worker_test_loss"].append(float(np.mean(worker_test_losses)))

        reads_per_round.append(reads_this_round)
        writes_per_round.append(writes_this_round)

        print(
            f"Round {r}: global_val_loss={global_val_loss:.4f}, "
            f"avg_worker_test_loss={np.mean(worker_test_losses):.4f}"
        )

    # 6) Stop workers (one extra command per worker)
    for q in cmd_queues:
        q.put({"type": "stop"})
    extra_writes = num_workers

    for p in workers:
        p.join()

    # Analytical IPC accounting (Queue-based, each message = 1 write + 1 read)
    train_messages = num_rounds * (2 * num_workers)  # train commands + results
    stop_messages = num_workers  # stop commands
    total_messages = train_messages + stop_messages

    comm_stats = {
        "num_rounds": num_rounds,
        "num_workers": num_workers,
        "reads_per_round": reads_per_round,
        "writes_per_round": writes_per_round,
        "total_train_round_messages": train_messages,
        "total_stop_messages": stop_messages,
        "total_queue_writes": total_messages,
        "total_queue_reads": total_messages,
    }

    return global_model, history, comm_stats

In [None]:
# Run the hybrid PSO–Adam experiment
if __name__ == "__main__":
    global_model, history, comm_stats = run_hybrid_pso_adam(
        num_rounds=3,
        num_workers=5,
        local_epochs=5,
        global_adam_epochs=3,
        base_lr=1e-3,
    )

    print("\n=== Communication statistics (Queue-based IPC) ===")
    for k, v in comm_stats.items():
        print(f"{k}: {v}")

fhab


In [None]:
# Plot losses and IPC counts

rounds = history["round"]

plt.figure(figsize=(12, 4))

# Loss curves
plt.subplot(1, 2, 1)
plt.plot(rounds, history["global_val_loss"], marker="o", label="Global (Adam after aggregation)")
plt.plot(rounds, history["avg_worker_test_loss"], marker="s", label="Avg worker test loss")
plt.xlabel("Communication round")
plt.ylabel("Loss (MSE)")
plt.title("Hybrid PSO–Adam: Loss vs Rounds")
plt.grid(True, alpha=0.3)
plt.legend()

# IPC counts per round
plt.subplot(1, 2, 2)
reads = comm_stats["reads_per_round"]
writes = comm_stats["writes_per_round"]
indices = np.arange(len(rounds))
width = 0.35

plt.bar(indices - width / 2, writes, width, label="Writes (commands)")
plt.bar(indices + width / 2, reads, width, label="Reads (results)")
plt.xticks(indices, rounds)
plt.xlabel("Communication round")
plt.ylabel("Queue ops per round")
plt.title("IPC: Queue Reads/Writes per Round")
plt.grid(True, axis="y", alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()

print("\nTotal queue writes:", comm_stats["total_queue_writes"])
print("Total queue reads:", comm_stats["total_queue_reads"])

In [None]:
# Verbose PSO over aggregation weights (with progress prints)

import numpy as np


def pso_optimize_aggregation(node_states, val_loader, model_template, loss_fn,
                             num_particles=10, max_iters=5, w=0.7, c1=1.5, c2=1.5):
    """Small PSO to find good aggregation weights for the workers.

    This version prints progress every iteration so we can see that it is running.
    """
    num_nodes = len(node_states)

    print(f"[PSO] Starting optimization with {num_particles} particles, {max_iters} iterations, {num_nodes} nodes")

    # Initialize particles on the simplex via Dirichlet
    particles = np.random.dirichlet(np.ones(num_nodes), size=num_particles)
    velocities = np.zeros_like(particles)
    pbest_positions = particles.copy()
    pbest_scores = np.full(num_particles, np.inf)

    gbest_position = None
    gbest_score = np.inf

    for it in range(max_iters):
        for i in range(num_particles):
            w_vec = np.abs(particles[i])
            w_vec /= np.sum(w_vec)

            agg_state = aggregate_models_cpu(w_vec, node_states)
            score = evaluate_model_state(agg_state, model_template, val_loader, loss_fn)

            if score < pbest_scores[i]:
                pbest_scores[i] = score
                pbest_positions[i] = w_vec.copy()

            if score < gbest_score:
                gbest_score = score
                gbest_position = w_vec.copy()

        # Velocity/position update
        for i in range(num_particles):
            r1 = np.random.rand(num_nodes)
            r2 = np.random.rand(num_nodes)
            velocities[i] = (
                w * velocities[i]
                + c1 * r1 * (pbest_positions[i] - particles[i])
                + c2 * r2 * (gbest_position - particles[i])
            )
            particles[i] += velocities[i]

        print(f"[PSO] Iteration {it+1}/{max_iters} | best_score={gbest_score:.6f}")

    print("[PSO] Finished optimization.")
    print("[PSO] Best weight vector:", np.round(gbest_position, 4))
    print(f"[PSO] Best validation loss: {gbest_score:.6f}")

    return gbest_position, gbest_score

In [None]:
# Verbose worker implementation with clear prints


def worker_node(worker_id, cmd_queue, result_queue, split_index,
                base_lr=1e-3, local_epochs=5):
    """Worker loop running in a separate process (verbose).

    Each worker:
    - Receives the current global model parameters.
    - Trains locally with Adam for a few epochs on its data split.
    - Sends back its best local model and metrics.
    """
    print(f"[Worker {worker_id}] Starting. Using split index {split_index}.")

    # Recreate model and data inside the process (safe under spawn)
    model = create_model()
    loaders = split_loaders[split_index]
    train_loader = loaders["train_loader"]
    test_loader = loaders["test_loader"]

    while True:
        msg = cmd_queue.get()  # 1 read on cmd_queue
        msg_type = msg.get("type", None)
        round_idx = msg.get("round", "-")

        if msg_type == "stop":
            print(f"[Worker {worker_id}] Received stop signal. Exiting.")
            break

        assert msg_type == "train", f"Unknown message type: {msg_type}"
        print(f"[Worker {worker_id}] Received TRAIN command for round {round_idx}.")

        global_state = msg["state_dict"]

        # Load global model
        model.load_state_dict(global_state)
        optimizer = torch.optim.Adam(model.parameters(), lr=base_lr)

        # Local training (Adam)
        last_train_loss = None
        for local_ep in range(local_epochs):
            last_train_loss = train_one_epoch_adam(model, train_loader, optimizer, loss_fn)
            print(f"[Worker {worker_id}] Round {round_idx} | Local epoch {local_ep+1}/{local_epochs} | train_loss={last_train_loss:.6f}")

        # Evaluate on local test data
        local_test_loss = evaluate_model(model, test_loader, loss_fn)
        print(f"[Worker {worker_id}] Round {round_idx} | Finished local training. test_loss={local_test_loss:.6f}")

        # Send result back to main node (1 write on result_queue)
        result_queue.put({
            "worker_id": worker_id,
            "round": round_idx,
            "state_dict": copy.deepcopy(model.state_dict()),
            "train_loss": float(last_train_loss),
            "test_loss": float(local_test_loss),
        })
        print(f"[Worker {worker_id}] Round {round_idx} | Result sent back to main.")

In [None]:
# Verbose main orchestration so we can follow progress clearly


def run_hybrid_pso_adam(num_rounds=3,
                         num_workers=5,
                         local_epochs=5,
                         global_adam_epochs=3,
                         base_lr=1e-3):
    assert num_workers <= len(split_loaders)

    print("\n[Main] =========================================")
    print("[Main] Starting hybrid PSO–Adam run")
    print(f"[Main] num_rounds={num_rounds}, num_workers={num_workers}, local_epochs={local_epochs}, global_adam_epochs={global_adam_epochs}, base_lr={base_lr}")
    print("[Main] =========================================\n")

    # Queues for IPC
    cmd_queues = [Queue() for _ in range(num_workers)]  # per-worker command queue
    result_queue = Queue()  # shared results queue

    # Spawn worker processes
    workers = []
    for wid in range(num_workers):
        print(f"[Main] Spawning worker process {wid} for split {wid}.")
        p = Process(
            target=worker_node,
            args=(wid, cmd_queues[wid], result_queue, wid, base_lr, local_epochs),
        )
        p.start()
        workers.append(p)

    print("[Main] All workers spawned. Entering communication rounds.\n")

    # Main node model
    global_model = create_model()

    # Use the first split's test loader as a simple global validation set
    val_loader = split_loaders[0]["test_loader"]

    history = {
        "round": [],
        "global_val_loss": [],
        "avg_worker_test_loss": [],
    }

    reads_per_round = []
    writes_per_round = []

    for r in range(num_rounds):
        print(f"\n[Main] ===== Communication Round {r} =====")

        # 1) Main node broadcasts current global model to all workers
        state_to_send = copy.deepcopy(global_model.state_dict())
        print(f"[Main] Broadcasting global model to {num_workers} workers...")
        for q in cmd_queues:
            q.put({"type": "train", "round": r, "state_dict": state_to_send})
        writes_this_round = num_workers  # command writes
        print(f"[Main] Broadcast complete. (writes_this_round={writes_this_round})")

        # 2) Collect results from workers
        worker_states = []
        worker_test_losses = []
        print("[Main] Waiting for worker results...")
        for _ in range(num_workers):
            msg = result_queue.get()
            wid = msg["worker_id"]
            worker_states.append(msg["state_dict"])
            worker_test_losses.append(msg["test_loss"])
            print(
                f"[Main] Received result from worker {wid} | "
                f"round={msg['round']} | train_loss={msg['train_loss']:.6f} | test_loss={msg['test_loss']:.6f}"
            )
        reads_this_round = num_workers  # result reads
        print(f"[Main] Collected all worker results. (reads_this_round={reads_this_round})")

        # 3) PSO aggregation on the worker models (hybrid PSO + Adam)
        print("[Main] Starting PSO-based aggregation over worker models...")
        model_template = create_model()
        best_weights, best_score = pso_optimize_aggregation(
            worker_states,
            val_loader,
            model_template,
            loss_fn,
            num_particles=10,
            max_iters=5,
        )
        agg_state = aggregate_models_cpu(best_weights, worker_states)
        print(f"[Main] PSO aggregation done. Best validation loss from PSO={best_score:.6f}")

        # 4) Main node runs Adam on aggregated model for a few epochs
        global_model.load_state_dict(agg_state)
        optimizer = torch.optim.Adam(global_model.parameters(), lr=base_lr)
        print(f"[Main] Running Adam on aggregated model for {global_adam_epochs} epochs...")
        for ge in range(global_adam_epochs):
            train_loss = train_one_epoch_adam(global_model, val_loader, optimizer, loss_fn)
            print(f"[Main] Global Adam epoch {ge+1}/{global_adam_epochs} | train_loss={train_loss:.6f}")

        # 5) Log metrics
        global_val_loss = evaluate_model(global_model, val_loader, loss_fn)
        history["round"].append(r)
        history["global_val_loss"].append(float(global_val_loss))
        history["avg_worker_test_loss"].append(float(np.mean(worker_test_losses)))

        reads_per_round.append(reads_this_round)
        writes_per_round.append(writes_this_round)

        print(
            f"[Main] Round {r} summary | global_val_loss={global_val_loss:.6f}, "
            f"avg_worker_test_loss={np.mean(worker_test_losses):.6f}, "
            f"writes={writes_this_round}, reads={reads_this_round}"
        )

    # 6) Stop workers (one extra command per worker)
    print("\n[Main] All communication rounds complete. Sending stop signals to workers...")
    for q in cmd_queues:
        q.put({"type": "stop"})
    extra_writes = num_workers
    print(f"[Main] Stop messages sent. (extra_writes={extra_writes})")

    for p in workers:
        p.join()
    print("[Main] All workers have terminated.")

    # Analytical IPC accounting (Queue-based, each message = 1 write + 1 read)
    train_messages = num_rounds * (2 * num_workers)  # train commands + results
    stop_messages = num_workers  # stop commands
    total_messages = train_messages + stop_messages

    comm_stats = {
        "num_rounds": num_rounds,
        "num_workers": num_workers,
        "reads_per_round": reads_per_round,
        "writes_per_round": writes_per_round,
        "total_train_round_messages": train_messages,
        "total_stop_messages": stop_messages,
        "total_queue_writes": total_messages,
        "total_queue_reads": total_messages,
    }

    print("\n[Main] Run complete. Communication statistics:")
    for k, v in comm_stats.items():
        print(f"[Main]   {k}: {v}")

    return global_model, history, comm_stats

In [None]:
# Simple PSO over aggregation weights (reusing the idea from Approach_1)

import numpy as np


def pso_optimize_aggregation(node_states, val_loader, model_template, loss_fn,
                             num_particles=10, max_iters=5, w=0.7, c1=1.5, c2=1.5):
    """Small PSO to find good aggregation weights for the 5 workers."""
    num_nodes = len(node_states)

    # Initialize particles on the simplex via Dirichlet
    particles = np.random.dirichlet(np.ones(num_nodes), size=num_particles)
    velocities = np.zeros_like(particles)
    pbest_positions = particles.copy()
    pbest_scores = np.full(num_particles, np.inf)

    gbest_position = None
    gbest_score = np.inf

    for it in range(max_iters):
        for i in range(num_particles):
            w_vec = np.abs(particles[i])
            w_vec /= np.sum(w_vec)

            agg_state = aggregate_models_cpu(w_vec, node_states)
            score = evaluate_model_state(agg_state, model_template, val_loader, loss_fn)

            if score < pbest_scores[i]:
                pbest_scores[i] = score
                pbest_positions[i] = w_vec.copy()

            if score < gbest_score:
                gbest_score = score
                gbest_position = w_vec.copy()

        # Velocity/position update
        for i in range(num_particles):
            r1 = np.random.rand(num_nodes)
            r2 = np.random.rand(num_nodes)
            velocities[i] = (
                w * velocities[i]
                + c1 * r1 * (pbest_positions[i] - particles[i])
                + c2 * r2 * (gbest_position - particles[i])
            )
            particles[i] += velocities[i]

    return gbest_position, gbest_score