# Data analysis of MS data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import pickle
from gzip import GzipFile
from pathlib import Path

import numpy as np
import seaborn as sns
from ipywidgets import Button, Output, interact
from jax import random
from matplotlib import pyplot as plt
from numpyro.infer import SVI, Predictive, Trace_ELBO, autoguide
from numpyro.optim import Adam
from scipy.signal import find_peaks
from tqdm.auto import tqdm

sns.set_theme('talk', 'ticks', font='Arial', font_scale=1.0, rc={'svg.fonttype': 'none'})

DATA_HOME = Path('./Paper dataset').resolve()
MODEL_HOME = Path('./src/models')

In [None]:
all_experiments = list(DATA_HOME.glob('**/*/*.pkgz'))
all_experiments.sort(key=lambda p: p.stem)
all_experiments = {p.stem: p for p in all_experiments}
print(list(all_experiments))

In [None]:
import importlib

models = {p.stem: p for p in sorted(MODEL_HOME.resolve().glob('model*.py'), key=lambda p: p.stem)}
models = {p.stem: importlib.import_module(f'src.models.{p.stem}') for p in tqdm(models.values())}
print(list(models))

## Interactive dashboard

In [None]:
plt.ioff()
@interact(sample_id=all_experiments.keys(), yticks=False, start_pct=(0, 100.0), end_pct=(0, 100.0))
def plot_sample(sample_id, yticks, start_pct=0.0, end_pct=100.0):
    data = pickle.load(GzipFile(all_experiments[sample_id], 'rb'))
    N = data['times'].shape[0]
    end_pct = max(start_pct, end_pct)
    start_idx = int(start_pct / 100.0 *  N)
    end_idx = int(end_pct / 100.0 * N)
    tic = data['intensities'].sum(axis=1)
    fig, ax = plt.subplots(figsize=(12.3, 3))
    ax.plot(data['times'][start_idx:end_idx], tic[start_idx:end_idx])
    ax.set(ylabel='Ion current', xlabel='Time (s)')
    if not yticks:
        ax.set_yticks([])
    # save button
    fig.tight_layout()
    save_button = Button(description='Save')
    out_dir = Path(f'out/{sample_id}')
    save_button.on_click(lambda _: (
        out_dir.mkdir(parents=True, exist_ok=True),
        fig.savefig(f'{out_dir}/TIC.svg', bbox_inches='tight', transparent=True)
    ))
    display(fig)
    display(save_button)

In [None]:
plt.ioff()
@interact(
    sample_id=all_experiments.keys(),
    start_pct=(0, 100.0),
    end_pct=(0, 100.0),
    scan_pct=(0, 100, 0.1),
    n_scans=(1, 100, 1),
    peak_threshold_pct=(0, 100),
    distance=(1, 100),
    prominance=(0, 100),
    xmin_pct=(0, 100),
    xmax_pct=(0, 100),
    rotate_labels=False,
    show_tic=True,
    series=False,
    truncate_starttime=False,
)
def plot_sample(
    sample_id,
    scan_pct,
    peak_threshold_pct,
    prominance,
    rotate_labels,
    n_scans=1,
    start_pct=0,
    end_pct=100,
    distance=1,
    xmin_pct=0,
    xmax_pct=100,
    show_tic=True,
    series=False,
    truncate_starttime=False,
):
    data = pickle.load(GzipFile(all_experiments[sample_id], 'rb'))
    times = data['times']
    tmin, tmax = (
    times[-1] * start_pct / 100,
    times[-1] * end_pct / 100,
    )
    tindices = (times >= tmin) & (times <= tmax)
    intensities = data['intensities']
    times = times[tindices]
    intensities = intensities[tindices]
    
    masses = data["masses"]
    xmin, xmax = (
    masses[-1] * xmin_pct / 100,
    masses[-1] * xmax_pct / 100,
    )
    distance = math.ceil(distance / (masses[1] - masses[0]))
    indices = (masses >= xmin) & (masses <= xmax)
    masses = masses[indices]
    scan_index = math.floor(scan_pct / 100 * (len(intensities) - 1))
    scans = intensities[scan_index:(scan_index + n_scans)][:, indices]
    scan = scans.mean(axis=0)
    peak_threshold = peak_threshold_pct / 100 * scan.max()
    peaks = find_peaks(
        scan, height=peak_threshold, distance=distance, prominence=prominance
    )[0]
    if show_tic:
        fig, (tic_ax, scan_ax) = plt.subplots(figsize=(8, 5), nrows=2)
        if truncate_starttime:
            times = times - times[0]
        tic_ax.plot(times, intensities.sum(axis=1))
        tic_ax.set(ylabel="Ion current", xlabel="Time (s)")
        lims = tic_ax.get_ylim()
        tic_ax.vlines(times[scan_index], *lims, color="r", lw=0.5)
        tic_ax.vlines(times[scan_index + n_scans], *lims, color="r", lw=0.5)
        tic_ax.fill_between([times[scan_index], times[scan_index + n_scans]], 0, lims[1], color="r", alpha=0.2)
    else:
        fig, scan_ax = plt.subplots(figsize=(8, 3))
    if series and n_scans > 1:
        cmap = sns.color_palette('crest', as_cmap=True).resampled(n_scans).colors
        for i, s in enumerate(scans):
            scan_ax.plot(masses, s, c=cmap[i])
    else:
        scan_ax.plot(masses, scan)
    scan_ax.scatter(masses[peaks], scan[peaks], color="r", s=30, zorder=10, marker="x", lw=1)
    hairline_height = (scan.max() - scan.min()) * 0.1
    for peak in peaks:
        scan_ax.text(
            masses[peak],
            scan[peak] + hairline_height,
            f"{masses[peak]:.2f}",
            color="r",
            fontsize=8,
            ha="center",
            rotation=90 if rotate_labels else 0,
        )

    scan_ax.set(
        ylabel="Intensity",
        xlabel="m/z",
        ylim=(0, (1.0 if not rotate_labels else 1.3) * scan.max()),
    )
    # save button
    save_button = Button(description="Save")
    fig.tight_layout()
    out_dir = Path(f"out/{sample_id}")
    filename = f"{out_dir}/scan{scan_index}.svg" if n_scans == 1 else f"{out_dir}/scan{scan_index}-{scan_index + n_scans}.svg"
    save_button.on_click(
        lambda _: (
            out_dir.mkdir(parents=True, exist_ok=True),
            fig.savefig(filename, bbox_inches="tight", transparent=True)
            
        )
    )
    display(fig)
    display(save_button)

In [None]:
def bin_masses(data):
    masses = data['masses']
    intensities = data['intensities']
    def bin_intensities(masses, intensities, bins: np.ndarray):
        intensities_binned = np.zeros((intensities.shape[0], len(bins)))
        for i in range(len(bins)):
            if i == 0:
                mask = (masses >= 0) & (masses < bins[i])
            else:
                mask = (masses >= bins[i - 1]) & (masses < bins[i])
            intensities_binned[:, i] = intensities[:, mask].sum(axis=1)
        return intensities_binned

    binned_masses = np.arange(masses.min(), masses.max(), 1.0)
    binned_intensities = bin_intensities(masses, intensities, binned_masses)
    return {
        **data,
        'masses': binned_masses,
        'intensities': binned_intensities,
    }

In [None]:
results = {}
out = Output()
plt.ioff()
def fit_model(model_fn, guide_fn, n_components, n_steps, data):
    colors = sns.color_palette('deep')[:n_components]
    rng_key = random.PRNGKey(0)
    times = data["times"]
    mzs = data["masses"]
    intensities = data["intensities"]
    tic = intensities.sum(axis=1)
    tic /= tic.max()
    fig, (losses_ax, weights_ax, components_ax) = plt.subplots(figsize=(15, 10), nrows=3)
    weights_ax.set_prop_cycle(plt.cycler(color=colors))
    components_ax.set_prop_cycle(plt.cycler(color=colors))

    svi = SVI(
        model=model_fn,
        guide=guide_fn,
        optim=Adam(lambda n: 0.005),
        loss=Trace_ELBO(),
    )
    with out:
        svi_result = svi.run(rng_key, n_steps, mzs, intensities, n_components, stable_update=True)
    losses_ax.plot(np.log(svi_result.losses))
    params = svi_result.params
    trace = Predictive(guide_fn, params=params, num_samples=25)(rng_key, mzs)
    deterministics = Predictive(model_fn, guide=guide_fn, params=params, num_samples=25)(rng_key, mzs, n_components=n_components)
    trace = {**trace, **deterministics}
    for c in trace['component_weights']:
        weights_ax.plot(times, c / c.max() + 1.2 * np.arange(n_components), alpha=0.1)
    weights_ax.plot(times, tic + n_components * 1.2, color='k')
    weights_ax.hlines(1.2 * np.arange(n_components + 1), *weights_ax.get_xlim(), color='k', lw=0.5, ls='--')
    weights_ax.set(
        xlabel='Time (s)',
        ylabel='Contribution',
        yticks=[],
    )

    for intensity in trace['component_intensities'].transpose(0, 2, 1):
        intensity = np.log1p(intensity)
        components_ax.plot(mzs, intensity / intensity.max(axis=0) + 1.2 * np.arange(n_components), alpha=0.1)
    components_ax.set(
        xlabel='m/z',
        ylabel='Intensity',
        yticks=[],
    )
    components_ax.hlines(1.2 * np.arange(n_components), *components_ax.get_xlim(), color='k', lw=0.5, ls='--')

    save_button = Button(description='Save')
    out_dir = Path(f'out/{data["sample_id"]}/{data["model"]}')
    save_button.on_click(lambda _: (
        out_dir.mkdir(parents=True, exist_ok=True),
        fig.savefig(f'{out_dir}/inference_outcome_{n_components}_{n_steps}.svg', bbox_inches='tight', transparent=True)
    ))
        
    fig.tight_layout()
    with out:
        display(fig)
        display(save_button)
    global results
    results = locals()


@interact(sample_id=all_experiments.keys(), model=models.keys(), start_pct=(0, 100.0),
    end_pct=(0, 100.0), n_steps=(1000, 100000), n_components=(2, 10), auto_guide=True)
def analyze(sample_id, model='model13', n_steps=10000, n_components=5, auto_guide=True, start_pct=0.0, end_pct=100.0):
    out.clear_output()
    data = pickle.load(GzipFile(all_experiments[sample_id], 'rb'))
    data = bin_masses(data)
    times = data['times']
    tmin, tmax = (
    times[-1] * start_pct / 100,
    times[-1] * end_pct / 100,
    )
    tindices = (times >= tmin) & (times <= tmax)
    intensities = data['intensities']
    data['times'] = times[tindices]
    data['intensities'] = intensities[tindices]
    model_fn = models[model].ms_model
    guide_fn = getattr(models[model], "ms_model_guide", None)
    if auto_guide or guide_fn is None:
        guide_fn = autoguide.AutoNormal(model_fn)
    fit_button = Button(description="Fit")
    fit_button.on_click(lambda _: fit_model(model_fn, guide_fn, n_components, n_steps, {**data, 'sample_id': sample_id, 'model': model}))
    display(fit_button)

display(out)

In [None]:
plt.ioff()
includes = {f'component{i}': True for i in range(results['n_components'])}
@interact(
    peak_threshold_pct=(0, 100),
    distance=(1, 100),
    xmin_pct=(0, 100),
    xmax_pct=(0, 100),
    rotate_labels=False,
    **includes
)
def plot_sample(
    peak_threshold_pct,
    rotate_labels,
    distance=1,
    xmin_pct=0,
    xmax_pct=100,
    **includes
):
    masses = results['mzs']
    times = results['times']
    xmin, xmax = (
        masses[-1] * xmin_pct / 100,
        masses[-1] * xmax_pct / 100,
    )
    distance = math.ceil(distance / (masses[1] - masses[0]))
    indices = (masses >= xmin) & (masses <= xmax)
    masses = masses[indices]
    scans = results['trace']['component_intensities'].mean(axis=0)[:, indices]
    weights = results['trace']['component_weights'].mean(axis=0)
    print(weights.shape)
    fig_height = sum(includes.values()) + 1
    fig, (ax, chromatogram) = plt.subplots(ncols=2, figsize=(10 * 2, fig_height), sharey=True)
    j = 0
    cmap = sns.color_palette()
    
    for i, scan in enumerate(scans):
        scan = np.maximum(scan, 1e-3)
        scan = np.log(scan)
        scan -= scan.min()
        if not includes[f'component{i}']:
            continue
        scan = scan / scan.max()
        weight = weights[:, i]
        weight /= weight.max()
        peak_threshold = peak_threshold_pct / 100
        peaks = find_peaks(
            scan, height=peak_threshold, distance=distance
        )[0]
        ax.plot(masses, scan + 1.2 * j, c=cmap[i])
        chromatogram.plot(times, weight + 1.2 * j, c=cmap[i])
        ax.scatter(masses[peaks], scan[peaks] + 1.2*j, color="k", s=30, zorder=10, marker="x", lw=1)
        for peak in peaks:
            # ax.axvline(masses[peak], ymin=scan[peak], ymax=1.1*scan[peak]+hairline_height, color='k', linewidth=0.5)
            ax.text(
                masses[peak],
                scan[peak] + 1.2 * j + (0.1 if rotate_labels else 0.05),
                f"{masses[peak]:.0f}",
                color="k",
                fontsize=8,
                ha="center",
                rotation=90 if rotate_labels else 0,
            )
        j +=1

    ax.hlines(1.2 * np.arange(j), *ax.get_xlim(), color='k', lw=0.5, ls='--')
    chromatogram.hlines(1.2 * np.arange(j), *chromatogram.get_xlim(), color='k', lw=0.5, ls='--')
    
    ax.set(
        ylabel="Intensity",
        xlabel="m/z",
        yticks=[],
    )

    chromatogram.set(
        xlabel="Time (s)",
    )
    # save button
    save_button = Button(description="Save")
    fig.tight_layout()
    sample_id = results['data']['sample_id']
    model_name = results['data']['model']
    out_dir = Path(f"out/{sample_id}/{model_name}")
    save_button.on_click(
        lambda _: (
            out_dir.mkdir(parents=True, exist_ok=True),
            fig.savefig(f"{out_dir}/peaks_{results['n_components']}_{results['n_steps']}.svg", bbox_inches="tight", transparent=True)
        )
    )
    display(fig)
    display(save_button)