In [None]:
import numpy as np
import msiwarp as mx
import pandas as pd
import multiprocessing as mp
import numba as nb

from pyimzml.ImzMLParser import ImzMLParser
from scipy.interpolate import make_interp_spline
from tqdm import tqdm
from msiwarp.util.warp import to_mx_peaks, to_height, generate_mean_spectrum, to_mz, dispersion_triplets

from src.psalign.mass_dispersion import get_mass_dispersion

# setup input and output filepaths
# Download file from https://www.omicsdi.org/dataset/metabolights_dataset/MTBLS289
fdir = '<path_to_data>'
sample = 'A52 CT S3-centroid.imzML'
imzml_path = fdir + sample
fpath_triplets_raw = fdir + 'triplets_raw.dat'
fpath_triplets_warped = fdir + 'triplets_warped.dat'
fpath_dispersion_csv = fdir + 'results/dispersion_100.csv'
fpath_scatter = fdir + 'results/scatter'

sigma_1 = 3.0e-7
epsilon = 1.0
slack = 2.0 * epsilon * sigma_1
instrument_type = 'orbitrap'

mz_begin = 200
mz_end = 1000

In [2]:
spectra = []
p = ImzMLParser(imzml_path)
idxs_centroid = np.load(fdir + 'idxs_centroid.npy')

for idx in idxs_centroid:
    mzs, hs = p.getspectrum(idx)    
    spectra.append(to_mx_peaks(mzs, hs,
                               sigma_1, id = idx,
                               instrument_type=instrument_type))
    
n_peaks = np.array([len(s) for s in spectra])
tic = np.array([np.sum(to_height(s)) for s in spectra])

  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


In [3]:
n_steps = 33
n_peaks = 30
max_n_nodes = 8

params = mx.params_uniform(mx.Instrument.Orbitrap,
                           n_steps,
                           n_peaks,
                           max_n_nodes,
                           mz_begin,
                           mz_end,
                           slack)

# --------- set reference spectrum ----------
i_r = np.argmax(tic)
s_ref = spectra[i_r]
print(len(s_ref), "peaks in reference spectrum")

1058 peaks in reference spectrum


In [4]:
import time

n_cores = 8

t0 = time.time()
warping_funcs = mx.find_optimal_warpings_uni(spectra, s_ref, params, epsilon, n_cores)
t1 = time.time()
print("found optimal warpings in {:0.2f} seconds".format(t1 - t0))

t2 = time.time()
warped_spectra = [mx.warp_peaks_unique(s_i, r_i) for (s_i, r_i) in zip(spectra, warping_funcs)]
t3 = time.time()
print("warped spectra in {:0.2f}s".format(t3 - t2))

found optimal warpings in 343.38 seconds
warped spectra in 14.78s


In [5]:
# ---------- mean spectrum ----------
n_points = 2000000
s_m = generate_mean_spectrum(warped_spectra, n_points, sigma_1,
                             mz_begin, mz_end, tic, instrument_type)

s_m_1000 = mx.peaks_top_n(s_m, 1000) # returns peak list sorted by intensity, not m/z
s_r = sorted(s_m_1000, key=lambda peak: peak.mz)

s_m_100 = mx.peaks_top_n(s_m, 100)
mz_ref = np.sort(to_mz(s_m_100))

  func = lambda x: x ** (3/2)


generating mean spectrum with 1042346 sampling points...
generated mean spectrum


In [6]:
# ---------- compute mass dispersions around mean spectrum ----------
dispersion_raw = np.zeros(len(mz_ref))
dispersion_warped = np.zeros(len(mz_ref))

mass_tolerance = 4 # ppm
    
for i, mz_i in enumerate(mz_ref):
    d = mass_tolerance * mz_i / 1e6 # -+ 350 ppm around reference mass
    mz0 = mz_i - d
    mz1 = mz_i + d
    
    ts_raw = mx.get_triplets_range(fpath_triplets_raw, mz0, mz1)
    ts_warped = mx.get_triplets_range(fpath_triplets_warped, mz0, mz1)
    
    q = 0.0 # remove background signal
    if len(ts_raw) > 0:
        dispersion_raw[i] = dispersion_triplets(ts_raw,  q)
    if len(ts_warped) > 0:  
        dispersion_warped[i] = dispersion_triplets(ts_warped, q)


d = {'mz': mz_ref,
     'dispersion raw [ppm]': dispersion_raw,
     'dispersion warped [ppm]': dispersion_warped}

df = pd.DataFrame(d)
df.round(4).to_csv(fpath_dispersion_csv, index=False)

print('median mass dispersion raw: {:0.4f}'.format(np.median(dispersion_raw)))
print('median mass dispersion warped: {:0.4f}'.format(np.median(dispersion_warped)))
print('mean mass dispersion raw: {:0.4f}'.format(np.mean(dispersion_raw)))
print('mean mass dispersion warped: {:0.4f}'.format(np.mean(dispersion_warped)))

median mass dispersion raw: 0.6134
median mass dispersion warped: 0.2883
mean mass dispersion raw: 0.6428
mean mass dispersion warped: 0.3083


In [7]:
@nb.njit
def binning(x_correct, x_wrong, values):
    
    index_min = np.argmin(np.abs(x_wrong - x_correct[0]))
    if x_wrong[index_min] > x_correct[0]:
        index_min -= 1
        index_min = index_min if index_min > 0 else 0
    index_max = np.argmin(np.abs(x_wrong - x_correct[-1]))
    if x_wrong[index_max] < x_correct[-1]:
        index_max += 1
        index_max = index_max if index_max < x_wrong.shape[0] else x_wrong.shape[0]
        
    result = np.zeros_like(x_correct, dtype=values.dtype)
    
    values = values[index_min: index_max + 1]
    x_wrong = x_wrong[index_min: index_max + 1]
    
    indices = np.searchsorted(x_correct, x_wrong, side='right') - 1
    
    idxs = indices < 0
    if idxs.sum() > 0:
        result[0] = np.dot(values[idxs], x_wrong[idxs] - x_correct[0]) / (x_correct[0] - x_correct[1])
    
    idxs = np.logical_and(0 <= indices, indices < x_correct.shape[0] - 1)
    if idxs.sum() > 0:
        temp = indices[idxs]
        shifted = temp + 1
        factor = np.divide(x_wrong[idxs] - x_correct[temp], x_correct[shifted] - x_correct[temp])
        vals = values[idxs]
        for i, index in enumerate(temp):
            result[index] += vals[i] * (1 - factor[i])
            result[index + 1] += vals[i] * factor[i]
    
    idxs = indices >= x_correct.shape[0] - 1
    if idxs.sum() > 0:
        temp = indices[idxs]
        result[-1] += np.dot(values[idxs], np.divide(x_wrong[idxs] - x_correct[temp], x_correct[temp] - x_correct[temp - 1]))
                
    return result

def interp(X, x, y):
    f = make_interp_spline(x, y, k=1)
    return f(X).astype(X.dtype)

def perform_pwl_warping(mz_vector, data, warping_knots):
    interpolated_mz = interp(mz_vector, np.array([w[0] for w in warping_knots]), np.array([w[0] + w[1] for w in warping_knots]))
    return binning(mz_vector, interpolated_mz, data)

def printf(x):
    print(f'Average mass dispersion [ppm]: {np.format_float_positional(x[0], 2)}')
    print(f'Median mass dispersion [ppm]: {np.format_float_positional(x[1], 2)}')
    print(f'Average cosine similarity: {np.format_float_positional(x[2], 4)}')

In [None]:
# Apply warping functions to profile data and perform mass dispersion analysis
file = np.load(imzml_path.replace("centroid", "profile").replace(".imzML", ".npz"))
idxs_profile = np.load(fdir + 'idxs_profile.npy')
maldi = file['data'][idxs_profile]
mz_vector = file['axis']

res = np.empty_like(maldi)

with mp.Pool(processes=mp.cpu_count()) as pool:
    
    printf(get_mass_dispersion(maldi, mz_vector, pool=pool))
    
    for i in tqdm(range(maldi.shape[0])):
        res[i] = perform_pwl_warping(mz_vector, maldi[i], warping_funcs[i])

    printf(get_mass_dispersion(res, mz_vector, pool=pool))

Average mass dispersion [ppm]: 0.85
Median mass dispersion [ppm]: 0.65
Average cosine similarity: 0.9103


100%|██████████| 18270/18270 [00:12<00:00, 1431.98it/s]


Average mass dispersion [ppm]: 0.52
Median mass dispersion [ppm]: 0.33
Average cosine similarity: 0.9127
