In [None]:
import os
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import pickle
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model


BASE_DIR = os.getcwd()
MODEL_DIR = os.path.join(BASE_DIR, "model_files")

def load_artifact(filename):

    path_in_model_dir = os.path.join(MODEL_DIR, filename)
    if os.path.exists(path_in_model_dir):
        return path_in_model_dir

    path_in_root = os.path.join(BASE_DIR, filename)
    if os.path.exists(path_in_root):
        return path_in_root

    raise FileNotFoundError(f"File '{filename}' not found in either '{MODEL_DIR}' or '{BASE_DIR}'")

try:
    scaler = joblib.load(load_artifact("scaler.pkl"))
    pca = pickle.load(open(load_artifact("pca.pkl"), "rb"))
    model = load_model(load_artifact("tremor_model.h5"))
    with open(load_artifact("feature_columns.pkl"), "rb") as f:
        feature_columns = pickle.load(f)
except Exception as e:
    raise RuntimeError(f"Failed to load model artifacts. Error: {e}")

non_numeric_cols = ['start_timestamp', 'end_timestamp']


# PREDICTION FUNCTION

def predict_tremor_from_csv(file_obj):
    try:
        # ---- Check file type ----
        if not file_obj.name.endswith('.csv'):
            empty_fig = plt.figure()
            return "‚ùó Unsupported file type. Please upload a CSV file.", empty_fig

        # ---- Load CSV ----
        df = pd.read_csv(file_obj.name)
        df = df.drop(columns=non_numeric_cols, errors='ignore')

        # ---- Check missing columns ----
        missing_cols = [col for col in feature_columns if col not in df.columns]
        if missing_cols:
            empty_fig = plt.figure()
            return (
                f"‚ùó Missing required columns: {', '.join(missing_cols)}.\n"
                "Please ensure your CSV matches the model's expected feature set.",
                empty_fig
            )

        # ---- Select expected features ----
        X = df[feature_columns].copy()

        # ---- Handle missing values ----
        warning_msg = ""
        if X.isnull().values.any():
            warning_msg = "‚ö†Ô∏è Missing values were detected and filled with zeros.\n\n"

        X = X.fillna(0)

        # ---- Preprocessing ----
        X_scaled = scaler.transform(X)
        X_pca = pca.transform(X_scaled)

        # ---- Prediction ----
        y_prob = model.predict(X_pca).flatten()
        y_pred = (y_prob > 0.5).astype(int)

        # ---- Build results ----
        results_text = []
        plot_data = []

        for i, (prob, pred) in enumerate(zip(y_prob, y_pred)):
            if pred == 1:
                status_msg = "‚ö†Ô∏è Possible rest tremor detected"
                confidence = prob
            else:
                status_msg = "üü¢ No rest tremor detected in this sample"
                confidence = 1 - prob

            results_text.append(
                f"Sample {i+1}: {status_msg} (Confidence: {confidence:.1%})"
            )

            plot_data.append({
                "Sample": i + 1,
                "Probability": prob,
                "Status": "Possible Tremor" if pred == 1 else "No Tremor"
            })

        full_text = warning_msg + "\n".join(results_text)
        df_plot = pd.DataFrame(plot_data)


        # Create Matplotlib Plot
        fig, ax = plt.subplots(figsize=(8, 4))

        colors = df_plot["Status"].map({
            "Possible Tremor": "red",
            "No Tremor": "green"
        })

        ax.bar(df_plot["Sample"], df_plot["Probability"], color=colors)
        ax.set_ylim(0, 1)
        ax.set_xlabel("Sample")
        ax.set_ylabel("Probability")
        ax.set_title("Rest Tremor Probability per Sample")

        return full_text, fig

    except Exception as e:
        empty_fig = plt.figure()
        error_text = (
            "‚ùó An error occurred during processing.\n"
            "Please verify that your file:\n"
            "‚Ä¢ Is a valid CSV file\n"
            "‚Ä¢ Contains all required sensor features\n"
            "‚Ä¢ Has no formatting issues\n\n"
            f"Technical details: {str(e)}"
        )
        return error_text, empty_fig



# EXAMPLES
example1_path = "example1.csv"
example2_path = "example2.csv"
example3_path = "example3.csv"


# GRADIO INTERFACE

iface = gr.Interface(
    fn=predict_tremor_from_csv,
    inputs=gr.File(label="Upload Sensor Data (CSV)", type="filepath"),
    outputs=[
        gr.Textbox(label="Analysis Results", lines=8, max_lines=15),
        gr.Plot(label="üìà Rest Tremor Probability per Sample")
    ],
    title="ü©∫ Parkinson‚Äôs Rest Tremor Screening Assistant",
    description="""
    This tool analyzes inertial sensor data to estimate the likelihood of **rest tremor**,
    an early indicator of Parkinson‚Äôs disease.

    üîπ Not a medical diagnosis ‚Äî consult a neurologist for confirmation.
    üîπ Accuracy depends on data quality and correct CSV feature structure.
    """,
    examples=[[example1_path], [example2_path], [example3_path]],
    cache_examples=False
)


# LAUNCH APP

if __name__ == "__main__":
    iface.launch(
        server_name="127.0.0.1",
        share=False,
        inbrowser=True,
        show_error=True
    )



Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>