In [None]:
import gradio as gr
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from math import ceil, floor
import random, math
import nest_asyncio
nest_asyncio.apply()

# ============================================================
# 0) GLOBAL CONSTANTS
# ============================================================

distance_matrix = None
time_matrix = None

SCHOOL_NODE = 0
STOP_TIME = 60
MAX_STUDENT_TIME = 90 * 60
PENALTY_FACTOR = 10000.0
ALPHA = 0.3
BETA  = 0.7

# ============================================================
# 1) COST FUNCTIONS
# ============================================================

def route_travel_time(route):
    return sum(time_matrix[route[i]][route[i+1]] for i in range(len(route)-1))

def hybrid_cost(route):
    travel_time = sum(time_matrix[route[i]][route[i+1]] for i in range(len(route)-1))
    travel_dist = sum(distance_matrix[route[i]][route[i+1]] for i in range(len(route)-1))
    inner_students = len(route) - 2
    stops = inner_students * STOP_TIME
    return ALPHA * (travel_time + stops) + BETA * travel_dist

def route_total_distance(route):
    return sum(distance_matrix[route[i]][route[i+1]] for i in range(len(route)-1))

# ============================================================
# 2) 2-OPT
# ============================================================

def two_opt(route, cost_func):
    best = route.copy()
    best_cost = cost_func(best)
    improved = True

    while improved:
        improved = False
        for i in range(1, len(best) - 2):
            for j in range(i + 1, len(best) - 1):
                new_route = best[:i] + best[i:j+1][::-1] + best[j+1:]
                new_cost = cost_func(new_route)
                if new_cost < best_cost:
                    best = new_route
                    best_cost = new_cost
                    improved = True
                    break
            if improved:
                break
    return best

# ============================================================
# 3) INITIAL ROUTES
# ============================================================

def initial_route_morning(student_nodes):
    unvisited = set(int(s) for s in student_nodes)
    if not unvisited:
        return [SCHOOL_NODE, SCHOOL_NODE]
    route = [SCHOOL_NODE]
    curr = SCHOOL_NODE
    while unvisited:
        nxt = min(unvisited, key=lambda j: distance_matrix[curr][j])
        route.append(nxt)
        unvisited.remove(nxt)
        curr = nxt
    route.append(SCHOOL_NODE)
    return route

def initial_route_afternoon(student_nodes):
    return initial_route_morning(student_nodes)

# ============================================================
# 4) REMOVE OPERATORS
# ============================================================

def random_remove(route, percent=0.15):
    r = route.copy()
    inner = r[1:-1]
    if len(inner) == 0:
        return r, []
    k = max(1, int(len(inner) * percent))
    removed = random.sample(inner, k)
    for x in removed:
        r.remove(x)
    return r, removed

def worst_remove(route, k=7):
    r = route.copy()
    inner = r[1:-1]
    if len(inner) <= 1:
        return r, []
    k = min(k, len(inner))
    scores = []

    for i in range(1, len(r)-1):
        b = r[i-1]
        n = r[i]
        a = r[i+1]

        inc_time = time_matrix[b][n] + time_matrix[n][a] - time_matrix[b][a]
        inc_dist = distance_matrix[b][n] + distance_matrix[n][a] - distance_matrix[b][a]
        inc = ALPHA * inc_time + BETA * inc_dist
        scores.append((inc, n))

    scores.sort(reverse=True)
    removed = [n for _, n in scores[:k]]
    for n in removed:
        r.remove(n)
    return r, removed

def shaw_remove(route, k=3, gamma=0.3):
    r = route.copy()
    inner = r[1:-1]
    if not inner:
        return r, []
    k = min(k, len(inner))

    seed = random.choice(inner)
    removed = {seed}

    sc = []
    for n in inner:
        if n == seed:
            continue
        sim_time = time_matrix[seed][n]
        sim_dist = distance_matrix[seed][n]
        rel = gamma * sim_time + (1 - gamma) * sim_dist
        sc.append((rel, n))

    sc.sort(key=lambda x: x[0])

    for _, n in sc:
        if len(removed) >= k:
            break
        removed.add(n)

    for x in removed:
        r.remove(x)

    return r, list(removed)

# ============================================================
# 5) REPAIR ‚Äî REGRET-2
# ============================================================

def regret_2_insert(route, removed):
    r = route.copy()
    removed = removed.copy()

    while removed:
        best_node = None
        best_pos = None
        best_regret = -1

        for node in removed:
            incs = []
            for i in range(1, len(r)):
                b = r[i-1]
                a = r[i]
                inc_time = time_matrix[b][node] + time_matrix[node][a] - time_matrix[b][a]
                inc_dist = distance_matrix[b][node] + distance_matrix[node][a] - distance_matrix[b][a]
                inc = ALPHA * inc_time + BETA * inc_dist
                incs.append((inc, i))

            incs.sort(key=lambda x: x[0])
            best_inc = incs[0][0]
            best_idx = incs[0][1]
            second_inc = incs[1][0] if len(incs) > 1 else best_inc

            regret = second_inc - best_inc

            if regret > best_regret:
                best_regret = regret
                best_node = node
                best_pos = best_idx

        r.insert(best_pos, best_node)
        removed.remove(best_node)

    return r

# ============================================================
# 6) CONSTRAINTS
# ============================================================

def route_time(route):
    travel = route_travel_time(route)
    inner = len(route) - 2
    stops = inner * STOP_TIME
    return travel + stops

def student_trip_time_morning(route):
    inner = route[1:-1]
    if not inner:
        return 0
    travel = sum(time_matrix[route[i]][route[i+1]] for i in range(1, len(route)-1))
    stops = len(inner) * STOP_TIME
    return travel + stops

def student_trip_time_afternoon(route):
    inner = route[1:-1]
    if not inner:
        return 0
    travel = sum(time_matrix[route[i]][route[i+1]] for i in range(len(route)-2))
    stops = len(inner) * STOP_TIME
    return travel + stops

def constrained_cost_morning(route):
    base = hybrid_cost(route)
    max_t = student_trip_time_morning(route)
    if max_t <= MAX_STUDENT_TIME:
        return base
    excess_min = (max_t - MAX_STUDENT_TIME) / 60.0
    penalty = PENALTY_FACTOR * (1.0 + excess_min)
    return base + penalty

def constrained_cost_afternoon(route):
    base = hybrid_cost(route)
    max_t = student_trip_time_afternoon(route)
    if max_t <= MAX_STUDENT_TIME:
        return base
    excess_min = (max_t - MAX_STUDENT_TIME) / 60.0
    penalty = PENALTY_FACTOR * (1.0 + excess_min)
    return base + penalty

# ============================================================
# 7) ACCEPTANCE
# ============================================================

def accept(new, old, T):
    if new < old:
        return True
    return random.random() < math.exp(-(new - old) / T)

# ============================================================
# 8) ALNS (Morning / Afternoon)
# ============================================================

def alns_for_cluster_morning(student_nodes,
                             iterations=12000, T0=850,
                             cooling=0.9999):
    """
    Adaptive Large Neighborhood Search (ALNS) for a single morning cluster.
    Uses:
      - initial_route_morning
      - random_remove, worst_remove, shaw_remove
      - regret_2_insert
      - constrained_cost_morning
      - accept (Simulated Annealing)
      - two_opt + route_time
    """

    route = initial_route_morning(student_nodes)
    best = route.copy()
    best_cost = constrained_cost_morning(best)

    ops = [random_remove, worst_remove, shaw_remove]
    weights = [1.0, 1.0, 1.0]

    op_scores = [0.0, 0.0, 0.0]
    op_usage  = [0.0, 0.0, 0.0]

    rho = 0.8
    update_every = 200

    T = T0

    def select_op():
        tot = sum(weights)
        probs = [w / tot for w in weights]

        cumulative = []
        c = 0.0
        for p in probs:
            c += p
            cumulative.append(c)

        r = random.random()
        for i, cp in enumerate(cumulative):
            if r <= cp:
                return i
        return len(weights) - 1

    for it in range(1, iterations + 1):
        idx = select_op()
        partial, rem = ops[idx](route)

        if not rem:
            continue

        new = regret_2_insert(partial, rem)

        nc = constrained_cost_morning(new)
        oc = constrained_cost_morning(route)

        if nc < best_cost:
            op_scores[idx] += 6
            op_usage[idx]  += 1
        elif nc < oc:
            op_scores[idx] += 3
            op_usage[idx]  += 1
        elif nc == oc:
            op_scores[idx] += 1
            op_usage[idx]  += 1
        else:
            op_usage[idx]  += 1

        if accept(nc, oc, T):
            route = new.copy()
            if nc < best_cost:
                best = new.copy()
                best_cost = nc

        if it % update_every == 0:
            for k in range(len(weights)):
                if op_usage[k] > 0:
                    avg = op_scores[k] / op_usage[k]
                else:
                    avg = 0.0
                weights[k] = (1 - rho) * weights[k] + rho * avg

            op_scores = [0.0, 0.0, 0.0]
            op_usage  = [0.0, 0.0, 0.0]

        T = T * cooling

    best_2 = two_opt(best, route_time)
    best_2_cost = constrained_cost_morning(best_2)

    if best_2_cost < best_cost:
        best, best_cost = best_2, best_2_cost

    return best, best_cost


def alns_for_cluster_afternoon(student_nodes,
                               iterations=12000, T0=850,
                               cooling=0.9999):
    """
    Adaptive Large Neighborhood Search (ALNS) for a single afternoon cluster.
    Same structure as morning, but using constrained_cost_afternoon and
    initial_route_afternoon.
    """

    route = initial_route_afternoon(student_nodes)
    best = route.copy()
    best_cost = constrained_cost_afternoon(best)

    ops = [random_remove, worst_remove, shaw_remove]
    weights = [1.0, 1.0, 1.0]

    op_scores = [0.0, 0.0, 0.0]
    op_usage  = [0.0, 0.0, 0.0]

    rho = 0.8
    update_every = 200

    T = T0

    def select_op():
        tot = sum(weights)
        probs = [w / tot for w in weights]

        cumulative = []
        c = 0.0
        for p in probs:
            c += p
            cumulative.append(c)

        r = random.random()
        for i, cp in enumerate(cumulative):
            if r <= cp:
                return i
        return len(weights) - 1

    for it in range(1, iterations + 1):
        idx = select_op()
        partial, rem = ops[idx](route)

        if not rem:
            continue

        new = regret_2_insert(partial, rem)

        nc = constrained_cost_afternoon(new)
        oc = constrained_cost_afternoon(route)

        if nc < best_cost:
            op_scores[idx] += 6
            op_usage[idx]  += 1
        elif nc < oc:
            op_scores[idx] += 3
            op_usage[idx]  += 1
        elif nc == oc:
            op_scores[idx] += 1
            op_usage[idx]  += 1
        else:
            op_usage[idx]  += 1

        if accept(nc, oc, T):
            route = new.copy()
            if nc < best_cost:
                best = new.copy()
                best_cost = nc

        if it % update_every == 0:
            for k in range(len(weights)):
                if op_usage[k] > 0:
                    avg = op_scores[k] / op_usage[k]
                else:
                    avg = 0.0
                weights[k] = (1 - rho) * weights[k] + rho * avg

            op_scores = [0.0, 0.0, 0.0]
            op_usage  = [0.0, 0.0, 0.0]

        T = T * cooling

    best_2 = two_opt(best, route_time)
    best_2_cost = constrained_cost_afternoon(best_2)

    if best_2_cost < best_cost:
        best, best_cost = best_2, best_2_cost

    return best, best_cost


# ============================================================
# 9) CLUSTERING
# ============================================================

def run_clustering(merged_file):
    global distance_matrix, time_matrix

    if merged_file is None:
        return None, None, None, "‚ö†Ô∏è Please upload the merged student CSV file first.", None

    df = pd.read_csv(merged_file)
    df = df.reset_index(drop=True)

    dist_cols = [c for c in df.columns if "distance" in c.lower()]
    time_cols = [c for c in df.columns if "time" in c.lower()]

    distance_matrix = df[dist_cols].to_numpy(dtype=float)
    time_matrix     = df[time_cols].to_numpy(dtype=float)

    df["node_id"] = df.index

    df_students = df.iloc[1:].reset_index(drop=True).copy()
    df_students["node_id"] = df_students.index + 1

    N = len(df_students)
    coords = df_students[["lat", "lon"]].values

    scaler = StandardScaler()
    coords_norm = scaler.fit_transform(coords)

    min_cap = 14
    max_cap = 44

    k_min = ceil(N / max_cap)
    k_max = floor(N / min_cap)
    possible_K = list(range(k_min, k_max + 1))

    if N < 28:
        possible_K = [1]
    elif N == 28:
        possible_K = [2]
    else:
        possible_K = [k for k in possible_K if k >= 2]

    results = {"K": [], "sil": [], "chi": [], "dbi": [], "labels": []}

    for k in possible_K:
        km = KMeans(
            n_clusters=k,
            init='k-means++',
            n_init=300,
            max_iter=700,
            random_state=42
        )
        labels = km.fit_predict(coords_norm)
        sizes = np.bincount(labels, minlength=k)

        if k == 2:
            ratio = min(sizes) / max(sizes)
            if ratio < 0.6:
                continue
        else:
            if (sizes < min_cap).any() or (sizes > max_cap).any():
                continue

        sil = silhouette_score(coords_norm, labels)
        dbi = davies_bouldin_score(coords_norm, labels)
        chi = calinski_harabasz_score(coords_norm, labels)

        results["K"].append(k)
        results["sil"].append(sil)
        results["chi"].append(chi)
        results["dbi"].append(dbi)
        results["labels"].append(labels)

    if len(results["K"]) == 0:
        return None, None, None, "‚ùå No feasible bus configuration found.", None

    rank_df = pd.DataFrame(results)
    rank_df["Silhouette Rank"] = rank_df["sil"].rank(ascending=False)
    rank_df["CH Rank"] = rank_df["chi"].rank(ascending=False)
    rank_df["DBI Rank"] = rank_df["dbi"].rank(ascending=True)
    rank_df["Internal Avg Rank"] = (rank_df["Silhouette Rank"] + rank_df["CH Rank"] + rank_df["DBI Rank"]) / 3
    rank_df["Final Rank"] = rank_df["Internal Avg Rank"].rank(ascending=True)
    rank_df = rank_df.sort_values("Final Rank")

    best_row = rank_df.iloc[0]
    best_labels = best_row["labels"]
    best_k = int(best_row["K"])

    df_students["cluster_ai"] = best_labels.astype(int)

    summary_rows = []
    for cid in sorted(df_students["cluster_ai"].unique()):
        part = df_students[df_students["cluster_ai"] == cid]
        summary_rows.append({
            "Bus / Cluster ID": int(cid),
            "Students Count": len(part)
        })
    summary_df = pd.DataFrame(summary_rows)

    msg = f"‚úÖ Clustering completed successfully.\nBest K = {best_k}"

    return df, df_students, rank_df, msg, summary_df

# ============================================================
# 10) HELPERS
# ============================================================

def get_students_for_bus(bus_id, df_students):
    try:
        bus_id = int(bus_id)
        subset = df_students[df_students["cluster_ai"] == bus_id]

        student_list = [
            f"{row['node_id']} - {row['student_name']}"
            for _, row in subset.iterrows()
        ]

        return gr.update(choices=student_list, value=student_list)
    except:
        return gr.update(choices=[], value=[])

def parse_selected_nodes(selected_list):
    if not selected_list:
        return []
    ids = []
    for item in selected_list:
        try:
            nid = int(str(item).split("-")[0].strip())
            ids.append(nid)
        except:
            continue
    return ids

def convert_route_to_df_indices(route, df_full):
    converted = []
    for nid in route:
        if nid == SCHOOL_NODE:
            converted.append(0)
        else:
            match = df_full.index[df_full["node_id"] == nid]
            converted.append(int(match[0]) if len(match) > 0 else 0)
    return converted

def plot_route(df_full, route, title="Route"):
    route_idx = convert_route_to_df_indices(route, df_full)

    fig, ax = plt.subplots(figsize=(7, 7))

    student_nodes = [i for i in route_idx if i != SCHOOL_NODE]
    lats = [df_full.loc[i, "lat"] for i in student_nodes]
    lons = [df_full.loc[i, "lon"] for i in student_nodes]

    ax.scatter(lons, lats, s=60, linewidths=1, edgecolors="#1F2933", color="#2A8C82")

    line_lats = [df_full.loc[i, "lat"] for i in route_idx]
    line_lons = [df_full.loc[i, "lon"] for i in route_idx]
    ax.plot(line_lons, line_lats, marker="o", linewidth=2, color="#F2994A")

    school_lat = df_full.loc[SCHOOL_NODE, "lat"]
    school_lon = df_full.loc[SCHOOL_NODE, "lon"]
    ax.scatter([school_lon], [school_lat], s=150, marker="*", linewidths=1, edgecolors="#1F2933", color="#F2C94C")

    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    return fig

# ============================================================
# 11) ROUTING
# ============================================================

def run_routing(bus_id, selected_students, df_full, df_students):
    if df_full is None or df_students is None:
        return None, "‚ö†Ô∏è Please run clustering first.", None, ""

    if bus_id is None:
        return None, "‚ö†Ô∏è Please select a bus first.", None, ""

    cid = int(bus_id)
    sub = df_students[df_students["cluster_ai"] == cid].copy()

    local_node_ids = sub["node_id"].to_numpy().astype(int)

    chosen_nodes = parse_selected_nodes(selected_students)
    if chosen_nodes:
        student_nodes = [n for n in local_node_ids if n in chosen_nodes]
    else:
        student_nodes = list(local_node_ids)

    if len(student_nodes) == 0:
        return None, "‚ö†Ô∏è No selected students.", None, ""

    best_morning_route, _ = alns_for_cluster_morning(student_nodes)
    best_afternoon_route, _ = alns_for_cluster_afternoon(student_nodes)

    morning_total = route_time(best_morning_route)
    morning_students_time = student_trip_time_morning(best_morning_route)
    morning_distance = route_total_distance(best_morning_route)

    afternoon_total = route_time(best_afternoon_route)
    afternoon_students_time = student_trip_time_afternoon(best_afternoon_route)
    afternoon_distance = route_total_distance(best_afternoon_route)

    morning_text_out = (
        f"üöå Bus / Cluster {cid}\n"
        f"‚òÄÔ∏è MORNING ROUTE\n"
        f"Route: {best_morning_route}\n"
        f"Bus Total Time (min): {morning_total / 60:.2f}\n"
        f"Max Student Trip Time (min): {morning_students_time / 60:.2f}\n"
        f"Bus Total Distance (km): {morning_distance / 1000:.3f}\n"
        + ("‚úî Within 90 minutes.\n" if morning_students_time <= MAX_STUDENT_TIME else "‚úò Exceeds 90 minutes!\n")
    )

    afternoon_text_out = (
        f"üåô AFTERNOON ROUTE\n"
        f"Route: {best_afternoon_route}\n"
        f"Bus Total Time (min): {afternoon_total / 60:.2f}\n"
        f"Max Student Trip Time (min): {afternoon_students_time / 60:.2f}\n"
        f"Bus Total Distance (km): {afternoon_distance / 1000:.3f}\n"
        + ("‚úî Within 90 minutes.\n" if afternoon_students_time <= MAX_STUDENT_TIME else "‚úò Exceeds 90 minutes!\n")
    )

    fig_m = plot_route(df_full, best_morning_route, "Morning Route")
    fig_a = plot_route(df_full, best_afternoon_route, "Afternoon Route")

    return fig_m, morning_text_out, fig_a, afternoon_text_out


# ============================================================
# 12) MULTI-PAGE UI
# ============================================================

BUS_IMAGE_PATH = "/content/school_bus.jpg"
def show_page(which):
    return [
        gr.update(visible=(which == 1)),
        gr.update(visible=(which == 2)),
        gr.update(visible=(which == 3)),
        gr.update(visible=(which == 4)),
    ]


# ============================
# DARK PURPLE THEME
# ============================

dark_theme = gr.themes.Soft(
    primary_hue="purple",
    secondary_hue="slate",
    neutral_hue="gray",
).set(
    body_background_fill="#0D0B12",
    body_text_color="#ECE6F5",

    block_background_fill="#15121C",
    block_border_width="0px",
    block_shadow="0px 0px 18px rgba(132, 56, 255, 0.18)",

    block_title_text_color="#D4B8FF",

    button_primary_background_fill="#7D22D8",
    button_primary_background_fill_hover="#9B3CFA",
    button_primary_text_color="#FFFFFF",

    button_secondary_background_fill="#231E2E",
    button_secondary_text_color="#E4D8FF",

    input_background_fill="#1B1722",
    input_border_color="#7D22D8",
    input_shadow="0px 0px 10px rgba(132, 56, 255, 0.35)",

    link_text_color="#C79BFF",
)

with gr.Blocks(theme=dark_theme) as demo:

    df_full_state = gr.State()
    df_students_state = gr.State()

    # PAGE 1
    with gr.Group(visible=True) as page1:
        with gr.Group(elem_classes=["welcome-card"]):
            gr.Image(BUS_IMAGE_PATH, show_label=False, elem_classes=["welcome-image"])
            gr.HTML("""
                <div class='welcome-header'>
                    <div class='welcome-title'>Welcome to the Smart School Bus Routing System</div>
                    <div class='welcome-subtitle'>
                        A modern AI-powered tool designed to cluster students and generate optimized bus routes
                    </div>
                </div>
            """)
        btn_start = gr.Button("Get started", elem_classes=["primary-btn"])

    # PAGE 2
    with gr.Group(visible=False) as page2:
        btn_back_to_page1 = gr.Button("‚¨Ö Back")
        with gr.Group(elem_classes=["step-card"]):
            gr.Markdown("###  Upload student data")
            merged_file = gr.File(label="Student merged CSV file", file_types=[".csv"])
            btn_run_clustering = gr.Button("Run clustering")
        clustering_status = gr.Markdown()
        rank_table = gr.Dataframe(label="Ranking table", visible=False)
        summary_table = gr.Dataframe(label="Buses summary")
        btn_to_page3 = gr.Button("Next ‚ûú Select bus & attendance")

    # PAGE 3
    with gr.Group(visible=False) as page3:
        btn_back_to_page2 = gr.Button("‚¨Ö Back")
        with gr.Group(elem_classes=["step-card"]):
            gr.Markdown("###  Select bus and attendance")
            bus_id_dropdown = gr.Dropdown(label="Bus ID", choices=[], value=None)
            students_check = gr.CheckboxGroup(label="Select present students", choices=[], value=[])
        btn_to_page4 = gr.Button("Next ‚ûú Routing page")

    # PAGE 4
    with gr.Group(visible=False) as page4:
        btn_back_to_page3 = gr.Button("‚¨Ö Back")
        with gr.Group(elem_classes=["step-card"]):
            gr.Markdown("###  Routing")
            gr.Markdown("#### ‚òÄÔ∏è Morning Route")
            morning_plot = gr.Plot()
            morning_text = gr.Markdown()
            gr.Markdown("#### üåô Afternoon Route")
            afternoon_plot = gr.Plot()
            afternoon_text = gr.Markdown()
            btn_rerun_route = gr.Button("Generate / Re-run routing")

    debug_text = gr.Markdown("", visible=True)
    EMPTY_DF = pd.DataFrame()

    def on_run_clustering(file):
        try:
            df_full, df_students, rank_df, msg, summary_df = run_clustering(file)
            df_students["cluster_ai"] = df_students["cluster_ai"].astype(int)
            bus_ids = [str(int(x)) for x in sorted(df_students["cluster_ai"].unique())]

            dbg = "OK ‚úî " + msg
            return (
                df_full,
                df_students,
                rank_df if rank_df is not None else EMPTY_DF,
                msg,
                summary_df if summary_df is not None else EMPTY_DF,
                gr.update(choices=bus_ids),
                dbg
            )
        except Exception as e:
            return (
                None,
                None,
                EMPTY_DF,
                f"‚ùå Error: {str(e)}",
                EMPTY_DF,
                gr.update(choices=None),
                f"‚ùå ERROR\n{str(e)}"
            )

    btn_run_clustering.click(
        fn=on_run_clustering,
        inputs=[merged_file],
        outputs=[
            df_full_state,
            df_students_state,
            rank_table,
            clustering_status,
            summary_table,
            bus_id_dropdown,
            debug_text
        ]
    )

    bus_id_dropdown.change(
        fn=get_students_for_bus,
        inputs=[bus_id_dropdown, df_students_state],
        outputs=[students_check]
    )

    def on_run_routing(bus_id, selected_students, df_full, df_students):
        return run_routing(bus_id, selected_students, df_full, df_students)

    btn_rerun_route.click(
        fn=on_run_routing,
        inputs=[bus_id_dropdown, students_check, df_full_state, df_students_state],
        outputs=[morning_plot, morning_text, afternoon_plot, afternoon_text],
        queue=True
    )

    # FORWARD
    btn_start.click(fn=lambda: show_page(2), outputs=[page1, page2, page3, page4])
    btn_to_page3.click(fn=lambda: show_page(3), outputs=[page1, page2, page3, page4])
    btn_to_page4.click(fn=lambda: show_page(4), outputs=[page1, page2, page3, page4])

    # BACKWARD
    btn_back_to_page1.click(fn=lambda: show_page(1), outputs=[page1, page2, page3, page4])
    btn_back_to_page2.click(fn=lambda: show_page(2), outputs=[page1, page2, page3, page4])
    btn_back_to_page3.click(fn=lambda: show_page(3), outputs=[page1, page2, page3, page4])

demo.launch(debug=True)


  with gr.Blocks(theme=dark_theme) as demo:


It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://9c8575a60644587aea.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
