In [1]:
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import uuid
import pickle
import numpy as np
from pycox.models import CoxPH
from scipy.stats import norm

# ─── Load your pickled DeepSurv model ─────────────────────────────────────────
with open("pickle/deepsurv_model_20250429_192055.pkl", "rb") as f:
    checkpoint = pickle.load(f)

# ─── Reconstruct CoxPH model and load weights ─────────────────────────────────
model = CoxPH(checkpoint['model_structure'])
model.net.load_state_dict(checkpoint['model_state_dict'])
model.baseline_hazards_ = checkpoint['baseline_hazards']
model.baseline_cumulative_hazards_ = checkpoint['baseline_cumulative_hazards']
model.training_data = checkpoint['training_data']
model.net.eval()

ALPHA = 0.1
Z95 = norm.ppf(1 - ALPHA/2)

# Predictor column order must match training
PREDS = [
    "age", "size", "nodes", "prog", "oest",
    "treat_1", "men_2", "grade_2", "grade_3"
]

# ─── Globals for saving ────────────────────────────────────────────────────────
last_haz_fig = None
last_surv_fig = None
last_table   = None

def generate_single_patient(treat, age, men, size, grade, nodes, prog, oest):
    global last_haz_fig, last_surv_fig, last_table

    # Early-exit if no inputs
    inputs = [treat, age, men, size, grade, nodes, prog, oest]
    if all(v in (None, 0) for v in inputs):
        last_haz_fig = last_surv_fig = None
        last_table = None
        return None, None, "Please enter patient values before plotting."

    # Build & scale features
    row = pd.DataFrame([{
        "age":     age/100.0,
        "size":    size/100.0,
        "nodes":   nodes/10.0,
        "prog":    prog/1000.0,
        "oest":    oest/1000.0,
        "treat_1": int(treat),
        "men_2":   int(men == 2),
        "grade_2": int(grade == 2),
        "grade_3": int(grade == 3)
    }])[PREDS].astype("float32")

    # Model predictions
    surv_df = model.predict_surv_df(row.values)
    cumhaz  = model.predict_cumulative_hazards(row.values)

    times = surv_df.index.values
    S_hat = surv_df.values[:, 0]
    H_hat = cumhaz.values[:, 0]
    # Calculate median survival time
    median_surv = times[np.abs(S_hat - 0.5).argmin()]


    # Step 1: Calculate S_CL
        # Step 1: Estimate s from training data (only once)
    if not hasattr(model, "s_std"):
        train_surv_df = model.predict_surv_df(model.training_data[0])
        model.s_std = train_surv_df.std(axis=1).values  # std across individuals for each time
        model.s_times = train_surv_df.index.values

    # Interpolate standard deviation s at the same time points as this patient's prediction
    s_interp = np.interp(times, model.s_times, model.s_std)

    # Step 2: Calculate s_cl using professor’s method
    S_safe = np.clip(S_hat, 1e-6, 1 - 1e-6)
    S_CL = s_interp / (S_safe * np.abs(np.log(S_safe)))

    # Step 3: Fix large s_cl values by forcing to zero
    S_CL[S_CL > 2.0] = 0  # You can tune this threshold


   

    # Additional stabilization for S_hat and H_hat
    S_hat_clamped = np.copy(S_hat)
    S_hat_clamped[S_hat > 0.98] = 1  # Very high survival -> force to 1
    H_hat[S_hat > 0.98] = 0          # Very high survival -> hazard should be ~0



    # Confidence intervals
    upper_S = S_hat ** np.exp(-Z95 * S_CL)
    lower_S = S_hat ** np.exp( Z95 * S_CL)
    lower_H = H_hat * np.exp(-Z95 * S_CL)
    upper_H = H_hat * np.exp( Z95 * S_CL)

    # Cumulative Hazard figure
    fig_haz = go.Figure()
    fig_haz.add_trace(go.Scatter(
        x=times, y=H_hat, mode='lines', name='Estimate', line=dict(color='black')
    ))
    fig_haz.add_trace(go.Scatter(
        x=times, y=upper_H, mode='lines', name='Upper CI', line=dict(dash='dash', color='magenta')
    ))
    fig_haz.add_trace(go.Scatter(
        x=times, y=lower_H, mode='lines', name='Lower CI', line=dict(dash='dash', color='magenta')
    ))
    fig_haz.update_layout(
        title="Cumulative Hazard with 95% CI",
        xaxis_title="Time",
        yaxis_title="Cumulative Hazard"
    )

    # Survival figure
    fig_surv = go.Figure()
    fig_surv.add_trace(go.Scatter(
        x=times, y=S_hat, mode='lines', name='Estimate', line=dict(color='black')
    ))
    fig_surv.add_trace(go.Scatter(
        x=times, y=upper_S, mode='lines', name='Upper CI', line=dict(dash='dash', color='magenta')
    ))
    fig_surv.add_trace(go.Scatter(
        x=times, y=lower_S, mode='lines', name='Lower CI', line=dict(dash='dash', color='magenta')
    ))
    fig_surv.update_layout(
        title="Survival Curve with 95% CI",
        xaxis_title="Time",
        yaxis_title="Survival Probability"
    )

    # Store for saving
    last_haz_fig = fig_haz
    last_surv_fig = fig_surv
    last_table    = None

    return fig_haz, fig_surv, f"Predictions generated for single patient.\nMedian survival time: {median_surv:.1f} days"

def generate_table_from_csv(csv_file):
    global last_haz_fig, last_surv_fig, last_table

    if csv_file is None:
        return None, "No CSV file provided."
    df = pd.read_csv(csv_file.name) if hasattr(csv_file, 'name') else pd.read_csv(csv_file)
    df = pd.get_dummies(df, columns=["treat", "men", "grade"], drop_first=True)
    for col in ["treat_1", "men_2", "grade_2", "grade_3"]:
        if col not in df.columns:
            df[col] = 0
    X = df[PREDS].astype("float32").copy()
    X["age"]   /= 100.0
    X["size"]  /= 100.0
    X["nodes"] /= 10.0
    X["prog"]  /= 1000.0
    X["oest"]  /= 1000.0

    surv_df = model.predict_surv_df(X.values)
    cumhaz  = model.predict_cumulative_hazards(X.values)
    hazard_rate = cumhaz.mean(axis=0).round(3)
    times = surv_df.index.values
    medians = [times[np.abs(s - 0.5).argmin()] for s in surv_df.values.T]

    df["hazard_rate"]     = hazard_rate.values
    df["median_survival"] = np.round(medians, 1)
    df["id"]              = [str(uuid.uuid4()) for _ in range(len(df))]

    last_table    = df
    last_haz_fig  = None
    last_surv_fig = None

    return df, "Predictions generated from CSV."

def save_new_plot():
    global last_haz_fig, last_surv_fig
    if last_haz_fig is None or last_surv_fig is None:
        return "No plots to save. Generate first."

    hz_fn = f"Hazard_plot.html"
    sv_fn = f"Survival_plot.html"
    try:
        last_haz_fig.write_html(hz_fn, include_plotlyjs="cdn")
        last_surv_fig.write_html(sv_fn, include_plotlyjs="cdn")
        return (
            f"Saved hazard plot as `{hz_fn}`\n"
            f"Saved survival plot as `{sv_fn}`"
        )
    except Exception as e:
        return f"Error saving plots: {e}"

def clean_single():
    global last_haz_fig, last_surv_fig, last_table
    last_haz_fig = last_surv_fig = None
    last_table = None
    return None, None, None

def clean_csv():
    global last_haz_fig, last_surv_fig, last_table
    last_haz_fig = last_surv_fig = None
    last_table = None
    return None, None

# ─── Gradio UI ────────────────────────────────────────────────────────────────
def main():
    with gr.Blocks() as demo:
        gr.Markdown("# Hazard/Survivor Plot Selector")
        with gr.Tabs():
            with gr.Tab("Single Patient Input"):
                gr.Markdown("### Enter Patient Information")
                with gr.Row():
                    treat_input  = gr.Number(label="Treat (0 or 1)")
                    age_input    = gr.Number(label="Age (integer)", precision=0)
                    men_input    = gr.Number(label="Menopausal Status (1 or 2)", precision=0)
                    size_input   = gr.Number(label="Tumour Size (mm)", precision=0)
                    grade_input  = gr.Number(label="Tumour Grade (1, 2, 3)", precision=0)
                    nodes_input  = gr.Number(label="Nodes (integer)", precision=0)
                    prog_input   = gr.Number(label="Progesterone (integer)", precision=0)
                    oest_input   = gr.Number(label="Oestrogen Status (integer)", precision=0)
                with gr.Row():
                    clear_single_btn    = gr.Button("Clear")
                    generate_single_btn = gr.Button("Generate")
                    save_single_btn     = gr.Button("Save")
                single_haz    = gr.Plot(label="Hazard Plot")
                single_surv   = gr.Plot(label="Survival Plot")
                single_status = gr.Textbox(label="Status")
                generate_single_btn.click(
                    fn=generate_single_patient,
                    inputs=[treat_input, age_input, men_input, size_input,
                            grade_input, nodes_input, prog_input, oest_input],
                    outputs=[single_haz, single_surv, single_status]
                )
                clear_single_btn.click(
                    fn=clean_single,
                    outputs=[single_haz, single_surv, single_status]
                )
                save_single_btn.click(fn=save_new_plot,
                                      outputs=[single_status])

            with gr.Tab("Upload CSV"):
                gr.Markdown("### Upload a CSV File to Generate a Table")
                csv_input = gr.File(label="Upload CSV")
                with gr.Row():
                    clear_csv_btn    = gr.Button("Clear")
                    generate_csv_btn = gr.Button("Generate")
                csv_table  = gr.DataFrame(label="Generated Table",
                                          headers=["id", "hazard_rate", "median_survival"])
                csv_status = gr.Textbox(label="Status")
                generate_csv_btn.click(
                    fn=generate_table_from_csv,
                    inputs=[csv_input],
                    outputs=[csv_table, csv_status]
                )
                clear_csv_btn.click(fn=clean_csv,
                                     outputs=[csv_table, csv_status])

        demo.launch(share=False)

if __name__ == "__main__":
    main()


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
