In [7]:
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import uuid
import random
import pickle
import torch
from pycox.models import CoxPH
import torchtuples as tt
import numpy as np

# Load your pickled DeepSurv model
with open("pickle/deepsurv_model_20250410_163846.pkl", "rb") as f:
    checkpoint = pickle.load(f)

model = CoxPH(checkpoint['model_structure'])
model.net.load_state_dict(checkpoint['model_state_dict'])
model.baseline_hazards_ = checkpoint['baseline_hazards']
model.baseline_cumulative_hazards_ = checkpoint['baseline_cumulative_hazards']

model.net.eval()


# Predictor column order must match training
PREDS = [
    "age", "size", "nodes", "prog", "oest",
    "treat_1", "men_2", "grade_2", "grade_3"
]


# Keep track of the last generated figure and table globally
last_figure = None
last_table = None

# Original plotting functions
def plot_hazard_data(hazard_data, title):
    if hazard_data is None:
        raise ValueError("Hazard data not loaded")

    required_columns = {'time', 'mean', 'hdi_5.5%', 'hdi_94.5%'}
    if not required_columns.issubset(hazard_data.columns):
        raise ValueError("Required columns not found in hazard data")

    time = hazard_data['time']
    hazard_mean = hazard_data['mean']
    hdi_5_5 = hazard_data['hdi_5.5%']
    hdi_94_5 = hazard_data['hdi_94.5%']

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=time, y=hazard_mean, mode='lines', name='Best Guess',
        line=dict(color='black'),
        hovertemplate="Time: %{x}<br>Best Guess: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_94_5, mode='lines', name='HDI 94.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 94.5%: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_5_5, mode='lines', name='HDI 5.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 5.5%: %{y}<extra></extra>"
    ))

    fig.update_layout(
        title=title,
        xaxis_title='Time',
        yaxis_title='Hazard Rate',
        hovermode='x'
    )
    return fig

def plot_survivor_data(survivor_data, title):
    if survivor_data is None:
        raise ValueError("Survivor data not loaded")

    required_columns = {'time', 'mean', 'hdi_5.5%', 'hdi_94.5%'}
    if not required_columns.issubset(survivor_data.columns):
        raise ValueError("Required columns not found in survivor data")

    time = survivor_data['time']
    survivor_mean = survivor_data['mean']
    hdi_5_5 = survivor_data['hdi_5.5%']
    hdi_94_5 = survivor_data['hdi_94.5%']

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=time, y=survivor_mean, mode='lines', name='Best Guess',
        line=dict(color='black'),
        hovertemplate="Time: %{x}<br>Best Guess: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_94_5, mode='lines', name='HDI 94.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 94.5%: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_5_5, mode='lines', name='HDI 5.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 5.5%: %{y}<extra></extra>"
    ))

    fig.update_layout(
        title=title,
        xaxis_title='Time',
        yaxis_title='Survivor Rate',
        hovermode='x'
    )
    return fig

#save plot function

def generate_single_patient(treat, age, men, size, grade, nodes, prog, oest):
    global last_figure
    global last_table

    # Construct a 1-row dataframe with correct columns
    row = pd.DataFrame([{
        "age": age,
        "size": size,
        "nodes": nodes,
        "prog": prog,
        "oest": oest,
        "treat_1": int(treat),
        "men_2": int(men == 2),
        "grade_2": int(grade == 2),
        "grade_3": int(grade == 3)
    }])[PREDS].astype("float32")

    # Predict survival + hazard curves
    surv = model.predict_surv_df(row.values)
    haz = model.predict_haz(row.values)

    time = surv.index.values
    surv_values = surv.values[:, 0]
    haz_values = haz.values[:, 0]

    # Plot survival
    fig1 = go.Figure()
    fig1.add_trace(go.Scatter(x=time, y=surv_values, mode='lines', name='Survival Probability'))
    fig1.update_layout(title="Survival Curve", xaxis_title="Time", yaxis_title="Survival")

    # Plot hazard
    fig2 = go.Figure()
    fig2.add_trace(go.Scatter(x=time, y=haz_values, mode='lines', name='Hazard Rate'))
    fig2.update_layout(title="Hazard Curve", xaxis_title="Time", yaxis_title="Hazard")

    last_figure = fig1
    last_table = None

    return fig2, fig1, f"Predictions generated for single patient."


def generate_table_from_csv(csv_file):
    global last_figure
    global last_table

    if csv_file is None:
        return None, "No CSV file provided."

    try:
        df = pd.read_csv(csv_file.name)
    except:
        df = pd.read_csv(csv_file)

    # Ensure correct column order


    # Apply one-hot encoding just like in training
    df = pd.get_dummies(df, columns=["treat", "men", "grade"], drop_first=True)

    # Ensure all required columns exist — add missing ones with 0
    for col in ["treat_1", "men_2", "grade_2", "grade_3"]:
        if col not in df.columns:
            df[col] = 0

    # Reorder columns to match model input
    x = df[PREDS].astype("float32").values

    #x = df[PREDS].astype("float32").values





    

    # Predict survival functions and extract statistics
    surv = model.predict_surv_df(x)




    
    # Predict cumulative hazards and derive hazard curve
    cumulative_haz = model.predict_cumulative_hazards(x)
    hazard = cumulative_haz.diff().fillna(0)

    # Mean hazard per patient (across time)
    hazard_rate = hazard.mean(axis=0).round(3)
#   hazard = model.predict_haz(x)

    # Add prediction results
    df["hazard_rate"] = hazard_rate.values

    # Use your preferred median survival method
    def find_nearest_arg(array, value):
        return (np.abs(array - value)).argmin()

    surv = model.predict_surv_df(x)
    times = surv.index.values
    median_survival = [times[find_nearest_arg(s, 0.5)] for s in surv.values.T]
    df["median_survival"] = np.round(median_survival, 1)


#   df["hazard_rate"] = hazard.mean(axis=1).round(3)  # mean hazard over time
#   df["median_survival"] = surv.median(axis=0).round(1)  # median survival time

    df["id"] = [str(uuid.uuid4()) for _ in range(len(df))]
    last_table = df
    last_figure = None

    return df, "Predictions generated from CSV"

def save_new_plot():
    global last_figure

    if last_figure is None:
        return "No plot to save."

    try:
        filename = "Last_Generated_Plot.jpg"
        last_figure.write_image(filename, format="jpg")
        return f"Plot saved as {filename}"
    except Exception as e:
        return f"Error saving plot: {e}"

# Clear functions: one for single patient and one for CSV
def clean_single():
    global last_figure
    global last_table
    last_figure = None
    last_table = None
    return None, None, None

def clean_csv():
    global last_figure
    global last_table
    last_figure = None
    last_table = None
    return None, None


with gr.Blocks() as demo:
    gr.Markdown("# Hazard/Survivor Plot Selector")

    with gr.Tabs():
        # 1) Single Patient Input Tab (table removed)
        with gr.Tab("Single Patient Input"):
            gr.Markdown("### Enter Patient Information")
            with gr.Row():
                treat_input = gr.Number(label="Treat (0 or 1)")
                age_input = gr.Number(label="Age (integer)", precision=0)
                men_input = gr.Number(label="Menopausal Status (1 or 2)", precision=0)
                size_input = gr.Number(label="Tumour Size (mm)", precision=0)
                grade_input = gr.Number(label="Tumour Grade (1,2,3)", precision=0)
                nodes_input = gr.Number(label="Nodes (integer)", precision=0)
                prog_input = gr.Number(label="Progesterone (integer)", precision=0)
                oest_input = gr.Number(label="Oestrogen Status (integer)", precision=0)

            # Button row: Clear (left), Generate (middle), Save (right)
            with gr.Row():
                clear_single_btn = gr.Button("Clear")
                generate_single_btn = gr.Button("Generate")
                save_single_btn = gr.Button("Save")

            # Output areas (table removed)
            single_hazard_plot_area = gr.Plot(label="Hazard Plot")
            single_survivor_plot_area = gr.Plot(label="Survivor Plot")
            single_save_status = gr.Textbox(label="Save Status")

            # Link the functions for the single patient tab
            generate_single_btn.click(
                fn=generate_single_patient,
                inputs=[
                    treat_input, age_input, men_input, size_input,
                    grade_input, nodes_input, prog_input, oest_input
                ],
                outputs=[
                    single_hazard_plot_area, 
                    single_survivor_plot_area, 
                    single_save_status
                ]
            )

            clear_single_btn.click(
                fn=clean_single,
                outputs=[
                    single_hazard_plot_area,
                    single_survivor_plot_area,
                    single_save_status
                ]
            )

            save_single_btn.click(
                fn=save_new_plot,
                outputs=single_save_status
            )

        # 2) Upload CSV Tab (graph outputs removed)
        with gr.Tab("Upload CSV"):
            gr.Markdown("### Upload a CSV File to Generate a Table")
            csv_file_input = gr.File(label="Upload CSV")

            # Button row: Clear (left), Generate (middle), Save (right)
            with gr.Row():
                clear_csv_btn = gr.Button("Clear")
                generate_csv_btn = gr.Button("Generate")
                save_csv_btn = gr.Button("Save")

            # Output areas (graphs removed)
            csv_table_output = gr.DataFrame(
                label="Generated Table", 
                headers=["id", "hazard_rate", "survivor_time"]
            )
            csv_save_status = gr.Textbox(label="Save Status")

            # Link the functions for the CSV tab
            generate_csv_btn.click(
                fn=generate_table_from_csv,
                inputs=csv_file_input,
                outputs=[csv_table_output, csv_save_status]
            )

            clear_csv_btn.click(
                fn=clean_csv,
                outputs=[csv_table_output, csv_save_status]
            )

            save_csv_btn.click(
                fn=save_new_plot,
                outputs=csv_save_status
            )

demo.launch(share=False)

Running on local URL:  http://127.0.0.1:7867

To create a public link, set `share=True` in `launch()`.


