In [None]:
# ---------------------------
# pip installs if needed
# ---------------------------
# !pip install ortools pandas matplotlib

import pandas as pd
import datetime as dt
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from ortools.sat.python import cp_model
import math
import os

# ---------------------------
# Configuration
# ---------------------------
  # Update if needed
CSV_PATH = "./station_trains_30.csv"
HORIZON_MINUTES = 180
BASE_BUFFER = 1
UNCERTAINTY_MARGIN = 1
LOOP_PENALTY = 5
MAX_SOLVE_TIME_S = 30

TRACKS = {
    "A": {"dir": "top2bottom", "capacity": 1},
    "B": {"dir": "both", "capacity": 1},
    "C": {"dir": "both", "capacity": 1},
    "D": {"dir": "both", "capacity": 1},
    "E": {"dir": "bottom2top", "capacity": 1},
    "Loop": {"dir": "both", "capacity": 1, "is_loop": True},
}

PLATFORMS = {
    "P1": {"length": 1000, "status": "open"},
    "P2": {"length": 1000, "status": "open"},
    "P3": {"length": 1000, "status": "open"},
    "P4": {"length": 1000, "status": "open"},
    "P5": {"length": 1000, "status": "open"},
}

# ---------------------------
# Helper functions
# ---------------------------
def minutes_between(a, b):
    return int((b - a).total_seconds() // 60)

def compute_dwell(row):
    if str(row.get('train_type', '')).lower() == "freight":
        return 0
    if pd.notnull(row.get('etd')) and pd.notnull(row.get('eta')):
        return max(1, minutes_between(row['eta'], row['etd']))
    return 3

def estimate_service_time(row):
    length_factor = math.ceil(row['length'] / 300)
    return int(2 + row['dwell_time'] + 1 + length_factor)

# ---------------------------
# Load dataset
# ---------------------------
df = pd.read_csv(CSV_PATH)
print("Dataset preview:")
display(df.head())

if 'train_id' not in df.columns:
    df['train_id'] = df.index.astype(str)
df['eta'] = pd.to_datetime(df['ETA'], utc=True, errors='coerce') if 'ETA' in df.columns else pd.Timestamp.now(tz='UTC') + pd.to_timedelta(range(len(df)), unit='m')
df['etd'] = pd.to_datetime(df['ETD'], utc=True, errors='coerce') if 'ETD' in df.columns else pd.NaT
df['direction'] = df['direction'].astype(str).apply(lambda x: 'top2bottom' if str(x).upper().startswith('TOP') else 'bottom2top') if 'direction' in df.columns else ['top2bottom' if i%2==0 else 'bottom2top' for i in range(len(df))]
df['dwell_time'] = df.apply(compute_dwell, axis=1)
df['priority'] = df['priority'].fillna(1).astype(int) if 'priority' in df.columns else 1
df['length'] = df['length'].fillna(200).astype(int) if 'length' in df.columns else 200
df['eta_uncertainty'] = df['eta_uncertainty'].fillna(1).astype(int) if 'eta_uncertainty' in df.columns else 1

min_eta = df['eta'].min()
horizon_start = min_eta - pd.Timedelta(minutes=5)
df['eta_min'] = ((df['eta'] - horizon_start).dt.total_seconds() // 60).astype(int)
df['service_time'] = df.apply(estimate_service_time, axis=1)
df['buffer'] = BASE_BUFFER + UNCERTAINTY_MARGIN * df['eta_uncertainty']
df['effective_service'] = df['service_time'] + df['buffer']
horizon = int((df['eta_min'] + df['effective_service']).max()) + 60
print(f"Horizon: {horizon} minutes")

# ---------------------------
# What-if simulation: platform D maintenance
# ---------------------------
platform_closure = {"D": [(30, 50)]}  # closed from 30 to 50 min

# ---------------------------
# Build CP-SAT model
# ---------------------------
model = cp_model.CpModel()
trains = df.to_dict('records')

assign_vars, start_vars, end_vars, interval_vars = {}, {}, {}, {}
delay_vars, chosen_start = {}, {}
track_intervals = {t: [] for t in TRACKS.keys()}

for t in trains:
    tid = t['train_id']
    eta_min = int(t['eta_min'])
    service = int(t['effective_service'])

    dvar = model.NewIntVar(0, horizon, f"delay_{tid}")
    delay_vars[tid] = dvar
    svar = model.NewIntVar(0, horizon, f"start_chosen_{tid}")
    chosen_start[tid] = svar
    track_option_binaries = []

    for tr_id, tr_meta in TRACKS.items():
        dir_ok = (tr_meta['dir']=='both') or (tr_meta['dir']=='top2bottom' and t['direction']=='top2bottom') or (tr_meta['dir']=='bottom2top' and t['direction']=='bottom2top')
        if not dir_ok: continue
        has_platform = True if str(t.get('train_type','')).lower()=='freight' else any(PLATFORMS[p]['status']=='open' and PLATFORMS[p]['length']>=t['length'] for p in PLATFORMS)
        if not has_platform: continue

        binvar = model.NewBoolVar(f"assign_{tid}_{tr_id}")
        start_var = model.NewIntVar(0, horizon, f"start_{tid}_{tr_id}")
        end_var = model.NewIntVar(0, horizon, f"end_{tid}_{tr_id}")
        ivar = model.NewOptionalIntervalVar(start_var, service, end_var, binvar, f"interval_{tid}_{tr_id}")

        assign_vars[(tid, tr_id)] = binvar
        start_vars[(tid, tr_id)] = start_var
        end_vars[(tid, tr_id)] = end_var
        interval_vars[(tid, tr_id)] = ivar
        track_intervals[tr_id].append(ivar)
        track_option_binaries.append(binvar)

        # enforce start >= ETA and tie delay
        model.Add(start_var >= eta_min).OnlyEnforceIf(binvar)
        model.Add(dvar == start_var - eta_min).OnlyEnforceIf(binvar)
        model.Add(svar == start_var).OnlyEnforceIf(binvar)

        # ---------------------------
        # Platform maintenance constraint
        # ---------------------------
        # if tr_id in platform_closure:
        #     for (c_start, c_end) in platform_closure[tr_id]:
        #         model.Add(start_var >= c_end).OnlyEnforceIf(binvar)
        #         model.Add(end_var <= c_start).OnlyEnforceIf(binvar)

    if not track_option_binaries:
        raise RuntimeError(f"No feasible track for train {tid}.")
    model.Add(sum(track_option_binaries) == 1)

# No overlap for intervals
for tr_id, ivars in track_intervals.items():
    if ivars:
        model.AddNoOverlap(ivars)

# ---------------------------
# Objective: minimize weighted delays + loop penalties
# ---------------------------
delay_terms = []
for t in trains:
    tid = t['train_id']
    dvar = delay_vars[tid]
    ttype = str(t.get('train_type','')).lower()
    weight = 100 if ttype=='express' else 50 if ttype=='passenger' else 20 if ttype=='special' else 10 if ttype=='freight' else 30
    delay_terms.append(dvar*weight)

loop_penalty_terms = []
for (tid, tr), b in assign_vars.items():
    if tr=='Loop':
        pen = model.NewIntVar(0, LOOP_PENALTY, f"loop_pen_{tid}")
        model.Add(pen==LOOP_PENALTY).OnlyEnforceIf(b)
        model.Add(pen==0).OnlyEnforceIf(b.Not())
        loop_penalty_terms.append(pen)

model.Minimize(sum(delay_terms)+sum(loop_penalty_terms))

# ---------------------------
# Solve
# ---------------------------
solver = cp_model.CpSolver()
solver.parameters.max_time_in_seconds = MAX_SOLVE_TIME_S
solver.parameters.num_search_workers = 8

print("Solving...")
res = solver.Solve(model)
status = solver.StatusName(res)
print("Solver status:", status)
if status not in ("OPTIMAL","FEASIBLE"):
    raise RuntimeError("No feasible solution found")

# ---------------------------
# Extract solution
# ---------------------------
schedule = []
for t in trains:
    tid = t['train_id']
    chosen_track_name, chosen_s, chosen_e = "UNASSIGNED", None, None
    for tr_id in TRACKS.keys():
        key=(tid,tr_id)
        if key in assign_vars and solver.Value(assign_vars[key])==1:
            chosen_track_name = tr_id
            chosen_s = solver.Value(start_vars[key])
            chosen_e = solver.Value(end_vars[key])
            break
    delay_minutes = solver.Value(delay_vars[tid])
    schedule.append({
        "train_id": tid, "train_type": t.get('train_type','Unknown'), "direction":t['direction'],
        "assigned_track":chosen_track_name, "start_min":chosen_s, "end_min":chosen_e,
        "eta_min":int(t['eta_min']), "delay_min":int(delay_minutes),
        "priority":int(t['priority']), "service_time":int(t['service_time'])
    })

sched_df = pd.DataFrame(schedule)
sched_df['planned_start_dt'] = sched_df['start_min'].apply(lambda m: (horizon_start + pd.Timedelta(minutes=int(m))).to_pydatetime() if pd.notna(m) else None)
sched_df['planned_end_dt'] = sched_df['end_min'].apply(lambda m: (horizon_start + pd.Timedelta(minutes=int(m))).to_pydatetime() if pd.notna(m) else None)

# ---------------------------
# FCFS baseline (manual operator)
# ---------------------------
# ---------------------------
# FCFS baseline (manual operator) – fixed
# ---------------------------
# ---------------------------
# FCFS baseline with constraints & what-if platform closure
# ---------------------------
# ---------------------------
# Robust FCFS baseline
# ---------------------------
# ---------------------------
# Robust FCFS baseline with full closure handling
# ---------------------------
fcfs_df = df.sort_values(by='eta_min').copy()
fcfs_schedule = []

track_next_free = {tr: 0 for tr in TRACKS.keys()}
platform_closure = {"D": [(30, 50)]}  # what-if

for idx, row in fcfs_df.iterrows():
    eligible_tracks = []
    for tr_id, tr_meta in TRACKS.items():
        dir_ok = (tr_meta['dir']=='both') or (tr_meta['dir']=='top2bottom' and row['direction']=='top2bottom') or (tr_meta['dir']=='bottom2top' and row['direction']=='bottom2top')
        if not dir_ok: continue
        if str(row.get('train_type','')).lower() != "freight":
            has_platform = any(PLATFORMS[p]['status']=='open' and PLATFORMS[p]['length']>=row['length'] for p in PLATFORMS)
        else:
            has_platform = True
        if has_platform:
            eligible_tracks.append(tr_id)
    if not eligible_tracks:
        raise RuntimeError(f"No eligible track for train {row['train_id']}")

    # Find earliest track considering closures
    earliest_start = None
    assigned_tr = None
    for tr in eligible_tracks:
        candidate_start = max(row['eta_min'], track_next_free[tr])
        candidate_end = candidate_start + row['effective_service']

        # Repeatedly check closures
        if tr in platform_closure:
            for (c_start, c_end) in platform_closure[tr]:
                while (candidate_start < c_end and candidate_end > c_start):
                    candidate_start = c_end
                    candidate_end = candidate_start + row['effective_service']

        if earliest_start is None or candidate_start < earliest_start:
            earliest_start = candidate_start
            assigned_tr = tr

    start_min = earliest_start
    end_min = start_min + row['effective_service']
    track_next_free[assigned_tr] = end_min
    delay_min = start_min - row['eta_min']

    fcfs_schedule.append({
        "train_id": row['train_id'],
        "assigned_track": assigned_tr,
        "start_min": start_min,
        "end_min": end_min,
        "delay_min": delay_min
    })

fcfs_df = pd.DataFrame(fcfs_schedule)
fcfs_df['planned_start_dt'] = fcfs_df['start_min'].apply(lambda m: horizon_start + pd.Timedelta(minutes=int(m)))
fcfs_df['planned_end_dt'] = fcfs_df['end_min'].apply(lambda m: horizon_start + pd.Timedelta(minutes=int(m)))
print("FCFS Schedule after full closure handling:")
display(fcfs_df.sort_values('planned_start_dt'))
print(f"FCFS Average delay: {fcfs_df['delay_min'].mean():.2f} min")


# ---------------------------
# Gantt chart: Optimized vs FCFS
# ---------------------------
def plot_gantt(df, title):
    tracks = sorted(set(df['assigned_track']))
    color_map = {tr: plt.cm.tab20(i%20) for i,tr in enumerate(tracks)}
    fig, ax = plt.subplots(figsize=(14,6))
    y_pos = {tr:i for i,tr in enumerate(tracks)}
    for _, r in df.iterrows():
        tr = r['assigned_track']
        ax.barh(y_pos[tr], r['end_min']-r['start_min'], left=r['start_min'], height=0.6, color=color_map[tr])
        ax.text(r['start_min'] + (r['end_min']-r['start_min'])/2, y_pos[tr], str(r['train_id']), va='center', ha='center', fontsize=8, color='white')
    ax.set_yticks([y_pos[t] for t in tracks])
    ax.set_yticklabels(tracks)
    ax.set_xlabel("Minutes from horizon start")
    ax.set_title(title)
    ax.grid(axis='x', linestyle='--', alpha=0.4)
    plt.tight_layout()
    plt.show()

plot_gantt(sched_df, f"Optimized Schedule — Avg delay {sched_df['delay_min'].mean():.2f} min")
plot_gantt(fcfs_df, f"FCFS Baseline — Avg delay {fcfs_df['delay_min'].mean():.2f} min")