## GPU/CPU Optimization on search

In [1]:
from alphapept.search import get_psms

import alphapept.io

ms_file = 'E:/test_temp/thermo_HeLa.ms_data.hdf'
ms_file_ = alphapept.io.MS_Data_File(f"{ms_file}")


db_data_path = 'E:/test_temp/database.hdf'

#         TODO calibrated_fragments should be included in settings
query_data = ms_file_.read_DDA_query_data(
    calibrated_fragments=True,
    database_file_name=db_data_path
)

features = ms_file_.read(dataset_name="features")

In [2]:
%%time

psms, num_specs_compared = get_psms(query_data, db_data_path, features, parallel = True, m_tol = 20, m_offset=20, ppm=True, min_frag_hits = 6)

Wall time: 24.1 s


In [3]:
from alphapept.fasta import read_database
import numpy as np
from alphapept.search import get_idxs

In [4]:
from alphapept.search import compare_specs_single

In [6]:
#export 

# A decorator for writing GPU/CPU agnostic code
import multiprocessing
import threading
import functools
import math
import numba as numba_
numba = numba_
import numpy as np
from numba import cuda as cuda_
cuda = cuda_
from numba import njit

try:
    import cupy
    jit_fun = cuda.jit(device=True) #Device Function
except ModuleNotFoundError:
    import numpy as cupy
    jit_fun = njit
    
@numba.njit
def grid_1d(x): return -1
@numba.njit
def grid_2d(x): return -1, -1


def set_cuda_grid(dimensions=0):
    global cuda
    if dimensions == 0:
        cuda = cuda_
        cuda.grid = cuda_.grid
    if dimensions == 1:
        cuda = numba_
        cuda.grid = grid_1d
    if dimensions == 2:
        cuda = numba_
        cuda.grid = grid_2d
      
def parallel_compiled_func(
    _func=None,
    *,
    cpu_threads=None,
    dimensions=1,
):
    set_cuda_grid()
    if dimensions not in (1, 2):
        raise ValueError("Only 1D and 2D are supported")
    if cpu_threads is not None:
        use_gpu = False
    else:
        try:
            cuda.get_current_device()
        except cuda.CudaSupportError:
            use_gpu = False
            cpu_threads = 0
        else:
            use_gpu = True
        try:
            import cupy
        except ModuleNotFoundError:
            use_gpu = False
            cpu_threads = 0

    if use_gpu:
        set_cuda_grid()
        def parallel_compiled_func_inner(func):
            cuda_func = cuda.jit(func)
            if dimensions == 1:
                def wrapper(iterable_1d, *args):
                    cuda_func.forall(iterable_1d.shape[0], 1)(
                        -1,
                        iterable_1d,
                        *args
                    )
            elif dimensions == 2:
                def wrapper(iterable_2d, *args):
                    threadsperblock = (
                        min(iterable_2d.shape[0], 16),
                        min(iterable_2d.shape[0], 16)
                    )
                    blockspergrid_x = math.ceil(
                        iterable_2d.shape[0] / threadsperblock[0]
                    )
                    blockspergrid_y = math.ceil(
                        iterable_2d.shape[1] / threadsperblock[1]
                    )
                    blockspergrid = (blockspergrid_x, blockspergrid_y)
                    cuda_func[blockspergrid, threadsperblock](
                        -1,
                        -1,
                        iterable_2d,
                        *args
                    )
            return functools.wraps(func)(wrapper)
    else:
        set_cuda_grid(dimensions)
        if cpu_threads <= 0:
            cpu_threads = multiprocessing.cpu_count()
        def parallel_compiled_func_inner(func):
            numba_func = numba.njit(nogil=True)(func)
            if dimensions == 1:
                def numba_func_parallel(
                    thread,
                    iterable_1d,
                    *args
                ):
                    for i in range(
                        thread,
                        len(iterable_1d),
                        cpu_threads
                    ):
                        numba_func(i, iterable_1d, *args)
            elif dimensions == 2:
                def numba_func_parallel(
                    thread,
                    iterable_2d,
                    *args
                ):
                    for i in range(
                        thread,
                        iterable_2d.shape[0],
                        cpu_threads
                    ):
                        for j in range(iterable_2d.shape[1]):
                            numba_func(i, j, iterable_2d, *args)
            numba_func_parallel = numba.njit(nogil=True)(numba_func_parallel)
            def wrapper(iterable, *args):
                threads = []
                for thread_id in range(cpu_threads):
                    t = threading.Thread(
                        target=numba_func_parallel,
                        args=(thread_id, iterable, *args)
                    )
                    t.start()
                    threads.append(t)
                for t in threads:
                    t.join()
                    del t
            return functools.wraps(func)(wrapper)
    if _func is None:
        return parallel_compiled_func_inner
    else:
        return parallel_compiled_func_inner(_func)

In [7]:
@jit_fun
def compare_frags(query_frag, db_frag, mtol, ppm=False):
    """
    Compare query and database frags and find hits
    """
    q_max = len(query_frag)
    d_max = len(db_frag)
    hits = np.zeros(d_max, dtype=np.int16)
    q, d = 0, 0  # q > query, d > database
    while q < q_max and d < d_max:
        mass1 = query_frag[q]
        mass2 = db_frag[d]
        delta_mass = mass1 - mass2

        if ppm:
            sum_mass = mass1 + mass2
            mass_difference = 2 * delta_mass / sum_mass * 1e6
        else:
            mass_difference = delta_mass

        if abs(mass_difference) <= mtol:
            hits[d] = q + 1  # Save query position +1 (zero-indexing)
            d += 1
            q += 1  # Only one query for each db element
        elif delta_mass < 0:
            q += 1
        elif delta_mass > 0:
            d += 1

    return hits

In [8]:
@parallel_compiled_func
def search(idx, search_idx, query_indices, query_frags, db_frags, db_bounds, query_masses, frag_hits, mtol = 20, ppm=True):
    if idx == -1:
        x = cuda.grid(1)
    else:
        x = idx
        
    query_idx, db_idx = search_idx[x]
         
    query_idx_start = query_indices[query_idx]
    query_idx_end = query_indices[query_idx + 1]
    query_frag = query_frags[query_idx_start:query_idx_end]
    db_frag = db_frags[:, db_idx] [: db_bounds[db_idx] ]
    o_mass = query_masses[query_idx]  - db_masses[db_idx]
    hits = compare_frags(query_frag, db_frag, mtol, ppm)
    frag_hits[query_idx, db_idx - idxs_lower[query_idx] ] = cupy.sum(hits > 0)

In [9]:
@jit_fun
def get_search_idx(idxs_higher, idxs_lower, query_masses):
    n_comparisons = cupy.zeros((cupy.sum(idxs_higher-idxs_lower),2))
    i = 0
    for query_idx in range(len(query_masses)):
        for db_idx in range(idxs_lower[query_idx] , idxs_higher[query_idx] ):
            n_comparisons[i,0] = query_idx
            n_comparisons[i,1] = db_idx
            i+=1
    return n_comparisons
        
        

In [10]:
db_data = db_data_path
m_offset_calibrated = False

m_offset = 20
m_tol = 20
ppm = True


if isinstance(db_data, str):
    db_masses = read_database(db_data, array_name = 'precursors')
    db_frags = read_database(db_data, array_name = 'fragmasses')
    db_bounds = read_database(db_data, array_name = 'bounds')
else:
    db_masses = db_data['precursors']
    db_frags = db_data['fragmasses']
    db_bounds = db_data['bounds']

query_indices = query_data["indices_ms2"]
query_bounds = query_data['bounds']
query_frags = query_data['mass_list_ms2']

if features is not None:
    if m_offset_calibrated:
        m_offset = m_offset_calibrated
        query_masses = features['corrected_mass'].values
    else:
        query_masses = features['mass_matched'].values
    query_mz = features['mz_matched'].values
    query_rt = features['rt_matched'].values
    query_bounds = query_bounds[features['query_idx'].values]
    query_selection = features['query_idx'].values
    indices = np.zeros(len(query_selection) + 1, np.int64)
    indices[1:] = np.diff(query_indices)[query_selection]
    indices = np.cumsum(indices)
    query_frags = np.concatenate(
        [
            query_frags[s: e] for s, e in zip(
                query_indices[query_selection], query_indices[query_selection + 1]
            )
        ]
    )
    query_indices = indices
else:
    if m_offset_calibrated:
        m_offset = m_offset_calibrated
    query_masses = query_data['prec_mass_list2']
    query_mz = query_data['mono_mzs2']
    query_rt = query_data['rt_list_ms2']

#     idxs_lower, idxs_higher = get_idxs(db_masses, query_masses, m_offset, ppm)
idxs_lower, idxs_higher = get_idxs(
    db_masses,
    query_masses,
    m_offset,
    ppm
)
frag_hits = np.zeros(
    (len(query_masses), np.max(idxs_higher - idxs_lower)), dtype=int
)


In [11]:
def get_psms_(
    query_data,
    db_data,
    features,
    m_tol,
    m_offset,
    ppm,
    min_frag_hits,
    callback = None,
    m_offset_calibrated = None,
    **kwargs
):
    """
    Wrapper function to extract psms from dataset

    Args:
        db_masses: database precursor masses
        query_masses: query precursor masses
        m_offset: mass offset in dalton or ppm
        ppm: flag for ppm or dalton
        callback: Callback function, e.g. for progress bar
    Returns:
        idxs_lower: lower search range
        idxs_higher: upper search range
    Raises:
    """

    if isinstance(db_data, str):
        db_masses = read_database(db_data, array_name = 'precursors')
        db_frags = read_database(db_data, array_name = 'fragmasses')
        db_bounds = read_database(db_data, array_name = 'bounds')
    else:
        db_masses = db_data['precursors']
        db_frags = db_data['fragmasses']
        db_bounds = db_data['bounds']
    
    query_indices = query_data["indices_ms2"]
    query_bounds = query_data['bounds']
    query_frags = query_data['mass_list_ms2']

    if features is not None:
        if m_offset_calibrated:
            m_offset = m_offset_calibrated
            query_masses = features['corrected_mass'].values
        else:
            query_masses = features['mass_matched'].values
        query_mz = features['mz_matched'].values
        query_rt = features['rt_matched'].values
        query_bounds = query_bounds[features['query_idx'].values]
        query_selection = features['query_idx'].values
        indices = np.zeros(len(query_selection) + 1, np.int64)
        indices[1:] = np.diff(query_indices)[query_selection]
        indices = np.cumsum(indices)
        query_frags = np.concatenate(
            [
                query_frags[s: e] for s, e in zip(
                    query_indices[query_selection], query_indices[query_selection + 1]
                )
            ]
        )
        query_indices = indices
    else:
        if m_offset_calibrated:
            m_offset = m_offset_calibrated
        query_masses = query_data['prec_mass_list2']
        query_mz = query_data['mono_mzs2']
        query_rt = query_data['rt_list_ms2']
    
#     idxs_lower, idxs_higher = get_idxs(db_masses, query_masses, m_offset, ppm)
    idxs_lower, idxs_higher = get_idxs(
        db_masses,
        query_masses,
        m_offset,
        ppm
    )
    frag_hits = np.zeros(
        (len(query_masses), np.max(idxs_higher - idxs_lower)), dtype=int
    )

    #logging.info(f'Performing search on {len(query_masses):,} query and {len(db_masses):,} db entries with m_tol = {m_tol:.2f} and m_offset = {m_offset:.2f}.')

    search_idx = get_search_idx(idxs_higher, idxs_lower, query_masses).astype(cupy.int)

    search(search_idx, query_indices, query_frags, db_frags, db_bounds, query_masses, frag_hits)

    hit_query, hit_db = cupy.where(frag_hits >= min_frag_hits)
    hits = frag_hits[hit_query, hit_db]
    hit_db += idxs_lower[hit_query]

    psms = cupy.array(
        list(zip(hit_query, hit_db, hits)), dtype=[("query_idx", int), ("db_idx", int), ("hits", int)]
    )

    return psms

In [13]:
min_frag_hits = 6

In [14]:
%%time
get_psms_(
    query_data,
    db_data,
    features,
    m_tol,
    m_offset,
    ppm,
    min_frag_hits)

Wall time: 34.8 s


array([(  2753,      14, 6), (  2753,      15, 6), (  2840,      49, 6),
       ..., (113350, 8958238, 6), (113350, 8958239, 6),
       (113411, 8959373, 7)],
      dtype=[('query_idx', '<i4'), ('db_idx', '<i4'), ('hits', '<i4')])