In [4]:
# Run the notebook so its functions exist in the current namespace (not as a module)
%run adapters.ipynb
%run providers.ipynb

# now you can just call make_call_model_fn(...) directly since it's defined in the notebook


In [5]:
# message_opt_sim.py
import json
import os
import simpy, math, random, heapq, hashlib
import ollama
from adapters import make_call_model_fn
from providers import openai_json, anthropic_json, vllm_http_json
from dataclasses import dataclass
from typing import Optional, Callable
from collections import defaultdict
from math import isfinite

BITS_PER_BYTE = 8

os.environ["OLLAMA_MODEL"] = "phi3:mini"   # "llama3.1" was a previous model but it ran too slow "qwen2.5:7b-instruct", etc.

# -------------------- Link (network channel) --------------------
class Link:
    def __init__(self, env, u, v, bandwidth_bps, prop_delay_ms, capacity=32):   #udpated smaller queue depth
        self.env = env
        self.u, self.v = u, v
        self.bps = bandwidth_bps
        self.prop = prop_delay_ms / 1000.0
        self.q = simpy.Store(env, capacity=capacity)   # finite queue -> backpressure
        self.bytes_sent = 0
        self.messages_sent = 0
        self.max_q_depth = 0
        env.process(self.run())

    def send(self, size_bytes, payload):
        # put() blocks if queue is full -> natural backpressure
        self.bytes_sent += size_bytes
        self.messages_sent += 1
        self.max_q_depth = max(self.max_q_depth, len(self.q.items) + 1)
        return self.q.put((size_bytes, payload))

    def run(self):
        while True:
            size_bytes, payload = (yield self.q.get())
            ser = (size_bytes * BITS_PER_BYTE) / self.bps
            yield self.env.timeout(ser + self.prop)
            payload["dst"].inbox.put(payload)

# -------------------- Queue-aware Rule Policy --------------------
class RulePolicy:
    """
    Realistic queue-aware policy:
      - If tiny improvement and link is busy -> SKIP (shed low-value load)
      - If queue has backlog -> AGGREGATE (batch for a short window)
      - If small improvement -> COMPRESS (smaller payload)
      - Else -> ALLOW
    """
    def __init__(self, eps_small=0.05, eps_med=0.20, q_hi=1, agg_ms=20):
        self.eps_small = eps_small
        self.eps_med = eps_med
        self.q_hi = q_hi
        self.agg_ms = agg_ms

    def decide(self, *, delta, queue_depth, util=None):
        if delta < self.eps_small and queue_depth > 0:
            return "SKIP", None, None
        if queue_depth >= self.q_hi:
            return "AGGREGATE", self.agg_ms, None
        if delta < self.eps_med:
            return "COMPRESS", None, 48  # bytes
        return "ALLOW", None, 96         # bytes

# -------------------- Node (distributed process) --------------------
class Node:
    def __init__(self, env, node_id, controller, policy):
        self.env = env
        self.id = node_id
        self.ctrl = controller
        self.policy = policy                    # may be None for baseline
        self.neighbors = {}                     # v -> (link, weight)
        self.inbox = simpy.Store(env)
        self.dist = defaultdict(lambda: math.inf)
        self.last_sent = defaultdict(lambda: math.inf)  # best value last sent per (v,s)
        self.pending = {}                       # aggregation buffers
        self.action_counts = defaultdict(int)   # instrumentation
        env.process(self.recv_loop())

    def add_neighbor(self, v, link, w):
        self.neighbors[v] = (link, w)

    def init_source(self, s):
        if self.id == s:
            self.dist[s] = 0.0
            for v,(link,w) in self.neighbors.items():
                self._maybe_send(v, s, self.dist[s] + w)

    def _enqueue_send(self, link, payload, size_bytes):
        self.ctrl.inflight += 1
        link.send(size_bytes, payload)

    def _maybe_send(self, v, s, d):
        link,_ = self.neighbors[v]

        # Baseline: always send full message
        if self.policy is None:
            msg = {"kind":"RELAX","src":self,"dst":self.ctrl.nodes[v],"s":s,"d":d}
            self.last_sent[(v,s)] = d
            self._enqueue_send(link, msg, self.ctrl.size_model("RELAX"))
            return

        old = self.last_sent[(v,s)]
        delta = abs(d - old)
        queue_depth = len(link.q.items)

        action, param, size_bytes = self.policy.decide(
        delta=delta, queue_depth=queue_depth, util=None
        )
        # action, param, size_bytes = self.policy.decide(
        #     delta=delta, queue_depth=queue_depth, util=None
        # )
        
        self.action_counts[action] += 1
        if action == "SKIP":
            return

        if action == "AGGREGATE":
            key = (v,s)
            self.pending[key] = d
            if ("timer", key) not in self.pending:
                self.pending[("timer", key)] = True
                self.env.process(self._aggregate_after(param, v, s, link))
            return

        # ALLOW/COMPRESS -> send immediately
        send_size = size_bytes if size_bytes is not None else self.ctrl.size_model("RELAX")
        msg = {"kind":"RELAX","src":self,"dst":self.ctrl.nodes[v],"s":s,"d":d}
        self.last_sent[(v,s)] = d
        self._enqueue_send(link, msg, send_size)

    def _aggregate_after(self, delay_ms, v, s, link):
        yield self.env.timeout(delay_ms / 1000.0)
        key = (v,s)
        if key in self.pending:
            d = self.pending.pop(key)
            _ = self.pending.pop(("timer", key), None)
            send_size = self.ctrl.size_model("RELAX")
            msg = {"kind":"RELAX","src":self,"dst":self.ctrl.nodes[v],"s":s,"d":d}
            self.last_sent[(v,s)] = d
            self._enqueue_send(link, msg, send_size)
        else:
            _ = self.pending.pop(("timer", key), None)

    def recv_loop(self):
        while True:
            msg = (yield self.inbox.get())
            self.ctrl.inflight -= 1
            if msg["kind"] == "RELAX":
                s, d = msg["s"], msg["d"]
                if d < self.dist[s]:
                    self.dist[s] = d
                    for v,(link,w) in self.neighbors.items():
                        self._maybe_send(v, s, self.dist[s] + w)



#-------------------------------------------------------------------------------------------------------------------------------------------------#


class LLMPolicy:
    """
    Real-LLM policy stub:
    - Bins features (delta, queue_depth) to reduce distinct prompts.
    - Calls `call_model_fn(prompt: str) -> str` that returns a JSON decision.
    - Returns (action, param, size_bytes) like other policies.
    """
    def __init__(self,
                 call_model_fn,
                 bins_delta=(0.02, 0.1, 0.5, 1.0),
                 bins_q=(0, 1, 2, 4, 8),
                 agg_ms_choices=(5, 10, 20),
                 compress_size=48,
                 full_size=96,
                 use_cache=True):
        self.call_model_fn = call_model_fn
        self.bins_delta = tuple(sorted(bins_delta))
        self.bins_q = tuple(sorted(bins_q))
        self.agg_ms_choices = agg_ms_choices
        self.compress_size = compress_size
        self.full_size = full_size
        self.use_cache = use_cache
        self._cache = {}  # (delta_bin, q_bin) -> (action, param, size_bytes)
        
    # ---- helpers ----
    def _bin(self, x, edges):
        # place x into a discrete bin index: 0..len(edges)
        for i, e in enumerate(edges):
            if x < e:
                return i
        return len(edges)

    def _mk_prompt(self, delta, qdepth, dbin, qbin):
        # Keep the prompt small and deterministic so caching works well.
        return f"""
You are controlling message sending in a distributed shortest-path simulation.

Inputs:
- improvement_delta: {delta:.6f} (bin={dbin})
- link_queue_depth: {qdepth} (bin={qbin})

Goal:
- Minimize total bytes and messages.
- Keep accuracy high (prefer small error).
- Avoid adding much latency; only aggregate briefly if helpful.

Action space (choose exactly one):
- "ALLOW": send now with full size {self.full_size}.
- "COMPRESS": send now with smaller size {self.compress_size}.
- "AGGREGATE": delay briefly to coalesce; choose param in ms from {list(self.agg_ms_choices)}.
- "SKIP": do not send.

Respond with a SINGLE JSON object only:
{{"action": "ALLOW"|"COMPRESS"|"AGGREGATE"|"SKIP", "param": number|null, "size_bytes": number|null}}
Rules:
- If action is "AGGREGATE", "param" must be one of {list(self.agg_ms_choices)} and "size_bytes" must be null.
- If action is "COMPRESS", "size_bytes" must be {self.compress_size} and "param" must be null.
- If action is "ALLOW", "size_bytes" must be {self.full_size} and "param" must be null.
- If action is "SKIP", both "param" and "size_bytes" must be null.
"""

    def decide(self, *, delta, queue_depth, util=None):
        # sanitize inputs
        if not isfinite(delta) or delta < 0:
            # push into the highest bin so the first send isn't skipped
            delta = (self.bins_delta[-1] if self.bins_delta else 1.0) * 2.0
        if queue_depth is None or queue_depth < 0:
            queue_depth = 0

        # bin features to limit prompt variety
        dbin = self._bin(delta, self.bins_delta)
        qbin = self._bin(queue_depth, self.bins_q)
        cache_key = (dbin, qbin)

        if self.use_cache and cache_key in self._cache:
            return self._cache[cache_key]

        prompt = self._mk_prompt(delta, queue_depth, dbin, qbin)

        # ---- call the (real or dummy) model ----
        raw = self.call_model_fn(prompt)  # must return a JSON string

        # ---- parse & validate ----
        action, param, size_bytes = self._parse_and_validate(raw)

        # cache
        if self.use_cache:
            self._cache[cache_key] = (action, param, size_bytes)
        return action, param, size_bytes

    def _parse_and_validate(self, raw):
        # default safe fallback: ALLOW full size
        fallback = ("ALLOW", None, self.full_size)

        try:
            obj = json.loads(raw.strip())
        except Exception:
            return fallback

        action = obj.get("action")
        param = obj.get("param", None)
        size  = obj.get("size_bytes", None)

        if action not in ("ALLOW", "COMPRESS", "AGGREGATE", "SKIP"):
            return fallback

        # enforce schema
        if action == "ALLOW":
            return ("ALLOW", None, self.full_size)
        if action == "COMPRESS":
            return ("COMPRESS", None, self.compress_size)
        if action == "AGGREGATE":
            if param in self.agg_ms_choices:
                return ("AGGREGATE", param, None)
            # snap to nearest allowed choice if model gave a number
            if isinstance(param, (int, float)):
                closest = min(self.agg_ms_choices, key=lambda x: abs(x - param))
                return ("AGGREGATE", closest, None)
            return fallback
        if action == "SKIP":
            return ("SKIP", None, None)

        return fallback



#-------------------------------------------------------------------------------------------------------------------------------------------------#



class MockLLMPolicy:
    """
    Stand-in for an LLM. Same interface as RulePolicy.
    Returns (action, param, size_bytes).
    """
    def __init__(self,
                 bins_delta=(0.02, 0.1, 0.5),
                 q_hi=1,
                 agg_ms_choices=(10,15,20)):
        self.b0, self.b1, self.b2 = bins_delta
        self.q_hi = q_hi
        self.agg_ms_choices = agg_ms_choices

    def decide(self, *, delta, queue_depth, util=None):
        # congestion: aggregate if busy
        if queue_depth >= self.q_hi:
            return "AGGREGATE", self.agg_ms_choices[1], None
        # improvement thresholds
        if delta < self.b0:
            return "SKIP", None, None
        if delta < self.b1:
            return "COMPRESS", None, 48
        if delta < self.b2:
            return "AGGREGATE", self.agg_ms_choices[0], None
        return "ALLOW", None, 96

#-------------------------------------------------------------------------------------------------------------------------------------------------#

#This is only to be able to use the LLM policy to test its connection before using the API

def dummy_llm_call(prompt: str) -> str:
    """
    Cheap 'LLM' that reads queue bin / delta bin numbers from the prompt
    and returns a reasonable JSON decision. Replace with a real API later.
    """
    # Very crude parsing just for demonstration:
    # Look for 'bin=' markers we put in the prompt
    import re
    m_d = re.search(r"improvement_delta: .* \(bin=(\d+)\)", prompt)
    m_q = re.search(r"link_queue_depth: .* \(bin=(\d+)\)", prompt)
    dbin = int(m_d.group(1)) if m_d else 0
    qbin = int(m_q.group(1)) if m_q else 0

    # Heuristic:
    # - If queue bin is high, aggregate
    # - Else if delta bin is tiny, skip
    # - Else if delta is small, compress
    # - Else allow
    if qbin >= 2:
        return '{"action":"AGGREGATE","param":15,"size_bytes":null}'
    if dbin <= 0:
        return '{"action":"SKIP","param":null,"size_bytes":null}'
    if dbin == 1:
        return '{"action":"COMPRESS","param":null,"size_bytes":48}'
    return '{"action":"ALLOW","param":null,"size_bytes":96}'


# -------------------- Controller (orchestrator) --------------------
class Controller:
    def __init__(self, env, graph, base_bps=3e5, base_prop_ms=2.0, policy=None):
        self.env = env
        self.links = {}
        self.inflight = 0
        self.policy = policy  # None => baseline
        self.nodes = {u: Node(env, u, self, self.policy) for u in graph["nodes"]}

        # Heterogeneous links (bandwidth/latency vary per direction)
        for (u, v, w) in graph["edges"]:
            bps_uv = random.uniform(0.5, 1.0) * base_bps
            bps_vu = random.uniform(0.5, 1.0) * base_bps
            prop_uv = random.uniform(0.2, 2.0) * base_prop_ms
            prop_vu = random.uniform(0.2, 2.0) * base_prop_ms

            l_uv = Link(self.env, u, v, bps_uv, prop_uv, capacity=64)
            l_vu = Link(self.env, v, u, bps_vu, prop_vu, capacity=64)
            self.links[(u, v)] = l_uv
            self.links[(v, u)] = l_vu
            self.nodes[u].add_neighbor(v, l_uv, w)
            self.nodes[v].add_neighbor(u, l_vu, w)

    def size_model(self, kind):  # bytes per message
        return 96

# -------------------- Graph + exact Dijkstra (for MAE) --------------------
def make_grid(n_side=25, w=1.0, jitter=0.1):
    nodes = [(i, j) for i in range(n_side) for j in range(n_side)]
    idx = {nodes[k]: k for k in range(len(nodes))}
    edges = []
    for i in range(n_side):
        for j in range(n_side):
            u = idx[(i, j)]
            if i + 1 < n_side:
                v = idx[(i + 1, j)]
                wij = random.uniform(w - jitter, w + jitter)
                edges.append((u, v, wij))
            if j + 1 < n_side:
                v = idx[(i, j + 1)]
                wij = random.uniform(w - jitter, w + jitter)
                edges.append((u, v, wij))
    return {"nodes": list(range(len(nodes))), "edges": edges}

def dijkstra_cpu(graph, source):
    adj = {u: [] for u in graph["nodes"]}
    for u, v, w in graph["edges"]:
        adj[u].append((v, w))
        adj[v].append((u, w))
    dist = {u: math.inf for u in graph["nodes"]}
    dist[source] = 0.0
    pq = [(0.0, source)]
    while pq:
        d, u = heapq.heappop(pq)
        if d != dist[u]:
            continue
        for v, w in adj[u]:
            nd = d + w
            if nd < dist[v]:
                dist[v] = nd
                heapq.heappush(pq, (nd, v))
    return dist

# -------------------- Optional: dynamic edge jitter during run --------------------
def jitter_edges_periodically(ctrl, graph, period_ms=120, scale=0.02):
    def loop():
        while True:
            for _ in range(10):  # tweak how many edges per tick
                (u, v, w) = random.choice(graph["edges"])
                new_w = max(0.05, w * random.uniform(1 - scale, 1 + scale))
                # update both directions' stored weights on nodes
                if v in ctrl.nodes[u].neighbors:
                    link, _old = ctrl.nodes[u].neighbors[v]
                    ctrl.nodes[u].neighbors[v] = (link, new_w)
                if u in ctrl.nodes[v].neighbors:
                    link, _old = ctrl.nodes[v].neighbors[u]
                    ctrl.nodes[v].neighbors[u] = (link, new_w)
            yield ctrl.env.timeout(period_ms / 1000.0)
    ctrl.env.process(loop())

# -------------------- Run once (baseline or policy) --------------------

def run_once(seed=0, policy_kind="rule"):
    random.seed(seed)
    env = simpy.Environment()
    G = make_grid(n_side=25, w=1.0, jitter=0.1)

    if policy_kind == "none":
        policy = None
    elif policy_kind == "rule":
        policy = RulePolicy(eps_small=0.05, eps_med=0.20, q_hi=1, agg_ms=20)
    elif policy_kind == "mockllm":
        policy = LLMPolicy(bins_delta=(0.02, 0.1, 0.5), q_hi=1, agg_ms_choices=(10,15,20))
    elif policy_kind == "llm":
        policy = LLMPolicy(
            call_model_fn = make_call_model_fn(
                llm_call = ollama_json,               # This is where the LLM is chosen
                full_size = 96,
                compress_size = 48,
                agg_ms_choices = (5,10,20),
                rps = 5.0, burst = 10,
            ),
            bins_delta = (0.01, 0.05, 0.2, 0.8),
            bins_q = (0, 1, 2, 4, 8),
            agg_ms_choices = (5,10,20),
        )
    else:
        raise ValueError("unknown policy_kind")
    ctrl = Controller(env, G, base_bps=8e5, base_prop_ms=1.0, policy=policy)

    # Multiple staggered sources -> overlapping waves
    N = len(G["nodes"])
    sources = [0, 150, 300, 450, 600] 
    def start_source_later(s, delay_ms):
        yield env.timeout(delay_ms / 1000.0)
        ctrl.nodes[s].init_source(s)
    for idx, s in enumerate(sources):
        env.process(start_source_later(s, delay_ms=idx * 25))  # 0, 25, 50 ms

    # if use_dynamic_jitter:
    #     jitter_edges_periodically(ctrl, G, period_ms=120, scale=0.02)

    # stop when quiescent or time cap reached
    def stopper(timeout_s=15.0, quiet_ms=80):
        quiet = 0.0
        while env.now < timeout_s:
            yield env.timeout(0.005)
            if ctrl.inflight == 0:
                quiet += 0.005
                if quiet >= (quiet_ms / 1000.0):
                    break
            else:
                quiet = 0.0
    env.process(stopper())
    env.run()

    # Metrics
    total_bytes = sum(l.bytes_sent for l in ctrl.links.values())
    total_msgs  = sum(l.messages_sent for l in ctrl.links.values())
    max_q = max((l.max_q_depth for l in ctrl.links.values()), default=0)

    # Policy action summary (if enabled)
    action_totals = {}
    if ctrl.policy is not None:
        agg = defaultdict(int)
        for node in ctrl.nodes.values():
            for k, v in getattr(node, "action_counts", {}).items():
                agg[k] += v
        action_totals = dict(agg)
        print("Action summary:", action_totals)
        print("Max queue depth observed:", max_q)

    # Accuracy vs exact (per source), report average MAE
    maes = []
    for s in sources:
        exact = dijkstra_cpu(G, s)
        approx = {u: ctrl.nodes[u].dist.get(s, math.inf) for u in G["nodes"]}
        mae = sum(abs(approx[u] - exact[u]) for u in G["nodes"]) / len(G["nodes"])
        maes.append(mae)
    avg_mae = sum(maes) / len(maes)

    return {
        "bytes": total_bytes,
        "msgs": total_msgs,
        "time": env.now,
        "avg_mae": avg_mae,
        "max_q": max_q,
        "actions": action_totals
    }


#ADDED STUFF TO CHECK


####ADDED MORE STUFF TO CJECK
# -------------------- Main: compare baseline vs policy --------------------
if __name__ == "__main__":
    baseline = run_once(policy_kind="none")
    print("Baseline (no policy):", {k:v for k,v in baseline.items() if k!='actions'})

    import time
    t0=time.perf_counter()
    r = ollama.chat(model="phi3:mini", messages=[{"role":"user","content":'return only {"ok":true}'}],
                    options={"temperature":0.0, "num_predict":16})
    print("latency:", round(time.perf_counter()-t0,2), "s")
    print("reply:", r["message"]["content"])
    print(ollama.list())
    
    ruled = run_once(policy_kind="rule")
    print("RulePolicy:", {k:v for k,v in ruled.items() if k!='actions'})

    mocked   = run_once(policy_kind="llm")
    print("LLMPolicy:", {k:v for k,v in mocked.items() if k!='actions'})


Baseline (no policy): {'bytes': 29750496, 'msgs': 309901, 'time': 1.5149999999999897, 'avg_mae': 0.0, 'max_q': 65}
latency: 13.93 s
reply: ```json
{
  "ok": true
}
```

models=[Model(model='phi3:mini', modified_at=datetime.datetime(2025, 10, 31, 16, 53, 29, 524095, tzinfo=TzInfo(-25200)), digest='4f222292793889a9a40a020799cfd28d53f3e01af25d48e06c5e708610fc47e9', size=2176178913, details=ModelDetails(parent_model='', format='gguf', family='phi3', families=['phi3'], parameter_size='3.8B', quantization_level='Q4_0'))]
Action summary: {'ALLOW': 30674, 'COMPRESS': 59225, 'AGGREGATE': 15033, 'SKIP': 5621}
Max queue depth observed: 3
RulePolicy: {'bytes': 6539280, 'msgs': 97730, 'time': 0.3500000000000002, 'avg_mae': 0.009605144268427145, 'max_q': 3}


KeyboardInterrupt: 