In [6]:
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import uuid
import pickle
import numpy as np
from pycox.models import CoxPH
from scipy.stats import norm
import plotly.io as pio
import kaleido             # ensure Kaleido is importable
pio.renderers.default = 'png'  # optional, but locks in Kaleido for any `to_image()` calls


# Load your pickled DeepSurv model
with open("pickle/deepsurv_model_20250417_133745.pkl", "rb") as f:
    checkpoint = pickle.load(f)


# Reconstruct CoxPH model and load weights
model = CoxPH(checkpoint['model_structure'])
model.net.load_state_dict(checkpoint['model_state_dict'])
model.baseline_hazards_ = checkpoint['baseline_hazards']
model.baseline_cumulative_hazards_ = checkpoint['baseline_cumulative_hazards']
model.net.eval()


ALPHA = 0.1
Z95 = norm.ppf(1-ALPHA/2)

# Predictor column order must match training
PREDS = [
    "age", "size", "nodes", "prog", "oest",
    "treat_1", "men_2", "grade_2", "grade_3"
]

# Globals for save
last_figure = None
last_table = None

# 95% CI z-value
# Complementary log-log CI width factor (adjustable)



def generate_single_patient(treat, age, men, size, grade, nodes, prog, oest):
    global last_figure, last_table

    # Early-exit if no inputs
    inputs = [treat, age, men, size, grade, nodes, prog, oest]
    if all(v in (None, 0) for v in inputs):
        last_figure = None
        last_table  = None
        return None, None, "Please enter patient values before plotting."

    # Build feature row
    row = pd.DataFrame([{
        "age":     age,
        "size":    size,
        "nodes":   nodes,
        "prog":    prog,
        "oest":    oest,
        "treat_1": int(treat),
        "men_2":   int(men == 2),
        "grade_2": int(grade == 2),
        "grade_3": int(grade == 3)
    }])[PREDS].astype("float32")

    # Apply same scaling used in training
    row["age"]   /= 100.0
    row["size"]  /= 100.0
    row["nodes"] /= 10.0
    row["prog"]  /= 1000.0
    row["oest"]  /= 1000.0

    # Predict survival and cumulative hazard
    surv_df = model.predict_surv_df(row.values)
    cumhaz  = model.predict_cumulative_hazards(row.values)

    # Extract arrays
    times = surv_df.index.values
    S_hat = surv_df.values[:, 0]
    S_hat[S_hat>0.9] = 1
    
    H_hat = cumhaz.values[:, 0]
    H_hat[S_hat>0.9] = 0

    #0.05 Mock
    
    S_CL = 0.05/(S_hat * np.abs(np.log(S_hat)))

     # ────────────────────────────────────────────────────────────────────────────
     # If S_CL is infinite or unreasonably large, treat uncertainty as zero
    THRESHOLD = 100.0  # <-- tune this by trial & error
    mask = np.isinf(S_CL) | (S_CL > THRESHOLD)
    S_CL[mask] = 0.0
     # ────────────────────────────────────────────────────────────────────────────
    
    lower_S = S_hat**np.exp(-Z95 * S_CL)
    upper_S = S_hat**np.exp( Z95 * S_CL)

    #lower_S[S_CL==np.inf] = 1
    #upper_S[S_CL==np.inf] = 1
    
    lower_H = H_hat * np.exp(-Z95 * S_CL)
    upper_H = H_hat * np.exp( Z95 * S_CL)

    #lower_H[S_CL==np.inf] = 0
    #upper_H[S_CL==np.inf] = 0

    # Plot cumulative hazard + CI bands
    fig_haz = go.Figure()
    fig_haz.add_trace(go.Scatter(
        x=times, y=H_hat, mode='lines', name='Estimate', line=dict(color='black')
    ))
    fig_haz.add_trace(go.Scatter(
        x=times, y=upper_H, mode='lines', name='Upper CI', line=dict(dash='dash', color='magenta')
    ))
    fig_haz.add_trace(go.Scatter(
        x=times, y=lower_H, mode='lines', name='Lower CI', line=dict(dash='dash', color='magenta')
    ))
    fig_haz.update_layout(
        title="Cumulative Hazard with 95% CI",
        xaxis_title="Time",
        yaxis_title="Cumulative Hazard"
    )

    # Plot survival + CI bands
    fig_surv = go.Figure()
    fig_surv.add_trace(go.Scatter(
        x=times, y=S_hat, mode='lines', name='Estimate', line=dict(color='black')
    ))
    fig_surv.add_trace(go.Scatter(
        x=times, y=upper_S, mode='lines', name='Lower CI', line=dict(dash='dash', color='magenta')
    ))
    fig_surv.add_trace(go.Scatter(
        x=times, y=lower_S, mode='lines', name='Upper CI', line=dict(dash='dash', color='magenta')
    ))
    fig_surv.update_layout(
        title="Survival Curve with 95% CI",
        xaxis_title="Time",
        yaxis_title="Survival Probability"
    )

    # Keep for saving
    last_figure = fig_haz
    last_table  = None

    return fig_haz, fig_surv, "Predictions generated for single patient."


def generate_table_from_csv(csv_file):
    global last_figure, last_table
    if csv_file is None:
        return None, "No CSV file provided."
    df = pd.read_csv(csv_file.name) if hasattr(csv_file, 'name') else pd.read_csv(csv_file)
    df = pd.get_dummies(df, columns=["treat", "men", "grade"], drop_first=True)
    for col in ["treat_1", "men_2", "grade_2", "grade_3"]:
        if col not in df.columns:
            df[col] = 0
    X = df[PREDS].astype("float32").copy()
    X["age"]   /= 100.0
    X["size"]  /= 100.0
    X["nodes"] /= 10.0
    X["prog"]  /= 1000.0
    X["oest"]  /= 1000.0
    surv_df = model.predict_surv_df(X.values)
    cumhaz  = model.predict_cumulative_hazards(X.values)
    hazard_rate = cumhaz.mean(axis=0).round(3)
    times = surv_df.index.values
    medians = [times[np.abs(s - 0.5).argmin()] for s in surv_df.values.T]
    df["hazard_rate"]     = hazard_rate.values
    df["median_survival"] = np.round(medians, 1)
    df["id"]              = [str(uuid.uuid4()) for _ in range(len(df))]
    last_table  = df
    last_figure = None
    return df, "Predictions generated from CSV"


def save_new_plot():
    global last_figure
    if last_figure is None:
        return "No plot to save."
    try:
        fn = "Last_Generated_Plot.jpg"
        # Render via Kaleido (no browser ever launched)
        img_bytes = last_figure.to_image(format="jpg", engine="kaleido")
        with open(fn, "wb") as f:
            f.write(img_bytes)
        return f"Plot saved as {fn}"
    except Exception as e:
        return f"Error saving plot: {e}"


def clean_single():
    global last_figure, last_table
    last_figure = None
    last_table  = None
    return None, None, None


def clean_csv():
    global last_figure, last_table
    last_figure = None
    last_table  = None
    return None, None

# Gradio UI
def main():
    with gr.Blocks() as demo:
        gr.Markdown("# Hazard/Survivor Plot Selector")
        with gr.Tabs():
            with gr.Tab("Single Patient Input"):
                gr.Markdown("### Enter Patient Information")
                with gr.Row():
                    treat_input  = gr.Number(label="Treat (0 or 1)")
                    age_input    = gr.Number(label="Age (integer)", precision=0)
                    men_input    = gr.Number(label="Menopausal Status (1 or 2)", precision=0)
                    size_input   = gr.Number(label="Tumour Size (mm)", precision=0)
                    grade_input  = gr.Number(label="Tumour Grade (1, 2, 3)", precision=0)
                    nodes_input  = gr.Number(label="Nodes (integer)", precision=0)
                    prog_input   = gr.Number(label="Progesterone (integer)", precision=0)
                    oest_input   = gr.Number(label="Oestrogen Status (integer)", precision=0)
                with gr.Row():
                    clear_single_btn    = gr.Button("Clear")
                    generate_single_btn = gr.Button("Generate")
                    save_single_btn     = gr.Button("Save")
                single_haz    = gr.Plot(label="Hazard Plot")
                single_surv   = gr.Plot(label="Survivor Plot")
                single_status = gr.Textbox(label="Status")
                generate_single_btn.click(
                    fn=generate_single_patient,
                    inputs=[treat_input, age_input, men_input, size_input,
                            grade_input, nodes_input, prog_input, oest_input],
                    outputs=[single_haz, single_surv, single_status]
                )
                clear_single_btn.click(
                    fn=clean_single,
                    outputs=[single_haz, single_surv, single_status]
                )
                save_single_btn.click(fn=save_new_plot, outputs=single_status)
            with gr.Tab("Upload CSV"):
                gr.Markdown("### Upload a CSV File to Generate a Table")
                csv_input = gr.File(label="Upload CSV")
                with gr.Row():
                    clear_csv_btn    = gr.Button("Clear")
                    generate_csv_btn = gr.Button("Generate")
                    save_csv_btn     = gr.Button("Save")
                csv_table = gr.DataFrame(label="Generated Table",
                                         headers=["id", "hazard_rate", "median_survival"])
                csv_status = gr.Textbox(label="Status")
                generate_csv_btn.click(
                    fn=generate_table_from_csv,
                    inputs=csv_input,
                    outputs=[csv_table, csv_status]
                )
                clear_csv_btn.click(fn=clean_csv,
                                     outputs=[csv_table, csv_status])
                save_csv_btn.click(fn=save_new_plot, outputs=csv_status)
        demo.launch(share=False, inbrowser=False)

if __name__ == "__main__":
    main()

Running on local URL:  http://127.0.0.1:7866

To create a public link, set `share=True` in `launch()`.



divide by zero encountered in divide

Task exception was never retrieved
future: <Task finished name='Task-286' coro=<Browser._open_async() done, defined at /opt/miniconda3/lib/python3.9/site-packages/choreographer/browser.py:246> exception=BrowserFailedError('The browser seemed to close immediately after starting. Perhaps adding debug_browser=True will help.')>
Traceback (most recent call last):
  File "/opt/miniconda3/lib/python3.9/site-packages/choreographer/browser.py", line 271, in _open_async
    await self.populate_targets()
  File "/opt/miniconda3/lib/python3.9/site-packages/choreographer/browser.py", line 603, in populate_targets
    response = await self.browser.send_command("Target.getTargets")
choreographer.browser.BrowserClosedError: Command not completed because browser closed.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/miniconda3/lib/python3.9/site-packages/choreographer/browser.py", line 274, i