In [None]:
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import uuid
import pickle
import numpy as np
from pycox.models import CoxPH

# Load your pickled DeepSurv model
with open("pickle/deepsurv_model_20250417_133745.pkl", "rb") as f:
    checkpoint = pickle.load(f)

# Reconstruct the CoxPH model
model = CoxPH(checkpoint['model_structure'])
model.net.load_state_dict(checkpoint['model_state_dict'])
model.baseline_hazards_ = checkpoint['baseline_hazards']
model.baseline_cumulative_hazards_ = checkpoint['baseline_cumulative_hazards']
model.net.eval()

# Predictor column order must match training
PREDS = [
    "age", "size", "nodes", "prog", "oest",
    "treat_1", "men_2", "grade_2", "grade_3"
]

# Globals for last generated plot/table
last_figure = None
last_table = None

def generate_single_patient(treat, age, men, size, grade, nodes, prog, oest):
    global last_figure, last_table

    # Early exit if all inputs are blank or zero
    inputs = [treat, age, men, size, grade, nodes, prog, oest]
    if all(v in (None, 0) for v in inputs):
        last_figure = None
        last_table  = None
        return None, None, "Please enter patient values before plotting."

    # Build DataFrame row
    row = pd.DataFrame([{
        "age":      age,
        "size":     size,
        "nodes":    nodes,
        "prog":     prog,
        "oest":     oest,
        "treat_1":  int(treat),
        "men_2":    int(men == 2),
        "grade_2":  int(grade == 2),
        "grade_3":  int(grade == 3)
    }])[PREDS].astype("float32")

    # Apply the same scaling used during training
    row["age"]   /= 100.0
    row["size"]  /= 100.0
    row["nodes"] /= 10.0
    row["prog"]  /= 1000.0
    row["oest"]  /= 1000.0

    # Predict survival function and cumulative hazards
    surv_df = model.predict_surv_df(row.values)
    cumhaz  = model.predict_cumulative_hazards(row.values)

    # Extract time axis and values
    times       = surv_df.index.values
    surv_vals   = surv_df.values[:, 0]
    cumhaz_vals = cumhaz.values[:, 0]

    # Plot cumulative hazard (monotonic increasing)
    fig_haz = go.Figure()
    fig_haz.add_trace(go.Scatter(
        x=times,
        y=cumhaz_vals,
        mode='lines',
        name='Cumulative Hazard'
    ))
    fig_haz.update_layout(
        title="Cumulative Hazard Curve",
        xaxis_title="Time",
        yaxis_title="Cumulative Hazard"
    )

    # Plot survival curve
    fig_surv = go.Figure()
    fig_surv.add_trace(go.Scatter(
        x=times,
        y=surv_vals,
        mode='lines',
        name='Survival Probability'
    ))
    fig_surv.update_layout(
        title="Survival Curve",
        xaxis_title="Time",
        yaxis_title="Survival Probability"
    )

    # Keep track for saving
    last_figure = fig_haz
    last_table  = None

    return fig_haz, fig_surv, "Predictions generated for single patient."


def generate_table_from_csv(csv_file):
    global last_figure, last_table

    if csv_file is None:
        return None, "No CSV file provided."

    # Read CSV into DataFrame
    df = pd.read_csv(csv_file.name) if hasattr(csv_file, 'name') else pd.read_csv(csv_file)

    # One-hot encode categorical vars (must match training)
    df = pd.get_dummies(df, columns=["treat", "men", "grade"], drop_first=True)
    for col in ["treat_1", "men_2", "grade_2", "grade_3"]:
        if col not in df.columns:
            df[col] = 0

    # Ensure predictor order and types
    X = df[PREDS].astype("float32").copy()

    # Apply same scaling
    X["age"]   /= 100.0
    X["size"]  /= 100.0
    X["nodes"] /= 10.0
    X["prog"]  /= 1000.0
    X["oest"]  /= 1000.0

    # Predict
    surv_df = model.predict_surv_df(X.values)
    cumhaz  = model.predict_cumulative_hazards(X.values)

    # Cumulative hazard is monotonic
    hazard_rates = cumhaz.mean(axis=0).round(3)
    times = surv_df.index.values
    # median survival per row:
    median_surv = [
        times[np.abs(s - 0.5).argmin()] for s in surv_df.values.T
    ]

    # Append results
    df["hazard_rate"]    = hazard_rates.values
    df["median_survival"] = np.round(median_surv, 1)
    df["id"]              = [str(uuid.uuid4()) for _ in range(len(df))]

    last_table  = df
    last_figure = None
    return df, "Predictions generated from CSV"


def save_new_plot():
    global last_figure
    if last_figure is None:
        return "No plot to save."
    try:
        fn = "Last_Generated_Plot.jpg"
        last_figure.write_image(fn)
        return f"Plot saved as {fn}"
    except Exception as e:
        return f"Error saving plot: {e}"


def clean_single():
    global last_figure, last_table
    last_figure = None
    last_table  = None
    return None, None, None


def clean_csv():
    global last_figure, last_table
    last_figure = None
    last_table  = None
    return None, None


# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# Hazard/Survivor Plot Selector")

    with gr.Tabs():
        with gr.Tab("Single Patient Input"):
            gr.Markdown("### Enter Patient Information")
            with gr.Row():
                treat_input = gr.Number(label="Treat (0 or 1)")
                age_input   = gr.Number(label="Age (integer)", precision=0)
                men_input   = gr.Number(label="Menopausal Status (1 or 2)", precision=0)
                size_input  = gr.Number(label="Tumour Size (mm)", precision=0)
                grade_input = gr.Number(label="Tumour Grade (1, 2, 3)", precision=0)
                nodes_input = gr.Number(label="Nodes (integer)", precision=0)
                prog_input  = gr.Number(label="Progesterone (integer)", precision=0)
                oest_input  = gr.Number(label="Oestrogen Status (integer)", precision=0)

            with gr.Row():
                clear_single_btn    = gr.Button("Clear")
                generate_single_btn = gr.Button("Generate")
                save_single_btn     = gr.Button("Save")

            single_haz    = gr.Plot(label="Hazard Plot")
            single_surv   = gr.Plot(label="Survivor Plot")
            single_status = gr.Textbox(label="Status")

            generate_single_btn.click(
                fn=generate_single_patient,
                inputs=[
                    treat_input, age_input, men_input, size_input,
                    grade_input, nodes_input, prog_input, oest_input
                ],
                outputs=[single_haz, single_surv, single_status]
            )
            clear_single_btn.click(
                fn=clean_single,
                outputs=[single_haz, single_surv, single_status]
            )
            save_single_btn.click(fn=save_new_plot, outputs=single_status)

        with gr.Tab("Upload CSV"):
            gr.Markdown("### Upload a CSV File to Generate a Table")
            csv_file_input = gr.File(label="Upload CSV")

            with gr.Row():
                clear_csv_btn    = gr.Button("Clear")
                generate_csv_btn = gr.Button("Generate")
                save_csv_btn     = gr.Button("Save")

            csv_table_output = gr.DataFrame(
                label="Generated Table",
                headers=["id", "hazard_rate", "median_survival"]
            )
            csv_status = gr.Textbox(label="Status")

            generate_csv_btn.click(
                fn=generate_table_from_csv,
                inputs=csv_file_input,
                outputs=[csv_table_output, csv_status]
            )
            clear_csv_btn.click(
                fn=clean_csv,
                outputs=[csv_table_output, csv_status]
            )
            save_csv_btn.click(fn=save_new_plot, outputs=csv_status)

    demo.launch(share=False)


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'torch'