# Welcome to the MALDI-UI session

In this notebook, we will work with **MALDI-TOF mass spectrometry data**, specifically the **DRIAMS B database**, which contains routine MALDI-TOF MS data from the Canton Hospital Basel-Land.

Our objective for this practical session is to create a simple UI, where the user can:
- Load in dataset of spectra
- Perform basic preprocessing of the MALDI-TOF MS data
- Show a selected spectra
- Check for quality metrics
- See model prediction
- Show important features for the prediction

Let's get started!


# Set-up

In [1]:
!pip install gradio plotly maldi-nn shap



In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [11]:
from tensorflow.keras.models import load_model
import numpy as np
import joblib
import json
import shap

mlp_model = load_model("'/content/drive/MyDrive/Models/mlp_keras_model.h5")
xgb_model = joblib.load("/content/drive/MyDrive/Models/xgboost_top5_species_model.pkl")

# Do we implement shapley values as well?

# SHAPLEY values
# Load saved background
background = np.load("/content/drive/MyDrive/Models/mlp_shap_background.npy")
feature_names = [f"{2000 + i*3}-{2000 + (i+1)*3} Da" for i in range(6000)]

# For XGBoost (TreeExplainer is fast and efficient)
xgb_explainer = shap.Explainer(xgb_model)

# For MLP
mlp_explainer = shap.DeepExplainer(mlp_model, background)


Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.


The structure of `inputs` doesn't match the expected structure.
Expected: input_layer
Received: inputs=['Tensor(shape=(100, 6000))']



# Gradio app

In [14]:
# --- Set-up
import gradio as gr
import pandas as pd
import plotly.graph_objs as go
import maldi_nn.spectrum as maldi_spectrum
from maldi_nn.spectrum import SpectrumObject
import os
import numpy as np

# Path to spectra folder
DATA_DIR = '/content/drive/MyDrive/DRIAMS/DRIAMS-B/DRIAMS-B/raw/2018'

In [15]:
# --- Read Spectrum ---
def read_spectrum_file(filepath):
    spec_obj = SpectrumObject.from_tsv(filepath)
    spec_obj.intensity = np.array(spec_obj.intensity[1:], dtype=float)
    spec_obj.mz = np.array(spec_obj.mz[1:], dtype=float)
    return spec_obj


In [16]:
# --- Preprocessing Function ---
def preprocess_spectra(filenames, var_stabilizer, smoother, baseline, normalizer, binner):
    if not filenames:
        return "No spectra selected", [], []

    file_paths = [os.path.join(DATA_DIR, f) for f in filenames]
    spectra = [read_spectrum_file(path) for path in file_paths]
    steps = []

    if var_stabilizer:
        steps.append(maldi_spectrum.VarStabilizer(method="sqrt"))
    if smoother:
        steps.append(maldi_spectrum.Smoother(halfwindow=10))
    if baseline:
        steps.append(maldi_spectrum.BaselineCorrecter(method="SNIP", snip_n_iter=20))
    if normalizer:
        steps.append(maldi_spectrum.Normalizer(sum=1))
    if binner:
        steps.append(maldi_spectrum.Binner(step=3))

    preprocessor = maldi_spectrum.SequentialPreprocessor(*steps)
    spectra_preprocessed = [preprocessor(spectrum) for spectrum in spectra]

    return "Preprocessing complete", spectra_preprocessed, file_paths


In [17]:
# --- Plotting ---
def plot_preprocessed_spectra(preprocessed_spectra, file_paths, selected_filenames):
    if not preprocessed_spectra or not selected_filenames:
        return go.Figure()

    file_name_to_index = {os.path.basename(p): idx for idx, p in enumerate(file_paths)}
    fig = go.Figure()

    for fname in selected_filenames:
        i = file_name_to_index.get(fname)
        if i is None:
            continue
        spectrum = preprocessed_spectra[i]
        fig.add_trace(go.Scatter(x=spectrum.mz, y=spectrum.intensity, mode="lines", name=fname))

    fig.update_layout(title="Preprocessed Spectra", xaxis_title="m/z", yaxis_title="Intensity")
    return fig

In [18]:
# --- Fancy Prediction Output ---
def predict_species_from_spectra(preprocessed_spectra, selected_filenames, model_choice, file_paths):
    if not preprocessed_spectra or not selected_filenames:
        return pd.DataFrame(columns=["Filename", "Predicted Class", "Probabilities"])

    predictions = []
    file_name_to_index = {os.path.basename(p): idx for idx, p in enumerate(file_paths)}

    for fname in selected_filenames:
        i = file_name_to_index.get(fname)
        if i is None:
            continue
        spectrum = preprocessed_spectra[i]
        features = spectrum.intensity.reshape(1, -1)

        # Dummy logic – replace with real model predictions
        if model_choice == "MLP":
            pred_probs = mlp_model.predict(features)[0]
        else:
            pred_probs = xgb_model.predict_proba(features)[0]

        pred_class = int(np.argmax(pred_probs))
        prob_str = ", ".join(f"{p:.4f}" for p in pred_probs)

        predictions.append({
            "Filename": fname,
            "Predicted Class": pred_class,
            "Probabilities": prob_str
        })

    return pd.DataFrame(predictions)

In [None]:
# --- Gradio UI ---
with gr.Blocks(css="#scrollable-checkbox {max-height: 300px; overflow-y: auto; border: 1px solid #ccc; padding: 10px;}") as demo:
    gr.Markdown("## MALDI-TOF MS Tool: Preprocessing, Visualization, and Species Prediction")

    def list_txt_files():
        try:
            return [f for f in os.listdir(DATA_DIR) if f.endswith(".txt")]
        except FileNotFoundError:
            return []

    file_selector = gr.CheckboxGroup(
        choices=list_txt_files(),
        label="Select spectra files from Google Drive folder",
        elem_id="scrollable-checkbox"
    )

    with gr.Row():
        var_stabilizer = gr.Checkbox(label="Variance Stabilizer (sqrt)", value=True)
        smoother = gr.Checkbox(label="Smoother (halfwindow=10)", value=True)
        baseline = gr.Checkbox(label="Baseline Correction (SNIP)", value=True)
        normalizer = gr.Checkbox(label="Normalization (sum=1)", value=True)
        binner = gr.Checkbox(label="Binner (step=3)", value=True)

    selected_files = gr.CheckboxGroup(choices=[], label="Select spectra to plot/predict (by filename)")

    def update_file_choices(filenames):
        return gr.update(choices=filenames, value=[])

    file_selector.change(update_file_choices, inputs=file_selector, outputs=selected_files)

    preprocessed_spectra_state = gr.State()
    file_paths_state = gr.State()

    preprocess_button = gr.Button("Run Preprocessing")
    preprocessing_output = gr.Textbox(label="Preprocessing Status")

    preprocess_button.click(
        fn=preprocess_spectra,
        inputs=[file_selector, var_stabilizer, smoother, baseline, normalizer, binner],
        outputs=[preprocessing_output, preprocessed_spectra_state, file_paths_state]
    )

    plot_button = gr.Button("Plot Selected Spectra")
    plot_output = gr.Plot(label="Processed Spectra")

    plot_button.click(
        fn=plot_preprocessed_spectra,
        inputs=[preprocessed_spectra_state, file_paths_state, selected_files],
        outputs=plot_output
    )

    model_selector = gr.Dropdown(choices=["MLP", "XGBoost"], label="Select ML Model", value="MLP")
    predict_button = gr.Button("Run Prediction")

    prediction_output = gr.Dataframe(
        label="Predicted Species",
        headers=["Filename", "Predicted Class", "Probabilities"],
        wrap=True
    )

    predict_button.click(
        fn=predict_species_from_spectra,
        inputs=[preprocessed_spectra_state, selected_files, model_selector, file_paths_state],
        outputs=prediction_output
    )

demo.launch(share=True, debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://8ddf5feb4e9c8ce218.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step
