In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import json
import random
import glob
from collections import deque
from typing import List, Tuple

import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim


In [None]:
DATASETS_DIR = "datasets"   # folder containing load_* subfolders with snapshot_*.json
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cpu


In [None]:
# RL hyperparams
GAMMA = 0.99
LR = 1e-3
BATCH_SIZE = 64
BUFFER_CAPACITY = 20000
EPS_START = 1.0
EPS_END = 0.05
EPS_DECAY = 0.995
TARGET_UPDATE_EVERY = 5   # episodes
MAX_EPISODES = 800
MAX_STEPS_PER_EP = 30

# Environment / network params (will be overwritten by snapshot topology)
DEFAULT_CAPACITY = 100.0
CONGESTION_THRESHOLD = 0.9  # link load/capacity > threshold considered congested

In [None]:
# Reward weights
W_DELAY = -1.0
W_CONGEST = -5.0
W_SUCCESS = +50.0
W_INVALID = -50.0
W_BW      = -2.0


In [None]:
# ---------------------------
# Utilities: load snapshots
# ---------------------------
def load_all_snapshots(dataset_root: str) -> List[dict]:
    snapshots = []
    if not os.path.isdir(dataset_root):
        raise FileNotFoundError(f"Dataset dir not found: {dataset_root}")
    # find all snapshot json files under load_* directories
    for load_dir in sorted(glob.glob(os.path.join(dataset_root, "load_*"))):
        for f in sorted(glob.glob(os.path.join(load_dir, "snapshot_*.json"))):
            with open(f, "r", encoding="utf-8") as fh:
                data = json.load(fh)
            # Some fields in snapshot may use string keys for edges; normalize later
            snapshots.append(data)
    if not snapshots:
        raise RuntimeError("No snapshots found in dataset path.")
    return snapshots


In [None]:
# ---------------------------
# Environment
# ---------------------------
class RerouteEnv:
    """
    Environment that uses snapshots as starting states.
    Each episode handles one flow (picked from a random snapshot).
    Action space: choose next-hop among neighbors of current_node.
    State vector: [normalized link loads flattened] + [one-hot current node] + [one-hot dest node]
    """

    def __init__(self, snapshots: List[dict]):
        self.snapshots = snapshots
        # Use first snapshot to build canonical graph (node set)
        base = snapshots[0]
        self.G = self._build_graph_from_snapshot(base)
        self.num_nodes = self.G.number_of_nodes()
        self.edges_list = list(self.G.edges())
        self.num_links = len(self.edges_list)

        # runtime dynamic
        self.link_load = {e: 0.0 for e in self.edges_list}
        self.current_flow = None  # dict with id, src, dst, bw, time_left
        self.partial_path = []
        self.current_node = None
        self.destination = None
        self.max_steps = MAX_STEPS_PER_EP

    def _build_graph_from_snapshot(self, snap: dict) -> nx.DiGraph:
        G = nx.DiGraph()
        # If topology stored as dict with string keys "(u, v)"
        topo = snap.get("topology", None)
        if topo is None:
            # fallback: create full graph from node count if present
            n = len(snap.get("traffic_matrix", []))
            G.add_nodes_from(range(n))
            # no edges info -> create random edges (not ideal)
            return G
        # Add nodes
        # collect nodes from keys
        nodes = set()
        for edge_str in topo.keys():
            # edge_str expected like "(u, v)" or "u,v"
            try:
                if edge_str.startswith("("):
                    u_str, v_str = edge_str.strip("()").split(",")
                else:
                    u_str, v_str = edge_str.split(",")
                u = int(u_str.strip())
                v = int(v_str.strip())
            except Exception:
                # if keys are like "u v" or other, try eval
                try:
                    u, v = eval(edge_str)
                except Exception:
                    continue
            nodes.add(u); nodes.add(v)
        max_node = max(nodes) if nodes else 0
        G.add_nodes_from(range(max_node+1))
        # add edges with capacity
        for edge_str, meta in topo.items():
            try:
                if edge_str.startswith("("):
                    u_str, v_str = edge_str.strip("()").split(",")
                else:
                    u_str, v_str = edge_str.split(",")
                u = int(u_str.strip()); v = int(v_str.strip())
            except Exception:
                try:
                    u, v = eval(edge_str)
                except Exception:
                    continue
            cap = float(meta.get("capacity", DEFAULT_CAPACITY))
            G.add_edge(u, v, capacity=cap)
        return G

    def reset(self):
        """
        Choose a random snapshot and pick a random active flow from it.
        Initialize graph link loads and flow parameters.
        Return initial state vector.
        """
        snap = random.choice(self.snapshots)
        # rebuild graph for capacity in this snapshot
        self.G = self._build_graph_from_snapshot(snap)
        self.num_nodes = self.G.number_of_nodes()
        self.edges_list = list(self.G.edges())
        self.num_links = len(self.edges_list)
        # init link_load from snapshot if present
        self.link_load = {}
        raw_link_load = snap.get("link_load", {})
        # raw_link_load keys might be strings
        for e in self.edges_list:
            key = str(e)
            # some snapshots use "(u, v)" keys
            key2 = f"({e[0]},{e[1]})"
            load_val = 0.0
            if key in raw_link_load:
                load_val = float(raw_link_load[key])
            elif key2 in raw_link_load:
                load_val = float(raw_link_load[key2])
            self.link_load[e] = load_val

        # choose an active connection to reroute
        active_conns = snap.get("active_connections", [])
        if not active_conns:
            # fallback: create synthetic flow
            nodes = list(self.G.nodes())
            src = random.choice(nodes)
            dst = random.choice([n for n in nodes if n != src])
            bw = max(1.0, np.random.rand() * 10.0)
            self.current_flow = {"id": 0, "src": src, "dst": dst, "bw": bw, "time_left": 10.0}
        else:
            # active_conns entries like dicts {"id":..., "src":..., "dst":..., "bw":...}
            cand = [c for c in active_conns if isinstance(c, dict)]
            if not cand:
                # if stored as tuples list, convert
                cand = []
                for item in active_conns:
                    try:
                        cid, path, bw, on_short = item
                        src = path[0]; dst = path[-1]
                        cand.append({"id": cid, "src": src, "dst": dst, "bw": bw, "time_left": 10.0})
                    except Exception:
                        continue
            self.current_flow = random.choice(cand)

        self.partial_path = [self.current_flow["src"]]
        self.current_node = self.current_flow["src"]
        self.destination = self.current_flow["dst"]
        # episode step counter
        self.step_count = 0

        state = self._build_state()
        return state

    def _build_state(self) -> np.ndarray:
        """
        Build state vector:
            - normalized link loads flattened (len = num_links)
            - one-hot current node (num_nodes)
            - one-hot destination node (num_nodes)
        """
        # normalized loads
        loads = np.array([self.link_load[e] / (self.G[e[0]][e[1]]["capacity"] + 1e-9) for e in self.edges_list], dtype=np.float32)
        # clip to [0,1]
        loads = np.clip(loads, 0.0, 1.0)
        # node one-hot
        cur_onehot = np.zeros(self.num_nodes, dtype=np.float32)
        dst_onehot = np.zeros(self.num_nodes, dtype=np.float32)
        cur_onehot[self.current_node] = 1.0
        dst_onehot[self.destination] = 1.0
        # concatenate
        state = np.concatenate([loads, cur_onehot, dst_onehot]).astype(np.float32)
        return state

    def action_space(self) -> List[int]:
        """Return list of neighbor nodes (possible next-hops) from current_node"""
        return list(self.G.successors(self.current_node))

    def compute_total_bandwidth_utilization(self):
        """
        Tính tổng băng thông sử dụng toàn mạng = Σ(load/capacity)
        """
        total = 0.0
        for (u, v) in self.edges_list:
            load = self.link_load.get((u, v), 0.0)
            cap  = float(self.G[u][v].get("capacity", 1e-9))
            total += load / cap
        return total


    def step(self, action_next_node: int) -> Tuple[np.ndarray, float, bool, dict]:
        """
        action_next_node: node id chosen as next hop (must be neighbor)
        Procedure:
          - append next node to partial_path
          - compute shortest path from that node -> destination (Dijkstra)
          - create full_path = partial + rest[1:]
          - check capacity on full_path for current_flow["bw"]
          - if ok: allocate bw on each link of full_path
            compute reward based on delay and congestion
            if reached dest -> success reward and done True
          - if not ok: invalid action -> heavy penalty and done True
        """
        self.step_count += 1
        info = {}
        bw = float(self.current_flow.get("bw", 1.0))

        # validate action: must be neighbor
        if not self.G.has_edge(self.current_node, action_next_node):
            # invalid
            reward = W_INVALID
            done = True
            return self._build_state(), reward, done, {"reason": "not_neighbor"}

        # append partial path and compute rest using shortest path
        self.partial_path.append(action_next_node)
        try:
            rest = nx.shortest_path(self.G, action_next_node, self.destination)
        except nx.NetworkXNoPath:
            # no path to dest from chosen next-hop
            reward = W_INVALID
            done = True
            return self._build_state(), reward, done, {"reason": "no_path"}

        # build full path
        full_path = self.partial_path + rest[1:]
        # check capacities along full_path
        links_in_path = [(full_path[i], full_path[i+1]) for i in range(len(full_path)-1)]
        insufficient = False
        for e in links_in_path:
            cap = float(self.G[e[0]][e[1]]["capacity"])
            load = float(self.link_load.get(e, 0.0))
            if load + bw > cap + 1e-9:
                insufficient = True
                break

        if insufficient:
            # invalid (would exceed capacity)
            reward = W_INVALID
            done = True
            return self._build_state(), reward, done, {"reason": "capacity"}

        # allocate bandwidth on links
        for e in links_in_path:
            self.link_load[e] = self.link_load.get(e, 0.0) + bw

        # update current node to last chosen (agent moves to that next hop)
        self.current_node = action_next_node

        # compute metrics (simple approximations)
        # delay: average normalized load across links
        loads = np.array([self.link_load[e] / (self.G[e[0]][e[1]]["capacity"] + 1e-9) for e in self.edges_list], dtype=np.float32)
        avg_delay = float(np.mean(loads))

        # congestion count
        congest_count = int(np.sum(loads > CONGESTION_THRESHOLD))
        total_bw_util = self.compute_total_bandwidth_utilization()

        # reward composition
        reward = (
                      W_DELAY * avg_delay +
                      W_CONGEST * congest_count +
                      W_BW * total_bw_util
                  )

        # success?
        if self.current_node == self.destination:
            reward += W_SUCCESS
            done = True
            info["reached"] = True
        elif self.step_count >= self.max_steps:
            done = True
            info["timeout"] = True
        else:
            done = False

        next_state = self._build_state()
        return next_state, float(reward), done, info


In [None]:
# ---------------------------
# Replay Buffer
# ---------------------------
class ReplayBuffer:
    def __init__(self, capacity=BUFFER_CAPACITY):
        self.buffer = deque(maxlen=capacity)

    def add(self, s, a, r, s2, done):
        self.buffer.append((s, a, r, s2, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, min(len(self.buffer), batch_size))
        s, a, r, s2, d = zip(*batch)
        return np.stack(s), np.array(a), np.array(r, dtype=np.float32), np.stack(s2), np.array(d, dtype=np.float32)

    def size(self):
        return len(self.buffer)


In [None]:
# ---------------------------
# Q-Network (MLP)
# ---------------------------
class QNet(nn.Module):
    def __init__(self, state_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim//4)
        )
        # final linear head produced dynamically for variable action sizes per forward via indexing

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [None]:
# ---------------------------
# Training helpers
# ---------------------------
def train_q_network(qnet_base: QNet, head_out_dim_func, target_base: QNet, optimizer, buffer: ReplayBuffer, batch_size=BATCH_SIZE):
    """
    head_out_dim_func: function that, given a batch of states, returns:
       - q_values_batch: list of Q-values arrays per sample for all possible actions
         (Because actions variable per state (different neighbor counts), here we will
         approximate by evaluating Q-values for all neighbors via separate small heads.)
    In this simplified implementation, we will:
       - build a candidate action set per sample by retrieving neighbors of current node from env
       - but because env not available here, we restrict training to samples where the action index corresponds to a neighbor index
    For simplicity and stability: we'll approximate by training only on samples where action is mapped to an integer index and
    use a small head network per sample computed by a linear layer on top of base embedding.
    """
    if buffer.size() < batch_size:
        return
    s_batch, a_batch, r_batch, s2_batch, d_batch = buffer.sample(batch_size)

    s_batch_t = torch.tensor(s_batch, dtype=torch.float32, device=DEVICE)
    s2_batch_t = torch.tensor(s2_batch, dtype=torch.float32, device=DEVICE)
    a_batch_t = torch.tensor(a_batch, dtype=torch.long, device=DEVICE)
    r_batch_t = torch.tensor(r_batch, dtype=torch.float32, device=DEVICE)
    d_batch_t = torch.tensor(d_batch, dtype=torch.float32, device=DEVICE)

    # base embeddings
    emb_s = qnet_base(s_batch_t)            # (B, emb_dim)
    emb_s2 = target_base(s2_batch_t)        # (B, emb_dim)

    # To map embeddings to Q-values for chosen actions, we will use a small linear head shared:
    # This is a simplification: we assume action index fits into a limited range.
    # Create head layers on the fly (small) - note: in production design, action space should be fixed or use masking + policy net.
    head = nn.Linear(emb_s.shape[1], 64).to(DEVICE)
    head2 = nn.Linear(64, 1).to(DEVICE)  # output scalar Q for each chosen action
    # Create optimizer for head only (training both base+head is OK but keep simple)
    opt = optim.Adam(list(qnet_base.parameters()) + list(head.parameters()) + list(head2.parameters()), lr=LR)

    # compute predicted Q for chosen action
    q_pred = head2(torch.relu(head(emb_s))).squeeze()  # (B,)  - scalar per sample
    # compute target: r + gamma * max_a' Q_target(s2, a') -- approximated by using emb_s2 through same head
    with torch.no_grad():
        q_next_all = head2(torch.relu(head(emb_s2))).squeeze()
        q_target = r_batch_t + GAMMA * q_next_all * (1.0 - d_batch_t)

    loss = nn.MSELoss()(q_pred, q_target)
    opt.zero_grad()
    loss.backward()
    opt.step()
    return loss.item()


In [None]:
# ---------------------------
# Main training loop
# ---------------------------
def train():
    # load snapshots
    snapshots = load_all_snapshots(DATASETS_DIR)
    env = RerouteEnv(snapshots)

    state_dim = env.num_links + env.num_nodes * 2
    qnet_base = QNet(state_dim).to(DEVICE)
    target_base = QNet(state_dim).to(DEVICE)
    target_base.load_state_dict(qnet_base.state_dict())

    buffer = ReplayBuffer()
    optimizer = optim.Adam(qnet_base.parameters(), lr=LR)

    epsilon = EPS_START

    for ep in range(1, MAX_EPISODES + 1):
        state = env.reset()
        ep_reward = 0.0

        for step in range(MAX_STEPS_PER_EP):
            # get neighbor list
            neighbors = env.action_space()
            if not neighbors:
                # no outgoing edges -> done
                break

            # epsilon-greedy over neighbors
            if random.random() < epsilon:
                # choose random neighbor
                chosen = random.choice(neighbors)
            else:
                # Evaluate q for each possible neighbor by forward through qnet_base + small head
                s_t = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)
                emb = qnet_base(s_t)
                # create temporary head to score neighbors by feeding embedding through linear head
                # For simplicity we score neighbors by projecting embedding to scalar and select max
                # (Consistent with train helper above)
                with torch.no_grad():
                    tmp_head = nn.Linear(emb.shape[1], 64).to(DEVICE)
                    tmp_head2 = nn.Linear(64, 1).to(DEVICE)
                    score = tmp_head2(torch.relu(tmp_head(emb))).squeeze().item()
                    # note: because head is random initialized, this is roughly random; main training happens via buffer
                    # fallback: pick random neighbor
                    chosen = random.choice(neighbors)

            s2, r, done, info = env.step(chosen)
            buffer.add(state, chosen, r, s2, done)
            state = s2
            ep_reward += r

            # train step (simple helper)
            _ = train_q_network(qnet_base, None, target_base, optimizer, buffer, BATCH_SIZE)

            if done:
                break

        # update target network periodically
        if ep % TARGET_UPDATE_EVERY == 0:
            target_base.load_state_dict(qnet_base.state_dict())

        # decay epsilon
        epsilon = max(EPS_END, epsilon * EPS_DECAY)

        if ep % 10 == 0 or ep == 1:
            print(f"Episode {ep:4d}  reward={ep_reward:.2f}  eps={epsilon:.3f} buffer={buffer.size()}")

    # save model
    torch.save(qnet_base.state_dict(), "qnet_base.pth")
    print("Training finished. Model saved to qnet_base.pth"

In [None]:

if __name__ == "__main__":
    train()