In [3]:
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import uuid
import random

# Keep track of the last generated figure and table globally
last_figure = None
last_table = None

# Original plotting functions
def plot_hazard_data(hazard_data, title):
    if hazard_data is None:
        raise ValueError("Hazard data not loaded")

    required_columns = {'time', 'mean', 'hdi_5.5%', 'hdi_94.5%'}
    if not required_columns.issubset(hazard_data.columns):
        raise ValueError("Required columns not found in hazard data")

    time = hazard_data['time']
    hazard_mean = hazard_data['mean']
    hdi_5_5 = hazard_data['hdi_5.5%']
    hdi_94_5 = hazard_data['hdi_94.5%']

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=time, y=hazard_mean, mode='lines', name='Best Guess',
        line=dict(color='black'),
        hovertemplate="Time: %{x}<br>Best Guess: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_94_5, mode='lines', name='HDI 94.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 94.5%: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_5_5, mode='lines', name='HDI 5.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 5.5%: %{y}<extra></extra>"
    ))

    fig.update_layout(
        title=title,
        xaxis_title='Time',
        yaxis_title='Hazard Rate',
        hovermode='x'
    )
    return fig

def plot_survivor_data(survivor_data, title):
    if survivor_data is None:
        raise ValueError("Survivor data not loaded")

    required_columns = {'time', 'mean', 'hdi_5.5%', 'hdi_94.5%'}
    if not required_columns.issubset(survivor_data.columns):
        raise ValueError("Required columns not found in survivor data")

    time = survivor_data['time']
    survivor_mean = survivor_data['mean']
    hdi_5_5 = survivor_data['hdi_5.5%']
    hdi_94_5 = survivor_data['hdi_94.5%']

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=time, y=survivor_mean, mode='lines', name='Best Guess',
        line=dict(color='black'),
        hovertemplate="Time: %{x}<br>Best Guess: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_94_5, mode='lines', name='HDI 94.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 94.5%: %{y}<extra></extra>"
    ))
    fig.add_trace(go.Scatter(
        x=time, y=hdi_5_5, mode='lines', name='HDI 5.5%',
        line=dict(color='magenta', dash='dash'),
        hovertemplate="Time: %{x}<br>HDI 5.5%: %{y}<extra></extra>"
    ))

    fig.update_layout(
        title=title,
        xaxis_title='Time',
        yaxis_title='Survivor Rate',
        hovermode='x'
    )
    return fig

#save plot function

def generate_single_patient(treat, age, men, size, grade, nodes, prog, oest):
    global last_figure, last_table

    # Create DataFrame for single patient
    input_dict = {
        "treat": [treat],
        "age": [age],
        "men": [men],
        "size": [size],
        "grade": [grade],
        "nodes": [nodes],
        "prog": [prog],
        "oest": [oest],
    }
    df = pd.DataFrame(input_dict)

    # Preprocessing
    scaler = MinMaxScaler()
    df[CONTINUOUS_VARIABLES] = scaler.fit_transform(df[CONTINUOUS_VARIABLES])
    df = pd.get_dummies(df, columns=CATEGORICAL_VARIABLES, drop_first=True)

    for col in PREDS:
        if col not in df.columns:
            df[col] = 0.0

    X_input = df[PREDS].astype("float32").values
    X_tensor = torch.tensor(X_input, dtype=torch.float32)

    # Compute survival probabilities
    surv_df = model.predict_surv_df(X_input)
    surv_curve = surv_df.iloc[:, 0]

    # Compute cumulative hazard = -log(S(t))
    cumhaz = -np.log(surv_curve.clip(lower=1e-6))
    hazard_df = pd.DataFrame({
        "time": surv_df.index,
        "mean": cumhaz,
        "hdi_5.5%": cumhaz * 0.8,   # mock HDI range
        "hdi_94.5%": cumhaz * 1.2
    })

    survivor_df = pd.DataFrame({
        "time": surv_df.index,
        "mean": surv_curve,
        "hdi_5.5%": surv_curve * 0.8,
        "hdi_94.5%": surv_curve * 1.2
    })

    # Generate plots
    hazard_fig = plot_hazard_data(hazard_df, title="Hazard_0 Data")
    survivor_fig = plot_survivor_data(survivor_df, title="Survivor_0 Data")

    last_figure = hazard_fig
    last_table = None

    return hazard_fig, survivor_fig, "Generated survival and hazard curves for patient"


def generate_table_from_csv(csv_file):
    global last_figure
    global last_table

    if csv_file is None:
        return None, "No CSV file provided."

    try:
        # Read the CSV into a DataFrame
        df = pd.read_csv(csv_file.name)
    except:
        # If reading by .name fails, read the file object directly
        df = pd.read_csv(csv_file)

    # Generate a unique ID, hazard_rate, and survivor_time for each row
    df["id"] = [str(uuid.uuid4()) for _ in range(len(df))]
    df["hazard_rate"] = [round(random.uniform(0.0, 1.0), 3) for _ in range(len(df))]
    df["survivor_time"] = [random.randint(1, 100) for _ in range(len(df))]

    # Clear the last figure since we are showing a table
    last_figure = None
    last_table = df

    return df, "Table generated from CSV"

def save_new_plot():
    global last_figure

    if last_figure is None:
        return "No plot to save."

    try:
        filename = "Last_Generated_Plot.jpg"
        last_figure.write_image(filename, format="jpg")
        return f"Plot saved as {filename}"
    except Exception as e:
        return f"Error saving plot: {e}"

# Clear functions: one for single patient and one for CSV
def clean_single():
    global last_figure
    global last_table
    last_figure = None
    last_table = None
    return None, None, None

def clean_csv():
    global last_figure
    global last_table
    last_figure = None
    last_table = None
    return None, None


with gr.Blocks() as demo:
    gr.Markdown("# Hazard/Survivor Plot Selector")

    with gr.Tabs():
        # 1) Single Patient Input Tab (table removed)
        with gr.Tab("Single Patient Input"):
            gr.Markdown("### Enter Patient Information")
            with gr.Row():
                treat_input = gr.Number(label="Treat (0 or 1)")
                age_input = gr.Number(label="Age (integer)", precision=0)
                men_input = gr.Number(label="Menopausal Status (1 or 2)", precision=0)
                size_input = gr.Number(label="Tumour Size (mm)", precision=0)
                grade_input = gr.Number(label="Tumour Grade (1,2,3)", precision=0)
                nodes_input = gr.Number(label="Nodes (integer)", precision=0)
                prog_input = gr.Number(label="Progesterone (integer)", precision=0)
                oest_input = gr.Number(label="Oestrogen Status (integer)", precision=0)

            # Button row: Clear (left), Generate (middle), Save (right)
            with gr.Row():
                clear_single_btn = gr.Button("Clear")
                generate_single_btn = gr.Button("Generate")
                save_single_btn = gr.Button("Save")

            # Output areas (table removed)
            single_hazard_plot_area = gr.Plot(label="Hazard Plot")
            single_survivor_plot_area = gr.Plot(label="Survivor Plot")
            single_save_status = gr.Textbox(label="Save Status")

            # Link the functions for the single patient tab
            generate_single_btn.click(
                fn=generate_single_patient,
                inputs=[
                    treat_input, age_input, men_input, size_input,
                    grade_input, nodes_input, prog_input, oest_input
                ],
                outputs=[
                    single_hazard_plot_area, 
                    single_survivor_plot_area, 
                    single_save_status
                ]
            )

            clear_single_btn.click(
                fn=clean_single,
                outputs=[
                    single_hazard_plot_area,
                    single_survivor_plot_area,
                    single_save_status
                ]
            )

            save_single_btn.click(
                fn=save_new_plot,
                outputs=single_save_status
            )

        # 2) Upload CSV Tab (graph outputs removed)
        with gr.Tab("Upload CSV"):
            gr.Markdown("### Upload a CSV File to Generate a Table")
            csv_file_input = gr.File(label="Upload CSV")

            # Button row: Clear (left), Generate (middle), Save (right)
            with gr.Row():
                clear_csv_btn = gr.Button("Clear")
                generate_csv_btn = gr.Button("Generate")
                save_csv_btn = gr.Button("Save")

            # Output areas (graphs removed)
            csv_table_output = gr.DataFrame(
                label="Generated Table", 
                headers=["id", "hazard_rate", "survivor_time"]
            )
            csv_save_status = gr.Textbox(label="Save Status")

            # Link the functions for the CSV tab
            generate_csv_btn.click(
                fn=generate_table_from_csv,
                inputs=csv_file_input,
                outputs=[csv_table_output, csv_save_status]
            )

            clear_csv_btn.click(
                fn=clean_csv,
                outputs=[csv_table_output, csv_save_status]
            )

            save_csv_btn.click(
                fn=save_new_plot,
                outputs=csv_save_status
            )

demo.launch(share=False)

Running on local URL:  http://127.0.0.1:7891

To create a public link, set `share=True` in `launch()`.


