In [12]:
from masslib4search.snaps.MassSearchTools.utils.similarity import ms_similarity
import torch
import dask.bag as db
from itertools import cycle

In [8]:
import torch
import ms_entropy as me
import dask
import dask.bag as db
from typing import List, Callable, Optional, Union, Literal

def ms_entropy_similarity(
    query_spec: torch.Tensor, # (n_peaks, 2)
    ref_spec: torch.Tensor, # (n_peaks, 2)
) -> torch.Tensor: # (1,1)
    sim = me.calculate_entropy_similarity(query_spec, ref_spec)
    return torch.tensor([[sim]], device=query_spec.device)

@torch.no_grad()
def spectrum_similarity_cpu(
    query: List[List[torch.Tensor]],  # Queue[List[(n_peaks, 2)]]
    ref: List[List[torch.Tensor]], # Queue[List[(n_peaks, 2)]]
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    num_dask_workers: int = 4,
    work_device: torch.device = torch.device("cpu"),
    output_device: Optional[torch.device] = None,
    dask_mode: Optional[Literal["threads", "processes", "single-threaded"]] = None,
) -> List[torch.Tensor]: # Queue[(len(query_block), len(ref_block)]

    # 构建配对序列
    bag_queue = []
    for query_block, ref_block in zip(query, ref):
        query_block_bag = db.from_sequence(query_block, npartitions=num_dask_workers)
        ref_block_bag = db.from_sequence(ref_block, npartitions=num_dask_workers)
        pairs_bag = query_block_bag.product(ref_block_bag)
        results_bag = pairs_bag.map(lambda x: sim_operator(x[0].to(work_device), x[1].to(work_device)))
        results_bag = results_bag.map(lambda s: s.to(output_device or work_device))
        bag_queue.append(results_bag)
    
    # 使用dask并行计算
    queue_results = dask.compute(bag_queue, scheduler=dask_mode, num_workers=num_dask_workers)[0]
    # 合并结果
    queue_results_bag = db.from_sequence(zip(queue_results,query,ref), npartitions=num_dask_workers)
    queue_results_bag = queue_results_bag.map(lambda x: torch.cat(x[0], dim=0).reshape(len(x[1]), len(x[2])))
    queue_results = queue_results_bag.compute(scheduler='threads', num_workers=num_dask_workers)
    
    return queue_results

In [3]:
def sample_cpu_spectra():
    """生成CPU测试数据"""
    return [
        # 格式：(m/z, intensity)
        torch.tensor([[100.0, 1.0], [200.0, 0.8], [300.0, 0.5]], dtype=torch.float32),
        torch.tensor([[150.0, 0.9], [250.0, 0.7], [350.0, 0.6]], dtype=torch.float32)
    ]

In [4]:
spec_block = sample_cpu_spectra()

In [11]:
spectrum_similarity_cpu([spec_block],[spec_block],ms_entropy_similarity,dask_mode="threads",num_dask_workers=4)

[tensor([[1.0000, 0.0000],
         [0.0000, 1.0000]])]

In [13]:
@torch.no_grad()
def spectrum_similarity_cuda_block(
    query: List[torch.Tensor], # List[(n_peaks, 2)]
    ref: List[torch.Tensor], # List[(n_peaks, 2)]
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    num_cuda_workers: int = 4,
    work_device: torch.device = torch.device("cuda:0"),
    output_device: Optional[torch.device] = None,
) -> torch.Tensor:

    output_device = output_device or work_device
    torch.cuda.set_device(work_device)
    
    # 为每个worker创建三个专用流
    worker_resources = [
        {
            'h2d_stream': torch.cuda.Stream(),  # 数据传入流
            'compute_stream': torch.cuda.Stream(),  # 计算流
            'd2h_stream': torch.cuda.Stream(),  # 结果传出流
            'h2d_event': torch.cuda.Event(),
            'compute_event': torch.cuda.Event(),
        }
        for _ in range(num_cuda_workers)
    ]
    
    # 预分配设备内存
    results = torch.zeros(len(query), len(ref), device=output_device)

    # 异步执行函数
    def _process_pair(q_idx, r_idx, worker_id):
        resources = worker_resources[worker_id]
        
        # 修改后的Stage 1：同时传输query和ref
        with torch.cuda.stream(resources['h2d_stream']):
            q_tensor = query[q_idx].pin_memory().to(work_device, non_blocking=True)
            r_tensor = ref[r_idx].pin_memory().to(work_device, non_blocking=True)
            resources['h2d_event'].record()
        
        # Stage 2保持不变（但使用新传输的r_tensor）
        with torch.cuda.stream(resources['compute_stream']):
            resources['h2d_event'].wait()
            similarity = sim_operator(q_tensor, r_tensor)  # 改用动态传输的ref
            resources['compute_event'].record()
        
        # Stage 3: 结果传回output_device
        with torch.cuda.stream(resources['d2h_stream']):
            resources['compute_event'].wait()
            if output_device != work_device:
                results[q_idx, r_idx] = similarity.to(output_device, non_blocking=True)
            else:
                results[q_idx, r_idx] = similarity

    # 任务调度器
    worker_cycle = cycle(range(num_cuda_workers))
    
    # 提交任务到流水线
    futures = []
    for q_idx in range(len(query)):
        for r_idx in range(len(ref)):
            worker_id = next(worker_cycle)
            futures.append((q_idx, r_idx, worker_id))
    
    # 启动所有异步任务
    for q_idx, r_idx, worker_id in futures:
        _process_pair(q_idx, r_idx, worker_id)
    
    # 等待所有流完成
    for worker in worker_resources:
        worker['d2h_stream'].synchronize()
    
    return results

@torch.no_grad()
def spectrum_similarity_cuda(
    query: List[List[torch.Tensor]], # Queue[List[(n_peaks, 2)]]
    ref: List[List[torch.Tensor]], # Queue[List[(n_peaks, 2)]]
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    num_cuda_workers: int = 4,
    num_dask_workers: int = 4,
    work_device: torch.device = torch.device("cuda:0"),
    output_device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
    
    block_bag = db.from_sequence(zip(query, ref), npartitions=num_dask_workers)
    block_bag = block_bag.map(lambda x: spectrum_similarity_cuda_block(
        x[0], x[1], sim_operator, num_cuda_workers, work_device, output_device
    ))
    results = block_bag.compute(scheduler='threads', num_workers=num_dask_workers)
    return results

In [16]:
spectrum_similarity_cuda(
    [spec_block],[spec_block],
    lambda x,y: torch.tensor([[1.0]]),
)

[tensor([[1., 1.],
         [1., 1.]], device='cuda:0')]