In [1]:
import torch
import numpy as np
import dask
import dask.bag as db
from dask.diagnostics import ProgressBar
import pandas as pd
import ms_entropy as me
from typing import Tuple,Callable,Optional,Union,Literal,List

In [44]:
def ms_entropy_similarity(
    query_spec: torch.Tensor, # (n_peaks, 2)
    ref_spec: torch.Tensor, # (n_peaks, 2)
    **kwargs,
) -> torch.Tensor: # zero-dimensional
    sim = me.calculate_entropy_similarity(query_spec, ref_spec, **kwargs)
    return torch.tensor(sim, device=query_spec.device)

@torch.no_grad()
def spec_similarity_search_cpu(
    query: List[torch.Tensor], # List[(n_peaks, 2)]
    ref: List[torch.Tensor], # List[(n_peaks, 2)]
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = ms_entropy_similarity,
    top_k: Optional[int] = None,
    num_dask_workers: int = 4,
    work_device: torch.device = torch.device("cpu"),
    output_device: Optional[torch.device] = None,
    dask_mode: Optional[Literal["threads", "processes", "single-threaded"]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    output_device = output_device or work_device
    top_k = top_k or len(ref)
    
    # 缓冲区模板
    scores_template = torch.full((top_k,), -float('inf'), 
                                device=work_device, dtype=torch.float32)
    indices_template = torch.full((top_k,), -1, 
                                device=work_device, dtype=torch.long)
    
    # 单query搜索闭包
    def _search_single_query(q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        # 初始化缓冲区
        scores_buf = scores_template.clone()
        indices_buf = indices_template.clone()
        current_count = 0  # 有效结果计数器
        
        q_tensor = q.to(work_device)
        
        for r_idx, r_spec in enumerate(ref):
            score = sim_operator(q_tensor, r_spec.to(work_device))

            # 阶段1：缓冲区未满时的快速写入
            if current_count < top_k:
                scores_buf[current_count] = score
                indices_buf[current_count] = r_idx
                current_count += 1

            # 阶段2：缓冲区已满后的条件替换
            else:
                min_idx = torch.argmin(scores_buf)
                if score > scores_buf[min_idx]:  # 只需比较当前最小值
                    # 定点替换
                    scores_buf[min_idx] = score
                    indices_buf[min_idx] = r_idx
        
        # 后处理缓冲区 （排序）
        valid_part = scores_buf[:current_count]
        sorted_idx = torch.argsort(valid_part, descending=True)
        scores_buf[:current_count] = valid_part[sorted_idx]
        indices_buf[:current_count] = indices_buf[:current_count][sorted_idx]

        return scores_buf.to(output_device), indices_buf.to(output_device)

    # Dask并行处理
    query_bag = db.from_sequence(query, npartitions=num_dask_workers)
    query_bag = query_bag.map(_search_single_query)
    results = query_bag.compute(scheduler=dask_mode,num_workers=num_dask_workers)
    
    # 堆叠结果
    results = pd.DataFrame(results,columns=["scores", "indices"])
    scores = torch.stack(results['scores'].tolist())
    indices = torch.stack(results['indices'].tolist())
    
    return scores, indices

In [12]:
queries = [
        torch.tensor([[100.0, 1.0], [200.0, 0.8], [300.0, 0.5]], dtype=torch.float32),
        torch.tensor([[150.0, 0.9], [250.0, 0.7], [350.0, 0.6]], dtype=torch.float32),
        torch.tensor([[150.0, 0.9], [200.0, 0.7], [300.0, 0.6]], dtype=torch.float32),
]

In [13]:
with ProgressBar():
    S,I = spec_similarity_search_cpu(queries, queries, top_k=2, num_dask_workers=2, dask_mode='threads')

[########################################] | 100% Completed | 103.01 ms


In [9]:
S

tensor([[1.0000, 0.6201],
        [1.0000, 0.3722],
        [1.0000, 0.6201]])

In [10]:
I

tensor([[0, 2],
        [1, 2],
        [2, 0]])

In [14]:
@torch.no_grad()
def spec_similarity_search_cpu_by_queue(
    query: List[List[torch.Tensor]],  # Queue[List[(n_peaks, 2)]]
    ref: List[List[torch.Tensor]], # Queue[List[(n_peaks, 2)]]
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = ms_entropy_similarity,
    top_k: Optional[int] = None,
    num_dask_workers: int = 4,
    work_device: torch.device = torch.device("cpu"),
    output_device: Optional[torch.device] = None,
    dask_mode: Optional[Literal["threads", "processes", "single-threaded"]] = None,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:

    output_device = output_device or work_device
    top_k = top_k or len(ref)
    
    # 缓冲区模板
    scores_template = torch.full((top_k,), -float('inf'), 
                                device=work_device, dtype=torch.float32)
    indices_template = torch.full((top_k,), -1, 
                                device=work_device, dtype=torch.long)
    
    # 单query搜索闭包
    def _search_single_query(i: int, q: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        # 初始化缓冲区
        scores_buf = scores_template.clone()
        indices_buf = indices_template.clone()
        current_count = 0  # 有效结果计数器
        
        q_tensor = q.to(work_device)
        
        for r_idx, r_spec in enumerate(ref[i]):
            score = sim_operator(q_tensor, r_spec.to(work_device))

            # 阶段1：缓冲区未满时的快速写入
            if current_count < top_k:
                scores_buf[current_count] = score
                indices_buf[current_count] = r_idx
                current_count += 1

            # 阶段2：缓冲区已满后的条件替换
            else:
                min_idx = torch.argmin(scores_buf)
                if score > scores_buf[min_idx]:  # 只需比较当前最小值
                    # 定点替换
                    scores_buf[min_idx] = score
                    indices_buf[min_idx] = r_idx
        
        # 后处理缓冲区 （排序）
        valid_part = scores_buf[:current_count]
        sorted_idx = torch.argsort(valid_part, descending=True)
        scores_buf[:current_count] = valid_part[sorted_idx]
        indices_buf[:current_count] = indices_buf[:current_count][sorted_idx]

        return scores_buf.to(output_device), indices_buf.to(output_device)
    
    # 构建配对序列
    bag_queue = []
    for i,query_block in enumerate(query):
        query_block_bag = db.from_sequence(zip([i]*len(query_block), query_block), npartitions=num_dask_workers)
        results_bag = query_block_bag.map(lambda x: _search_single_query(x[0], x[1]))
        bag_queue.append(results_bag)
    
    # 并行搜索
    queue_results = dask.compute(bag_queue, scheduler=dask_mode, num_workers=num_dask_workers)[0]
    
    # 合并结果
    queue_results_bag = db.from_sequence(queue_results, npartitions=num_dask_workers)
    queue_results_bag = queue_results_bag.map(lambda x: pd.DataFrame(x,columns=["scores", "indices"]))
    queue_results_bag = queue_results_bag.map(lambda x: (torch.stack(x['scores'].tolist()),torch.stack(x['indices'].tolist())))
    queue_results = queue_results_bag.compute(scheduler=dask_mode, num_workers=num_dask_workers)
    
    return queue_results

In [15]:
with ProgressBar():
    queue_results = spec_similarity_search_cpu_by_queue([queries]*2, [queries]*2, top_k=2, num_dask_workers=2, dask_mode='threads')

[                                        ] | 0% Completed | 200.92 us

[########################################] | 100% Completed | 101.70 ms
[########################################] | 100% Completed | 101.73 ms


In [16]:
queue_results

[(tensor([[1.0000, 0.6201],
          [1.0000, 0.3722],
          [1.0000, 0.6201]]),
  tensor([[0, 2],
          [1, 2],
          [2, 0]])),
 (tensor([[1.0000, 0.6201],
          [1.0000, 0.3722],
          [1.0000, 0.6201]]),
  tensor([[0, 2],
          [1, 2],
          [2, 0]]))]

In [17]:
def test_cuda_sim_operator(q: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
    # 这里可以替换为实际的CUDA相似度计算函数
    return torch.sum(torch.abs(q - r))

@torch.no_grad()
def spec_similarity_search_cuda(
    query: List[torch.Tensor], # List[(n_peaks, 2)]
    ref: List[torch.Tensor], # List[(n_peaks, 2)]
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    top_k: Optional[int] = None,
    num_cuda_workers: int = 4,
    work_device: torch.device = torch.device("cuda:0"),
    output_device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    output_device = output_device or work_device
    top_k = top_k or len(ref)
    
    # 初始化CUDA流组（每个worker含3个流）
    stream_groups = [(
        torch.cuda.Stream(device=work_device),  # 数据转移流
        torch.cuda.Stream(device=work_device),  # 计算流
        torch.cuda.Stream(device=work_device)    # 缓冲区流
    ) for _ in range(num_cuda_workers)]

    # 预分配显存资源
    score_buffers = [torch.full((top_k,), -float('inf'), device=work_device) for _ in range(num_cuda_workers)]
    index_buffers = [torch.full((top_k,), -1, device=work_device, dtype=torch.long) for _ in range(num_cuda_workers)]
    event_pool = [torch.cuda.Event() for _ in range(num_cuda_workers*2)]

    # 异步执行容器
    results = [None] * len(query)
    
    for query_idx, q in enumerate(query):
        worker_id = query_idx % num_cuda_workers
        data_stream, compute_stream, buffer_stream = stream_groups[worker_id]
        event_idx = worker_id * 2

        # 阶段1: 异步数据传输
        with torch.cuda.stream(data_stream):
            q_gpu = q.to(work_device, non_blocking=True)
            ref_gpu = [r.to(work_device, non_blocking=True) for r in ref]
            event_pool[event_idx].record(stream=data_stream)

        # 阶段2: 异步计算
        with torch.cuda.stream(compute_stream):
            event_pool[event_idx].wait(stream=compute_stream)  # 等待数据就绪
            scores = []
            for r_idx, r in enumerate(ref_gpu):
                scores.append(sim_operator(q_gpu, r))
            event_pool[event_idx+1].record(stream=compute_stream)

        # 阶段3: 异步缓冲区更新
        with torch.cuda.stream(buffer_stream):
            event_pool[event_idx+1].wait(stream=buffer_stream)  # 等待计算完成
            current_count = 0
            score_buf = score_buffers[worker_id].zero_()
            index_buf = index_buffers[worker_id].zero_()
            
            for r_idx, score in enumerate(scores):
                if current_count < top_k:
                    score_buf[current_count] = score
                    index_buf[current_count] = r_idx
                    current_count += 1
                else:
                    min_idx = torch.argmin(score_buf)
                    if score > score_buf[min_idx]:
                        score_buf[min_idx] = score
                        index_buf[min_idx] = r_idx
            
            # 异步排序
            sorted_idx = torch.argsort(score_buf[:current_count], descending=True)
            score_buf[:current_count] = score_buf[:current_count][sorted_idx]
            index_buf[:current_count] = index_buf[:current_count][sorted_idx]

            # 异步传回结果
            results[query_idx] = (
                score_buf.to(output_device, non_blocking=True),
                index_buf.to(output_device, non_blocking=True)
            )

    # 同步所有流
    torch.cuda.synchronize(work_device)
    
    # 组装最终结果
    return torch.stack([r[0] for r in results]), torch.stack([r[1] for r in results])

In [22]:
with ProgressBar():
    results = spec_similarity_search_cuda(queries, queries, test_cuda_sim_operator, top_k=2, num_cuda_workers=2,output_device="cpu")

In [23]:
results

(tensor([[150.3000,  50.3000],
         [150.3000, 100.0000],
         [100.0000,  50.3000]]),
 tensor([[1, 2],
         [0, 2],
         [1, 0]]))

In [24]:
@torch.no_grad()
def spec_similarity_search_cuda_by_queue(
    query: List[List[torch.Tensor]], # Queue[List[(n_peaks, 2)]]
    ref: List[List[torch.Tensor]], # Queue[List[(n_peaks, 2)]]
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    top_k: Optional[int] = None,
    num_cuda_workers: int = 4,
    num_dask_workers: int = 4,
    work_device: torch.device = torch.device("cuda:0"),
    output_device: Optional[torch.device] = None,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:

    block_bag = db.from_sequence(zip(query, ref), npartitions=num_dask_workers)
    block_bag = block_bag.map(lambda x: spec_similarity_search_cuda(
        x[0], x[1], sim_operator, top_k, num_cuda_workers, work_device, output_device
    ))
    results = block_bag.compute(scheduler='threads', num_workers=num_dask_workers)
    return results

In [27]:
with ProgressBar():
    queue_results = spec_similarity_search_cuda_by_queue([queries]*2, [queries]*2, test_cuda_sim_operator, top_k=2, num_cuda_workers=2, num_dask_workers=2,output_device="cpu")

[                                        ] | 0% Completed | 220.48 us

[########################################] | 100% Completed | 102.10 ms


In [28]:
queue_results

[(tensor([[150.3000,  50.3000],
          [150.3000, 100.0000],
          [100.0000,  50.3000]]),
  tensor([[1, 2],
          [0, 2],
          [1, 0]])),
 (tensor([[150.3000,  50.3000],
          [150.3000, 100.0000],
          [100.0000,  50.3000]]),
  tensor([[1, 2],
          [0, 2],
          [1, 0]]))]

In [45]:
from abc import ABC
import torch
from functools import partial
from typing import Callable, Optional

class EmbbedingSimilarityOperator(ABC):
    
    cpu_kwargs = {}
    cuda_kwargs = {}
    
    @classmethod
    def cuda_operator(
        cls,
    ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        '''
        Returns a function that computes the similarity between two batches of embeddings.
        The function takes two batches of embeddings and returns a similarity matrix.
        The similarity matrix is a tensor of shape (batch_size_va, batch_size_vb) where each element (i, j)
        represents the similarity between the i-th embedding in the first batch and the j-th embedding
        in the second batch.
        The function should be able to handle batches of different sizes.
        '''
        raise NotImplementedError(f"{cls.__name__}.cuda_operator() not implemented")
    
    @classmethod
    def cpu_operator(
        cls,
    ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        '''
        Returns a function that computes the similarity between two batches of embeddings.
        The function takes two batches of embeddings and returns a similarity matrix.
        The similarity matrix is a tensor of shape (batch_size_va, batch_size_vb) where each element (i, j)
        represents the similarity between the i-th embedding in the first batch and the j-th embedding
        in the second batch.
        The function should be able to handle batches of different sizes.
        '''
        raise NotImplementedError(f"{cls.__name__}.cpu_operator() not implemented")
    
    @classmethod
    def get_operator_kwargs(
        cls,
        device: torch.device,
        input_kwargs: Optional[dict] = None,
    ) -> dict:
        if device.type.startswith("cuda"):
            return {**cls.cuda_kwargs, **(input_kwargs or {})}
        else:
            return {**cls.cpu_kwargs, **(input_kwargs or {})}
        
    @classmethod
    def get_operator(
        cls, 
        device: torch.device,
        input_kwargs: Optional[dict] = None,
    ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        if device.type.startswith("cuda"):
            return partial(cls.cuda_operator(), **cls.get_operator_kwargs(device,input_kwargs))
        else:
            return partial(cls.cpu_operator(), **cls.get_operator_kwargs(device,input_kwargs))
        
class SpectramSimilarityOperator(EmbbedingSimilarityOperator):
    
    dask_mode = None
    
    @classmethod
    def cuda_operator(
        cls,
    ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        '''
        Returns a function that computes the similarity between two spectra.
        The function takes two batches of spectra and returns a similarity matrix.
        The similarity matrix is a zero-dimensional tensor
        '''
        raise NotImplementedError(f"{cls.__name__}.cuda_operator() not implemented")
    
    @classmethod
    def cpu_operator(
        cls,
    ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        '''
        Returns a function that computes the similarity between two spectra.
        The function takes two batches of spectra and returns a similarity matrix.
        The similarity matrix is a zero-dimensional tensor
        '''
        raise NotImplementedError(f"{cls.__name__}.cpu_operator() not implemented")
    
    @classmethod
    def get_dask_mode(
        cls,
        input_dask_mode: Optional[str] = None,
    ) -> Optional[str]:
        if input_dask_mode is not None:
            return input_dask_mode
        else:
            return cls.dask_mode

def resolve_device(
    device: Union[str, torch.device, Literal['auto']], 
    default: torch.device
) -> torch.device:
    if isinstance(device, torch.device):
        return device
    if device == 'auto':
        return default
    if device == 'cuda':
        device = 'cuda:0'
    return torch.device(device)

class MSEntropyOperator(SpectramSimilarityOperator):
    
    cpu_kwargs = {
        "ms2_tolerance_in_da":0.02, 
        "ms2_tolerance_in_ppm": -1, 
        "clean_spectra": True,
    }
    dask_mode = "threads" # me.calculate_entropy_similarity是CPU函数，因此默认使用线程池
    
    @classmethod
    def cpu_operator(cls):
        return ms_entropy_similarity
    
    @classmethod
    def cuda_operator(cls):
        raise NotImplementedError(f"{cls.__name__} not supported on CUDA")

def spec_similarity_search(
    query: List[torch.Tensor],
    ref: List[torch.Tensor],
    sim_operator: SpectramSimilarityOperator = MSEntropyOperator,
    top_k: Optional[int] = None,
    num_cuda_workers: int = 4,
    num_dask_workers: int = 4,
    work_device: Union[str, torch.device, Literal['auto']] = 'auto',
    output_device: Union[str, torch.device, Literal['auto']] = 'auto',
    dask_mode: Optional[Literal["threads", "processes", "single-threaded"]] = None,
    operator_kwargs: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

    # 设备推断
    _work_device = resolve_device(work_device, query[0].device if query else torch.device('cpu'))
    _output_device = resolve_device(output_device, _work_device)
    
    # 算子生成
    operator = sim_operator.get_operator(_work_device,operator_kwargs)

    # 分发实现
    if _work_device.type.startswith('cuda'):
        return spec_similarity_search_cuda(
            query, ref, operator,
            top_k=top_k,
            num_cuda_workers=num_cuda_workers,
            work_device=_work_device,
            output_device=_output_device
        )
    else:
        return spec_similarity_search_cpu(
            query, ref, operator,
            top_k=top_k,
            num_dask_workers=num_dask_workers,
            work_device=_work_device,
            output_device=_output_device,
            dask_mode=sim_operator.get_dask_mode(dask_mode)
        )

def spec_similarity_search_by_queue(
    query: List[List[torch.Tensor]],
    ref: List[List[torch.Tensor]],
    sim_operator: SpectramSimilarityOperator = MSEntropyOperator,
    top_k: Optional[int] = None,
    num_cuda_workers: int = 4,
    num_dask_workers: int = 4,
    work_device: Union[str, torch.device, Literal['auto']] = 'auto',
    output_device: Union[str, torch.device, Literal['auto']] = 'auto',
    dask_mode: Optional[Literal["threads", "processes", "single-threaded"]] = None,
    operator_kwargs: Optional[dict] = None,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
    
    # 设备推断
    _work_device = resolve_device(work_device, query[0][0].device if query else torch.device('cpu'))
    _output_device = resolve_device(output_device, _work_device)
    
    # 算子生成
    operator = sim_operator.get_operator(_work_device,operator_kwargs)
    
    # 分发实现
    if _work_device.type.startswith('cuda'):
        return spec_similarity_search_cuda_by_queue(
            query, ref, operator,
            top_k=top_k,
            num_cuda_workers=num_cuda_workers,
            num_dask_workers=num_dask_workers,
            work_device=_work_device,
            output_device=_output_device
        )
    else:
        return spec_similarity_search_cpu_by_queue(
            query, ref, operator,
            top_k=top_k,
            num_dask_workers=num_dask_workers,
            work_device=_work_device,
            output_device=_output_device,
            dask_mode=sim_operator.get_dask_mode(dask_mode)
        )

In [39]:
_work_device = resolve_device("auto", queries[0][0].device if queries else torch.device('cpu'))
_output_device = resolve_device("auto", _work_device)
operator = MSEntropyOperator.get_operator(_work_device,None)

In [46]:
with ProgressBar():
    results = spec_similarity_search(
        queries,queries
    )

[                                        ] | 0% Completed | 442.39 us

[########################################] | 100% Completed | 102.83 ms


In [47]:
results

(tensor([[1.0000, 0.6201, 0.0000],
         [1.0000, 0.3722, 0.0000],
         [1.0000, 0.6201, 0.3722]]),
 tensor([[0, 2, 1],
         [1, 2, 0],
         [2, 0, 1]]))

In [48]:
with ProgressBar():
    results = spec_similarity_search_by_queue(
        [queries]*2,[queries]*2
    )

[                                        ] | 0% Completed | 209.20 us

[########################################] | 100% Completed | 102.14 ms
[########################################] | 100% Completed | 102.05 ms


In [49]:
results

[(tensor([[1.0000, 0.6201],
          [1.0000, 0.3722],
          [1.0000, 0.6201]]),
  tensor([[0, 2],
          [1, 2],
          [2, 0]])),
 (tensor([[1.0000, 0.6201],
          [1.0000, 0.3722],
          [1.0000, 0.6201]]),
  tensor([[0, 2],
          [1, 2],
          [2, 0]]))]