In [3]:
import gradio as gr
import plotly.graph_objects as go
import pandas as pd
import uuid
import random
import torch
import torchtuples as tt
import numpy as np
from pycox.models import CoxPH
from sklearn.preprocessing import MinMaxScaler

# Globals
last_figure = None
last_table = None

# DeepSurv Config
CONTINUOUS_VARIABLES = ["age", "size", "nodes", "prog", "oest"]
CATEGORICAL_VARIABLES = ["treat", "men", "grade"]
PREDS = ['age', 'size', 'nodes', 'prog', 'oest', 'treat_1', 'men_2', 'grade_2', 'grade_3']
t_point = 365

# Define and prepare DeepSurv model
net = torch.nn.Sequential(
    torch.nn.Linear(len(PREDS), 32),
    torch.nn.ReLU(),
    torch.nn.BatchNorm1d(32),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(32, 1)
)

model = CoxPH(net, tt.optim.Adam)
model.optimizer.set_lr(0.01)

# Dummy training for activation (replace this with real training if needed)
X_dummy = np.random.rand(200, len(PREDS)).astype("float32")
y_dummy = (np.random.exponential(1000, 200), np.random.randint(0, 2, 200).astype(bool))
model.fit(X_dummy, y_dummy, batch_size=64, epochs=5, verbose=False)
model.compute_baseline_hazards()


# Plotting Functions
def plot_hazard_data(hazard_data, title):
    if hazard_data is None:
        raise ValueError("Hazard data not loaded")
    required_columns = {'time', 'mean', 'hdi_5.5%', 'hdi_94.5%'}
    if not required_columns.issubset(hazard_data.columns):
        raise ValueError("Required columns not found in hazard data")

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=hazard_data['time'], y=hazard_data['mean'], mode='lines', name='Best Guess', line=dict(color='black')))
    fig.add_trace(go.Scatter(x=hazard_data['time'], y=hazard_data['hdi_94.5%'], mode='lines', name='HDI 94.5%', line=dict(color='magenta', dash='dash')))
    fig.add_trace(go.Scatter(x=hazard_data['time'], y=hazard_data['hdi_5.5%'], mode='lines', name='HDI 5.5%', line=dict(color='magenta', dash='dash')))
    fig.update_layout(title=title, xaxis_title='Time', yaxis_title='Hazard Rate', hovermode='x')
    return fig


def plot_survivor_data(survivor_data, title):
    if survivor_data is None:
        raise ValueError("Survivor data not loaded")
    required_columns = {'time', 'mean', 'hdi_5.5%', 'hdi_94.5%'}
    if not required_columns.issubset(survivor_data.columns):
        raise ValueError("Required columns not found in survivor data")

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=survivor_data['time'], y=survivor_data['mean'], mode='lines', name='Best Guess', line=dict(color='black')))
    fig.add_trace(go.Scatter(x=survivor_data['time'], y=survivor_data['hdi_94.5%'], mode='lines', name='HDI 94.5%', line=dict(color='magenta', dash='dash')))
    fig.add_trace(go.Scatter(x=survivor_data['time'], y=survivor_data['hdi_5.5%'], mode='lines', name='HDI 5.5%', line=dict(color='magenta', dash='dash')))
    fig.update_layout(title=title, xaxis_title='Time', yaxis_title='Survivor Rate', hovermode='x')
    return fig


def generate_single_patient(treat, age, men, size, grade, nodes, prog, oest):
    global last_figure, last_table

    # Create a DataFrame for the single patient's input
    input_data = {
        "treat": [treat],
        "age": [age],
        "men": [men],
        "size": [size],
        "grade": [grade],
        "nodes": [nodes],
        "prog": [prog],
        "oest": [oest]
    }
    df = pd.DataFrame(input_data)

    # Normalize continuous variables
    scaler = MinMaxScaler()
    df[CONTINUOUS_VARIABLES] = scaler.fit_transform(df[CONTINUOUS_VARIABLES])

    # One-hot encode categorical variables
    df = pd.get_dummies(df, columns=CATEGORICAL_VARIABLES, drop_first=True)

    # Add missing dummies
    for col in PREDS:
        if col not in df.columns:
            df[col] = 0.0

    # Ensure column order
    df = df[PREDS]

    # Convert to tensor
    X_input = df.astype("float32").values
    X_tensor = torch.tensor(X_input, dtype=torch.float32)

    # Predict survival curve
    surv_df = model.predict_surv_df(X_input)
    surv = surv_df.iloc[:, 0]

    # Compute cumulative hazard as -log(Survival)
    hazard = -np.log(surv.clip(lower=1e-6))

    # Build DataFrames for plot
    hazard_df = pd.DataFrame({
        "time": surv_df.index,
        "mean": hazard,
        "hdi_5.5%": hazard * 0.8,      # mock CI
        "hdi_94.5%": hazard * 1.2
    })

    survivor_df = pd.DataFrame({
        "time": surv_df.index,
        "mean": surv,
        "hdi_5.5%": surv * 0.8,
        "hdi_94.5%": surv * 1.2
    })

    # Generate plots using your existing visual format
    hazard_plot = plot_hazard_data(hazard_df, title="Hazard Plot")
    survivor_plot = plot_survivor_data(survivor_df, title="Survivor Plot")

    last_figure = hazard_plot
    last_table = None

    return hazard_plot, survivor_plot, "Generated DeepSurv predictions for patient"



def generate_table_from_csv(csv_file):
    global last_figure, last_table
    if csv_file is None:
        return None, "No CSV file provided."
    try:
        df = pd.read_csv(csv_file.name)
    except:
        df = pd.read_csv(csv_file)

    df["id"] = [str(uuid.uuid4()) for _ in range(len(df))]

    scaler = MinMaxScaler()
    df[CONTINUOUS_VARIABLES] = scaler.fit_transform(df[CONTINUOUS_VARIABLES])
    df = pd.get_dummies(df, columns=CATEGORICAL_VARIABLES, drop_first=True)

    for col in PREDS:
        if col not in df.columns:
            df[col] = 0.0

    X_input = df[PREDS].astype("float32").values
    X_tensor = torch.tensor(X_input, dtype=torch.float32)

    baseline_haz = model.baseline_hazards_
    baseline_times = baseline_haz.index.values.astype(float)
    baseline_values = baseline_haz.values.astype(float)
    h0_t = baseline_values[-1] if t_point > baseline_times.max() else np.interp(t_point, baseline_times, baseline_values)

    model.net.eval()
    with torch.no_grad():
        log_risk = model.net(X_tensor).squeeze().numpy()
        risk_scores = np.exp(log_risk)
        hazard_rates = h0_t * risk_scores

    surv_df = model.predict_surv_df(X_input)

    def compute_median_survival_time(surv_df):
        medians = []
        for i in range(surv_df.shape[1]):
            s = surv_df.iloc[:, i]
            below_half = s[s < 0.5]
            medians.append(np.nan if below_half.empty else below_half.index[0])
        return np.array(medians)

    predicted_medians = compute_median_survival_time(surv_df)

    df_out = pd.DataFrame({
        "id": df["id"],
        "hazard_rate": hazard_rates,
        "survivor_time": predicted_medians
    })

    last_figure = None
    last_table = df_out
    return df_out, "DeepSurv predictions applied"


def save_new_plot():
    global last_figure
    if last_figure is None:
        return "No plot to save."
    try:
        filename = "Last_Generated_Plot.jpg"
        last_figure.write_image(filename, format="jpg")
        return f"Plot saved as {filename}"
    except Exception as e:
        return f"Error saving plot: {e}"


def clean_single():
    global last_figure, last_table
    last_figure = None
    last_table = None
    return None, None, None


def clean_csv():
    global last_figure, last_table
    last_figure = None
    last_table = None
    return None, None


with gr.Blocks() as demo:
    gr.Markdown("# Hazard/Survivor Plot Selector")

    with gr.Tabs():
        with gr.Tab("Single Patient Input"):
            gr.Markdown("### Enter Patient Information")
            with gr.Row():
                treat_input = gr.Number(label="Treat (0 or 1)")
                age_input = gr.Number(label="Age (integer)", precision=0)
                men_input = gr.Number(label="Menopausal Status (1 or 2)", precision=0)
                size_input = gr.Number(label="Tumour Size (mm)", precision=0)
                grade_input = gr.Number(label="Tumour Grade (1,2,3)", precision=0)
                nodes_input = gr.Number(label="Nodes (integer)", precision=0)
                prog_input = gr.Number(label="Progesterone (integer)", precision=0)
                oest_input = gr.Number(label="Oestrogen Status (integer)", precision=0)

            with gr.Row():
                clear_single_btn = gr.Button("Clear")
                generate_single_btn = gr.Button("Generate")
                save_single_btn = gr.Button("Save")

            single_hazard_plot_area = gr.Plot(label="Hazard Plot")
            single_survivor_plot_area = gr.Plot(label="Survivor Plot")
            single_save_status = gr.Textbox(label="Save Status")

            generate_single_btn.click(
                fn=generate_single_patient,
                inputs=[treat_input, age_input, men_input, size_input, grade_input, nodes_input, prog_input, oest_input],
                outputs=[single_hazard_plot_area, single_survivor_plot_area, single_save_status]
            )

            clear_single_btn.click(fn=clean_single, outputs=[single_hazard_plot_area, single_survivor_plot_area, single_save_status])
            save_single_btn.click(fn=save_new_plot, outputs=single_save_status)

        with gr.Tab("Upload CSV"):
            gr.Markdown("### Upload a CSV File to Generate a Table")
            csv_file_input = gr.File(label="Upload CSV")

            with gr.Row():
                clear_csv_btn = gr.Button("Clear")
                generate_csv_btn = gr.Button("Generate")
                save_csv_btn = gr.Button("Save")

            csv_table_output = gr.DataFrame(label="Generated Table", headers=["id", "hazard_rate", "survivor_time"])
            csv_save_status = gr.Textbox(label="Save Status")

            generate_csv_btn.click(fn=generate_table_from_csv, inputs=csv_file_input, outputs=[csv_table_output, csv_save_status])
            clear_csv_btn.click(fn=clean_csv, outputs=[csv_table_output, csv_save_status])
            save_csv_btn.click(fn=save_new_plot, outputs=csv_save_status)

demo.launch(share=False)


Running on local URL:  http://127.0.0.1:7901

To create a public link, set `share=True` in `launch()`.


