<a href="https://colab.research.google.com/github/aminKMT/Train_scheduling_Net/blob/main/Train_Scheduling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================
# Capacity-constrained railway scheduler + edge-movement animation
# ============================================================
#!pip -q install networkx matplotlib tqdm --upgrade

import random, itertools, networkx as nx
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from collections import deque, defaultdict

# -------------------- USER-TUNEABLE PARAMETERS --------------------
CAP_RANGE       = (5, 10)          # inclusive capacity range per station
T_STOP          = 2                # dwell-time at a station  (min)
T_TRAVEL        = 3                # time to traverse one edge (min)
RIGHT_TRAINS    = 4                # O->… trains (start at node 0, goal nodes 5 or 6)
LEFT_TRAINS     = 3                # …<-O trains (start at node 5/6, goal node 0)
ARRIVAL_RIGHT   = [0, 3, 5, 9]     # entry minutes   for right-bound trains
ARRIVAL_LEFT    = [1, 4, 6]        # entry minutes   for left-bound trains
SIM_HORIZON     = 250              # max minutes to simulate
FPS             = 5                # animation frames per second
# -----------------------------------------------------------------

assert RIGHT_TRAINS == len(ARRIVAL_RIGHT)
assert LEFT_TRAINS  == len(ARRIVAL_LEFT)

# 1️⃣  Build a branching MultiGraph (0—1—2 loop, plus two terminals 5,6)
G = nx.MultiDiGraph()
edges = [
    (0,1), (1,2), (2,3), (3,6),        # square loop
    (1,4), (3,4),                      # diagonals
    (2,5), (2,6),                   # terminal branches
    (1,2), (1,2),(5,6), (6,0)                      # duplicate arcs (multi-track)
]
G.add_edges_from(edges)

# Fixed layout for nice plotting
pos = {0:(0,0), 1:(2,1.2), 2:(4,0), 3:(2,-1.2), 4:(2,0), 5:(6,1.2), 6:(6,-1.2)}
station_cap = {n: random.randint(*CAP_RANGE) for n in G.nodes()}

# 2️⃣  Train object ------------------------------------------------------------
class Train:
    _ids = itertools.count()
    def __init__(self, origin, dest, entry_time):
        self.id     = next(Train._ids)
        self.path   = nx.shortest_path(G, origin, dest)   # node sequence
        self.i      = 0               # index in path
        self.state  = 'station'       # station | travel
        self.t_left = T_STOP          # mins remaining in current state
        self.start_edge = self.end_edge = None
        self.entry  = entry_time
        self.delay  = None

    def __repr__(self): return f'T{self.id}'

# 3️⃣  Create train queues ------------------------------------------------------
right_q = deque(sorted([Train(0, random.choice([5,6]), t) for t in ARRIVAL_RIGHT], key=lambda tr:tr.entry))
left_q  = deque(sorted([Train(random.choice([5,6]), 0, t) for t in ARRIVAL_LEFT], key=lambda tr:tr.entry))

active, finished = [], []
occ               = defaultdict(list)        # current trains at each node
reserved_capacity = defaultdict(int)        # seats “booked” by trains in-flight

# 4️⃣  Discrete-event simulation -----------------------------------------------
timeline = []   # one snapshot per simulation minute

for t in range(SIM_HORIZON+1):

    # ── Inject trains whose entry time has come ─────────────────────────────
    while right_q and right_q[0].entry == t:
        tr = right_q.popleft()
        occ[0].append(tr); active.append(tr)
    while left_q and left_q[0].entry == t:
        tr = left_q.popleft()
        origin = tr.path[0]
        occ[origin].append(tr); active.append(tr)

    # ── Advance every active train ──────────────────────────────────────────
    for tr in list(active):
        tr.t_left -= 1

        if tr.state == 'station':
            if tr.t_left > 0: continue    # still dwelling

            # If destination reached ➜ finished
            if tr.i == len(tr.path)-1:
                occ[tr.path[tr.i]].remove(tr)
                tr.delay = t - tr.entry
                active.remove(tr); finished.append(tr); continue

            nxt = tr.path[tr.i+1]

            # capacity test *including* reservations for arriving trains
            if len(occ[nxt]) + reserved_capacity[nxt] < station_cap[nxt]:
                # depart: reserve slot at nxt
                reserved_capacity[nxt] += 1
                occ[tr.path[tr.i]].remove(tr)
                tr.start_edge, tr.end_edge = tr.path[tr.i], nxt
                tr.state, tr.t_left = 'travel', T_TRAVEL
            else:
                tr.t_left = 1            # wait 1 more minute

        elif tr.state == 'travel':
            if tr.t_left > 0: continue   # en-route

            # arrive at next station
            tr.state, tr.i = 'station', tr.i+1
            tr.t_left      = T_STOP
            reserved_capacity[tr.path[tr.i]] -= 1
            occ[tr.path[tr.i]].append(tr)

    # ── store snapshot for animation ───────────────────────────────────────
    snap = {
        'node_loads': {n: len(occ[n]) for n in G.nodes()},
        'train_states': {tr.id: (
            tr.state,
            (tr.start_edge, tr.end_edge, tr.t_left) if tr.state=='travel' else
            (tr.path[tr.i],)                       # station id
        ) for tr in active}
    }
    timeline.append(snap)

    if not active and not right_q and not left_q: break

print(f'Finished in {t} min;  mean delay = '
      f'{sum(tr.delay for tr in finished)/len(finished):.2f} min')

# 5️⃣  Build animation ----------------------------------------------------------
fig, ax = plt.subplots(figsize=(8,4))
nx.draw_networkx_edges(G, pos, ax=ax, width=2, alpha=.5)
nx.draw_networkx_nodes(G, pos, ax=ax,
                       node_size=[350+80*station_cap[n] for n in G.nodes()],
                       node_color='lightsteelblue')
nx.draw_networkx_labels(G, pos, ax=ax, font_weight='bold')
ax.set_axis_off()

# text displaying current load at every node
load_text = {n: ax.text(*((pos[n][0]), pos[n][1]+0.25), '',
                         ha='center', va='center',
                         fontdict={'size':9,'weight':'bold'}) for n in G.nodes()}

# train markers + small ID labels (created lazily)
train_dots, id_labels = {}, {}

def get_dot(tr_id):
    if tr_id not in train_dots:
        color = 'tab:red' if tr_id < RIGHT_TRAINS else 'tab:green'
        train_dots[tr_id] = ax.plot([],[], marker='>', ms=12, color=color)[0]
        id_labels[tr_id]  = ax.text(0,0,str(tr_id), fontsize=7,
                                    ha='center', va='center')
    return train_dots[tr_id], id_labels[tr_id]

time_txt = ax.text(0.02,0.95,'', transform=ax.transAxes, fontsize=12)

def init():
    return list(load_text.values()) + list(train_dots.values()) + list(id_labels.values()) + [time_txt]

def update(frame):
    snap = timeline[frame]
    # node loads
    for n, txt in load_text.items():
        txt.set_text(f"{snap['node_loads'][n]:d}")

    # trains
    present_ids = set()
    for tr_id, info in snap['train_states'].items():
        dot, lbl = get_dot(tr_id)
        if info[0]=='station':
            node = info[1][0]
            x,y = pos[node]
        else:  # travel
            u,v,tl = info[1]
            progress = (T_TRAVEL - tl)/T_TRAVEL
            x = (1-progress)*pos[u][0] + progress*pos[v][0]
            y = (1-progress)*pos[u][1] + progress*pos[v][1]
        dot.set_data([x],[y]); lbl.set_position((x,y-0.15))
        dot.set_visible(True); lbl.set_visible(True)
        present_ids.add(tr_id)

    # hide dots that are not present this frame
    for tr_id in list(train_dots.keys()):
        if tr_id not in present_ids:
            train_dots[tr_id].set_visible(False)
            id_labels[tr_id].set_visible(False)

    time_txt.set_text(f'Time = {frame} min')
    return list(load_text.values()) + list(train_dots.values()) + list(id_labels.values()) + [time_txt]

ani = FuncAnimation(fig, update, frames=len(timeline),
                    init_func=init, interval=1000/FPS,
                    blit=True, repeat=False)

plt.close()  # avoid static plot duplication in Colab
from IPython.display import HTML
HTML(ani.to_jshtml())


Finished in 25 min;  mean delay = 12.43 min


In [None]:
4