In [12]:
import simpy, math, random, heapq
from collections import defaultdict

BITS_PER_BYTE = 8

class Link:
    def __init__(self, env, u, v, bandwidth_bps, prop_delay_ms, capacity=64):
        self.env = env
        self.u, self.v = u, v
        self.bps = bandwidth_bps
        self.prop = prop_delay_ms / 1000.0
        self.q = simpy.Store(env, capacity=capacity)  # finite queue
        self.bytes_sent = 0
        self.messages_sent = 0
        self.max_q_depth = 0  # track max queue depth
        env.process(self.run())

    def send(self, size_bytes, payload):
        # backpressure: put() will block if q is full
        self.bytes_sent += size_bytes
        self.messages_sent += 1
        self.max_q_depth = max(self.max_q_depth, len(self.q.items) + 1)
        return self.q.put((size_bytes, payload))

    def run(self):
        while True:
            size_bytes, payload = (yield self.q.get())
            ser = (size_bytes * BITS_PER_BYTE) / self.bps
            yield self.env.timeout(ser + self.prop)
            payload["dst"].inbox.put(payload)


class RulePolicy:
    def __init__(self, eps_small=0.01, eps_med=0.05, q_hi=8, agg_ms=10):
        self.eps_small = eps_small
        self.eps_med = eps_med 
        self.q_hi = q_hi
        self.agg_ms = agg_ms

    def decide(self, *, delta, queue_depth, util=None):
        # util not optional; not used at first in minimal rule set
        if delta < self.eps_small and queue_depth > 0:
            return "SKIP", None, None
        if queue_depth >= self.q_hi:
            return "AGGREGATE", self.agg_ms, None
        if delta < self.eps_med:
            return "COMPRESS", None, 48  # compressed size
        return "ALLOW", None, 96         # normal size

    

class Node:
    def __init__(self, env, node_id, controller, policy):
        self.env = env
        self.id = node_id
        self.ctrl = controller
        self.policy = policy
        self.neighbors = {}               # v -> (link, weight)
        self.inbox = simpy.Store(env)
        self.dist = defaultdict(lambda: math.inf)
        self.action_counts = defaultdict(int)
        # track best value we last sent per (neighbor v, source s)
        self.last_sent = defaultdict(lambda: math.inf)
        # optional: buffers for aggregation per (v,s)
        self.pending = {}
        env.process(self.recv_loop())

    def add_neighbor(self, v, link, w):
        self.neighbors[v] = (link, w)

    def init_source(self, s):
        if self.id == s:
            self.dist[s] = 0.0
            for v,(link,w) in self.neighbors.items():
                self._maybe_send(v, s, self.dist[s] + w)

    def _enqueue_send(self, link, payload, size_bytes):
        self.ctrl.inflight += 1
        link.send(size_bytes, payload)

    def _maybe_send(self, v, s, d):
        link,_ = self.neighbors[v]
        old = self.last_sent[(v,s)]
        delta = abs(d - old)
        queue_depth = len(link.q.items)

        action, param, size_bytes = self.policy.decide(
            delta=delta, queue_depth=queue_depth, util=None
        )
        self.action_counts[action] += 1
        if action == "SKIP":
            return

        if action == "AGGREGATE":
            # buffer the latest value; (v,s) key
            key = (v,s)
            self.pending[key] = d
            # if no timer running, start one
            if ("timer", key) not in self.pending:
                self.pending[("timer", key)] = True
                self.env.process(self._aggregate_after(param, v, s, link))
            return

        # For ALLOW/COMPRESS, send immediately
        send_size = size_bytes if size_bytes is not None else self.ctrl.size_model("RELAX")
        msg = {"kind":"RELAX","src":self,"dst":self.ctrl.nodes[v],"s":s,"d":d}
        self.last_sent[(v,s)] = d
        self._enqueue_send(link, msg, send_size)

    def _aggregate_after(self, delay_ms, v, s, link):
        yield self.env.timeout(delay_ms / 1000.0)
        key = (v,s)
        if key in self.pending:
            d = self.pending.pop(key)
            # clear timer flag
            _ = self.pending.pop(("timer", key), None)
            # send the latest buffered value with normal size (or choose compressed)
            send_size = self.ctrl.size_model("RELAX")
            msg = {"kind":"RELAX","src":self,"dst":self.ctrl.nodes[v],"s":s,"d":d}
            self.last_sent[(v,s)] = d
            self._enqueue_send(link, msg, send_size)
        else:
            _ = self.pending.pop(("timer", key), None)

    def recv_loop(self):
        while True:
            msg = (yield self.inbox.get())
            self.ctrl.inflight -= 1
            if msg["kind"] == "RELAX":
                s, d = msg["s"], msg["d"]
                if d < self.dist[s]:
                    self.dist[s] = d
                    # propagate
                    for v,(link,w) in self.neighbors.items():
                        self._maybe_send(v, s, self.dist[s] + w)

class Controller:
    def __init__(self, env, graph, base_bps=8e5, base_prop_ms=1.0, policy=None):
        self.env = env
        self.links = {}
        self.inflight = 0
        self.policy = policy  # None => baseline
        self.nodes = {u: Node(env, u, self, self.policy) for u in graph["nodes"]}

        # Heterogeneous links (bandwidth/latency vary per direction)
        for (u, v, w) in graph["edges"]:
            bps_uv = random.uniform(0.5, 1.0) * base_bps
            bps_vu = random.uniform(0.5, 1.0) * base_bps
            prop_uv = random.uniform(0.2, 2.0) * base_prop_ms
            prop_vu = random.uniform(0.2, 2.0) * base_prop_ms

            l_uv = Link(self.env, u, v, bps_uv, prop_uv, capacity=64)
            l_vu = Link(self.env, v, u, bps_vu, prop_vu, capacity=64)
            self.links[(u, v)] = l_uv
            self.links[(v, u)] = l_vu
            self.nodes[u].add_neighbor(v, l_uv, w)
            self.nodes[v].add_neighbor(u, l_vu, w)

    def size_model(self, kind):
        # 96 bytes per second
        return 96

def make_grid(n_side=16, w=1.0, jitter=0.15):
    nodes = [(i, j) for i in range(n_side) for j in range(n_side)]
    idx = {nodes[k]: k for k in range(len(nodes))}
    edges = []
    for i in range(n_side):
        for j in range(n_side):
            u = idx[(i, j)]
            if i + 1 < n_side:
                v = idx[(i + 1, j)]
                wij = random.uniform(w - jitter, w + jitter)
                edges.append((u, v, wij))
            if j + 1 < n_side:
                v = idx[(i, j + 1)]
                wij = random.uniform(w - jitter, w + jitter)
                edges.append((u, v, wij))
    return {"nodes": list(range(len(nodes))), "edges": edges}

def dijkstra_cpu(graph, source):
    adj = {u: [] for u in graph["nodes"]}
    for u, v, w in graph["edges"]:
        adj[u].append((v, w))
        adj[v].append((u, w))
    dist = {u: math.inf for u in graph["nodes"]}
    dist[source] = 0.0
    pq = [(0.0, source)]
    while pq:
        d, u = heapq.heappop(pq)
        if d != dist[u]:
            continue
        for v, w in adj[u]:
            nd = d + w
            if nd < dist[v]:
                dist[v] = nd
                heapq.heappush(pq, (nd, v))
    return dist

def run_once(seed=0, use_policy=True, use_dynamic_jitter=False):
    random.seed(seed)
    env = simpy.Environment()
    G = make_grid(n_side=16, w=1.0, jitter=0.15)

    policy = RulePolicy(eps_small=0.05, eps_med=0.20, q_hi=1, agg_ms=20) if use_policy else None
    ctrl = Controller(env, G, base_bps=8e5, base_prop_ms=1.0, policy=policy)

    # Multiple staggered sources -> overlapping waves
    N = len(G["nodes"])
    sources = [0, N // 3, (2 * N) // 3]
    def start_source_later(s, delay_ms):
        yield env.timeout(delay_ms / 1000.0)
        ctrl.nodes[s].init_source(s)
    for idx, s in enumerate(sources):
        env.process(start_source_later(s, delay_ms=idx * 25))  # 0, 25, 50 ms

    if use_dynamic_jitter:
        jitter_edges_periodically(ctrl, G, period_ms=120, scale=0.02)

    # stop when quiescent or time cap reached
    def stopper(timeout_s=5.0, quiet_ms=80):
        quiet = 0.0
        while env.now < timeout_s:
            yield env.timeout(0.005)
            if ctrl.inflight == 0:
                quiet += 0.005
                if quiet >= (quiet_ms / 1000.0):
                    break
            else:
                quiet = 0.0
    env.process(stopper())
    env.run()

    # Metrics
    total_bytes = sum(l.bytes_sent for l in ctrl.links.values())
    total_msgs  = sum(l.messages_sent for l in ctrl.links.values())
    max_q = max((l.max_q_depth for l in ctrl.links.values()), default=0)

    # Policy action summary (if enabled)
    action_totals = {}
    if ctrl.policy is not None:
        agg = defaultdict(int)
        for node in ctrl.nodes.values():
            for k, v in getattr(node, "action_counts", {}).items():
                agg[k] += v
        action_totals = dict(agg)
        print("Action summary:", action_totals)
        print("Max queue depth observed:", max_q)

    # Accuracy vs exact (per source), report average MAE
    maes = []
    for s in sources:
        exact = dijkstra_cpu(G, s)
        approx = {u: ctrl.nodes[u].dist.get(s, math.inf) for u in G["nodes"]}
        mae = sum(abs(approx[u] - exact[u]) for u in G["nodes"]) / len(G["nodes"])
        maes.append(mae)
    avg_mae = sum(maes) / len(maes)

    return {
        "bytes": total_bytes,
        "msgs": total_msgs,
        "time": env.now,
        "avg_mae": avg_mae,
        "max_q": max_q,
        "actions": action_totals
    }

if __name__ == "__main__":
    # Baseline (no policy): set use_policy=False
    baseline = run_once(use_policy=False)
    print("Baseline (no policy):", {k:v for k,v in baseline.items() if k!='dists'})

    # Rule-based policy enabled
    ruled = run_once(use_policy=True)
    print("RulePolicy:", {k:v for k,v in ruled.items() if k!='dists'})




Action counts: {'ALLOW': 360}
Baseline (no policy): {'bytes': 34560, 'msgs': 360, 'time': 0.05999999999999999}
Action counts: {'ALLOW': 360}
RulePolicy: {'bytes': 34560, 'msgs': 360, 'time': 0.05999999999999999}
