In [None]:
import numpy as np
import numba as nb
import msiwarp as mx
import time
import multiprocessing as mp

from tqdm import tqdm
from msiwarp.util.warp import peak_density_mz
from bisect import bisect_left, bisect_right
from scipy.signal import savgol_filter, find_peaks
from scipy.interpolate import make_interp_spline

from src.psalign.mass_dispersion import get_mass_dispersion

def printf(x):
    print(f'Average mass dispersion [ppm]: {np.format_float_positional(x[0], 2)}')
    print(f'Median mass dispersion [ppm]: {np.format_float_positional(x[1], 2)}')
    print(f'Average cosine similarity: {np.format_float_positional(x[2], 4)}')

def interp(X, x, y):
    f = make_interp_spline(x, y, k=1)
    return f(X).astype(X.dtype)

def cossim(a, b):
    return np.dot(a / np.linalg.norm(a), b / np.linalg.norm(b))

def perform_pwl_warping(mz_vector, data, reference, warping_knots):
    # invert
    interpolated_mz = interp(mz_vector, np.array([w[0] for w in warping_knots]), np.array([w[0] + w[1] for w in warping_knots]))
    invert = binning(mz_vector, interpolated_mz, data)
    interpolated_mz = interp(mz_vector, np.array([w[0] for w in warping_knots]), np.array([w[0] - w[1] for w in warping_knots]))
    not_invert = binning(mz_vector, interpolated_mz, data)
    
    if cossim(reference, invert) > cossim(reference, not_invert):
        return invert
    return not_invert

@nb.njit
def binning(x_correct, x_wrong, values):
    
    index_min = np.argmin(np.abs(x_wrong - x_correct[0]))
    if x_wrong[index_min] > x_correct[0]:
        index_min -= 1
        index_min = index_min if index_min > 0 else 0
    index_max = np.argmin(np.abs(x_wrong - x_correct[-1]))
    if x_wrong[index_max] < x_correct[-1]:
        index_max += 1
        index_max = index_max if index_max < x_wrong.shape[0] else x_wrong.shape[0]
        
    result = np.zeros_like(x_correct, dtype=values.dtype)
    
    values = values[index_min: index_max + 1]
    x_wrong = x_wrong[index_min: index_max + 1]
    
    indices = np.searchsorted(x_correct, x_wrong, side='right') - 1
    
    idxs = indices < 0
    if idxs.sum() > 0:
        result[0] = np.dot(values[idxs], x_wrong[idxs] - x_correct[0]) / (x_correct[0] - x_correct[1])
    
    idxs = np.logical_and(0 <= indices, indices < x_correct.shape[0] - 1)
    if idxs.sum() > 0:
        temp = indices[idxs]
        shifted = temp + 1
        factor = np.divide(x_wrong[idxs] - x_correct[temp], x_correct[shifted] - x_correct[temp])
        vals = values[idxs]
        for i, index in enumerate(temp):
            result[index] += vals[i] * (1 - factor[i])
            result[index + 1] += vals[i] * factor[i]
    
    idxs = indices >= x_correct.shape[0] - 1
    if idxs.sum() > 0:
        temp = indices[idxs]
        result[-1] += np.dot(values[idxs], np.divide(x_wrong[idxs] - x_correct[temp], x_correct[temp] - x_correct[temp - 1]))
                
    return result

def centroid_tof(spectrum, peak_threshold, window_size = 11, order = 2, distance=None):
    intensity_golay = savgol_filter(np.array(spectrum[1]), window_size, order)
    intensity_golay2 = savgol_filter(intensity_golay, 11, order)
    # Centroid smoothed signal
    (mz_c, intensity_c) = parabolic_centroid(np.array(spectrum[0]), intensity_golay2, peak_threshold, distance)
    return (mz_c, intensity_c)

def parabolic_centroid(mzs, intensities, peak_threshold, distance=None):
    peak_indices, _ = find_peaks(intensities, height=peak_threshold, distance=distance)
    peak_left = peak_indices - 1
    peak_right = peak_indices + 1
    
    n = len(peak_indices)
    
    X = np.zeros((n, 3))
    Y = np.zeros((n, 3))
    
    X[:,0] = mzs[peak_left]
    X[:,1] = mzs[peak_indices]
    X[:,2] = mzs[peak_right]
    
    Y[:,0] = intensities[peak_left]
    Y[:,1] = intensities[peak_indices]
    Y[:,2] = intensities[peak_right]
    
    a = ((Y[:,2] - Y[:,1]) / (X[:,2] - X[:,1]) - 
         (Y[:,1] - Y[:,0]) / (X[:,1] - X[:,0])) / (X[:,2] - X[:,0])
    
    b = ((Y[:,2] - Y[:,1]) / (X[:,2] - X[:,1]) * (X[:,1] - X[:,0]) + 
         (Y[:,1] - Y[:,0]) / (X[:,1] - X[:,0]) * (X[:,2] - X[:,1])) / (X[:,2] - X[:,0])             

    mzs_parabolic = ((1/2) * (-b + 2 * a * X[:,1]) / a)
    intensities_parabolic = (a * (mzs_parabolic - X[:,1]) ** 2 +
                             b * (mzs_parabolic - X[:,1]) + Y[:,1])
    
    return (mzs_parabolic, intensities_parabolic)

In [None]:
# Download file from https://www.ebi.ac.uk/pride/archive/projects/PXD013069
sample = 'drugtreatedspheroids-nonormalization'

path = '<path_to_data>/'

In [3]:
# Conversion happens when tof_spheroids.ipynb is run
file = np.load(f"{path}/{sample}.npz")
mz_vector = file['axis']
row2grid = file['location']
maldi = file['data']

del file

nb_peaks = 50
start_maldi_alignment = 800
stop_maldi_alignment = 4500
start_index_alignment = np.max([bisect_left(mz_vector, start_maldi_alignment) - 1, 0])
stop_index_alignment = np.min([bisect_right(mz_vector, stop_maldi_alignment), mz_vector.shape[0] - 1])
maldi = maldi[:, start_index_alignment: stop_index_alignment + 1]
mz_vector = mz_vector[start_index_alignment: stop_index_alignment + 1]

l2 = np.linalg.norm(maldi, axis=1)
tic = np.sum(maldi, axis=1)
for i in range(maldi.shape[0]):
    maldi[i, :] /= maldi[i, :].sum()

distance = 5
threshold = 99.9
reference = maldi[np.argmax(tic)]

maldi = np.load(f"{path}/{sample}.npz")['data'][:, start_index_alignment: stop_index_alignment + 1]

# scaling to test impact of sigma on alignment performance
sigma_1 = 4e-5
epsilon = 2.55
bandwidth = 100
slack = 2.0 * epsilon * sigma_1

a = np.stack([mz_vector, reference], axis=0)

refe = centroid_tof(a, peak_threshold=np.percentile(reference, threshold), distance=distance)

spect_matrix = []

for i in range(maldi.shape[0]):
    a = np.stack([mz_vector, maldi[i, :]], axis=0)

    spect_matrix.append(centroid_tof(a, peak_threshold=np.percentile(maldi[i, :], threshold), distance=distance))
    
size = list(map(lambda x: len(x[0]), spect_matrix))
print("Average number of peaks:", np.format_float_positional(np.mean(size), 2))
    
size = maldi.shape
dtype = maldi.dtype
del maldi

ref = [mx.peak(i, mz_i, h_i, sigma_1 * mz_i) for i, (mz_i, h_i) in enumerate(zip(refe[0], refe[1]))]
spec_matrix = [[mx.peak(i, mz_i, h_i, sigma_1 * mz_i) for i, (mz_i, h_i) in enumerate(zip(spect_m[0], spect_m[1]))] for spect_m in spect_matrix]
del spect_matrix

# ---------- find peak dense regions across data set spectra ----------
mz_begin = start_maldi_alignment
mz_end = stop_maldi_alignment
xi = np.linspace(mz_begin, mz_end, 4000)
(yi, xp, yp) = peak_density_mz(spec_matrix, xi, bandwidth=bandwidth, stride=5)

# we're using the same warping nodes for all spectra here
node_mzs = (xp[:-1] + xp[1:]) / 2
node_mzs = np.array([mz_begin, *node_mzs, mz_end])

# setup warping parameters 
n_steps = 50 # the slack of a warping node is +- (n_steps * s * sigma @ the node's m/z)

node_slacks = np.array([slack * mz for mz in node_mzs])
nodes = mx.initialize_nodes(node_mzs, node_slacks, n_steps)

Average number of peaks: 21.29


In [4]:
t0 = time.time()
warping_functions = mx.find_optimal_spectra_warpings(spec_matrix, ref, nodes, epsilon)
warped_spectra = [mx.warp_peaks(s_i, nodes, o_i) for (s_i, o_i) in zip(spec_matrix, warping_functions)]
t1 = time.time()
print("time: {:0.2f}s".format(t1 - t0))

del spec_matrix
del warped_spectra

time: 1.48s


The computation time is very low for MSIWarp if the number of peaks present are low.

In [5]:
warping_funcs = [[(nodes[o].mz, nodes[o].mz_shifts[f[o]]) for o in range(len(nodes))] for f in warping_functions]

with mp.Pool(processes=mp.cpu_count() // 2) as pool:
    
    maldi = np.load(f"{path}/{sample}.npz")['data'][:, start_index_alignment: stop_index_alignment + 1]
    res = np.empty_like(maldi)
    
    print("Before alignment:")
    printf(get_mass_dispersion(maldi, mz_vector, pool=pool, nb_of_peaks=nb_peaks))
    
    for i in tqdm(range(maldi.shape[0])):
        res[i] = perform_pwl_warping(mz_vector, maldi[i, :], reference, warping_funcs[i])
    
    print("\nAfter alignment:")
    printf(get_mass_dispersion(res, mz_vector, pool=pool, nb_of_peaks=nb_peaks))

Before alignment:
Average mass dispersion [ppm]: 19.53
Median mass dispersion [ppm]: 19.02
Average cosine similarity: 0.6532


100%|██████████| 1201/1201 [00:21<00:00, 57.05it/s]



After alignment:
Average mass dispersion [ppm]: 12.88
Median mass dispersion [ppm]: 12.22
Average cosine similarity: 0.7941
