In [1]:
from masslib4search.snaps.MassSearchTools.utils.toolbox import EmbeddingSimilaritySearch
import torch

In [None]:
import torch
import dask.bag as db
from functools import partial
from typing import Tuple,Callable,Optional,Union,Literal,List

@torch.no_grad()
def cosine(va: torch.Tensor, vb: torch.Tensor) -> torch.Tensor:
    """余弦相似度"""
    norm_a = torch.norm(va, p=2, dim=-1, keepdim=True)
    norm_b = torch.norm(vb, p=2, dim=-1, keepdim=True)
    return torch.matmul(va, vb.transpose(-1,-2)) / (norm_a * norm_b.transpose(-1,-2) + 1e-6)

@torch.no_grad()
def emb_similarity_search_cpu(
    query: torch.Tensor, # shape: (n_q, dim), dtype: float32
    ref: torch.Tensor, # shape: (n_r, dim), dtype: float32
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cosine,
    top_k: Optional[int] = None,
    chunk_size: int = 5120,
    work_device: torch.device = torch.device("cpu"),
    output_device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    # 设备配置
    output_device = output_device or work_device
    ref_num = ref.size(0)  # 获取参考集总数
    top_k = ref_num if top_k is None else top_k  # 自动对齐参考集数量
    
    # 空查询集
    if len(query) == 0:
        return (
            torch.tensor([], device=output_device, dtype=torch.long).reshape(0, top_k),
            torch.tensor([], device=output_device, dtype=torch.float32).reshape(0, top_k)
        )
        
    # 空参考集
    if len(ref) == 0:
        return (
            torch.full((len(query),top_k), -1, device=output_device, dtype=torch.long),
            torch.full((len(query),top_k), -float('inf'), device=output_device, dtype=torch.float32),
        )
    
    # 初始化全局缓冲区模板
    scores_template = torch.full((top_k,), -float('inf'), 
                                device=work_device, dtype=torch.float32)
    indices_template = torch.full((top_k,), -1, 
                                device=work_device, dtype=torch.long)
    
    results = []
    indices_list = []

    # 分块处理查询集
    for q_chunk in query.split(chunk_size):
        q_work = q_chunk.to(work_device)
        batch_size = q_work.size(0)
        
        # 初始化每批查询的缓冲区 (batch_size, top_k)
        scores_buf = scores_template[None, :].expand(batch_size, -1).clone()
        indices_buf = indices_template[None, :].expand(batch_size, -1).clone()

        # 分块处理参考集
        for r_idx, r_chunk in enumerate(ref.split(chunk_size)):
            r_work = r_chunk.to(work_device)
            sim = sim_operator(q_work, r_work)  # (batch_size, ref_chunk_size)
            
            # 生成全局索引
            start_idx = r_idx * chunk_size
            indices = torch.arange(start_idx, start_idx + r_work.size(0), 
                                    device=work_device)
            
            # 向量化合并逻辑
            combined_scores = torch.cat([scores_buf, sim], dim=1)
            combined_indices = torch.cat([
                indices_buf, 
                indices.expand(batch_size, -1)
            ], dim=1)
            
            # 保留TopK
            top_scores, top_pos = torch.topk(combined_scores, top_k, dim=1)
            scores_buf = top_scores
            indices_buf = torch.gather(combined_indices, 1, top_pos)

        # 后处理：确保严格排序（仅在需要时）
        if top_k < ref_num:
            sorted_idx = torch.argsort(scores_buf, dim=1, descending=True)
            scores_buf = torch.gather(scores_buf, 1, sorted_idx)
            indices_buf = torch.gather(indices_buf, 1, sorted_idx)
        
        # 转移结果到目标设备
        results.append(scores_buf.to(output_device))
        indices_list.append(indices_buf.to(output_device))

    return torch.cat(indices_list, dim=0), torch.cat(results, dim=0)

@torch.no_grad()
def emb_similarity_search_cuda(
    query: torch.Tensor, # shape: (n_q, dim), dtype: float32
    ref: torch.Tensor, # shape: (n_r, dim), dtype: float32
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cosine,
    top_k: Optional[int] = None,
    chunk_size: int = 5120,
    work_device: torch.device = torch.device("cuda:0"),
    output_device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    # 设备配置
    output_device = output_device or work_device
    is_same_device = work_device == output_device
    ref_num = ref.size(0)
    top_k = ref_num if top_k is None else top_k  # 自动对齐参考集数量
    
    # 空查询集
    if len(query) == 0:
        return (
            torch.tensor([], device=output_device, dtype=torch.long).reshape(0, top_k),
            torch.tensor([], device=output_device, dtype=torch.float32).reshape(0, top_k)
        )
        
    # 空参考集
    if len(ref) == 0:
        return (
            torch.full((len(query),top_k), -1, device=output_device, dtype=torch.long),
            torch.full((len(query),top_k), -float('inf'), device=output_device, dtype=torch.float32),
        )
    
    # 初始化三流
    input_stream = torch.cuda.Stream()
    compute_stream = torch.cuda.Stream()
    output_stream = torch.cuda.Stream()

    # 数据预处理（异步）
    with torch.cuda.stream(input_stream):
        query = query.to(work_device, non_blocking=True)
        ref = ref.to(work_device, non_blocking=True)
    torch.cuda.synchronize()

    # 缓冲区模板（固定内存）
    with torch.cuda.stream(compute_stream):
        scores_template = torch.full((top_k,), -float('inf'), 
                                    device=work_device, dtype=torch.float32)
        indices_template = torch.full((top_k,), -1, 
                                    device=work_device, dtype=torch.long)

    all_scores, all_indices = [], []
    r_chunks = list(ref.split(chunk_size))

    # 主处理循环
    for q_chunk in query.split(chunk_size):
        # 初始化缓冲区（每个查询块独立）
        with torch.cuda.stream(compute_stream):
            batch_size = q_chunk.size(0)
            scores_buf = scores_template[None, :].expand(batch_size, -1).clone()
            indices_buf = indices_template[None, :].expand(batch_size, -1).clone()

        # 预加载第一个参考块
        current_r = None
        with torch.cuda.stream(input_stream):
            current_r = r_chunks[0].to(work_device, non_blocking=True)
        compute_events = [torch.cuda.Event() for _ in r_chunks]

        for i in range(len(r_chunks)):
            # 预加载下一个参考块
            next_r = r_chunks[i+1].to(work_device, non_blocking=True) if i+1 < len(r_chunks) else None
            with torch.cuda.stream(input_stream):
                if next_r is not None:
                    next_r = next_r.to(work_device, non_blocking=True)

            # 计算块处理
            with torch.cuda.stream(compute_stream):
                compute_stream.wait_stream(input_stream)
                
                # 计算相似度
                sim = sim_operator(q_chunk, current_r)
                indices = torch.arange(
                    i * chunk_size, 
                    i * chunk_size + current_r.size(0), 
                    device=work_device
                )

                # 合并到缓冲区
                combined_scores = torch.cat([scores_buf, sim], dim=1)
                combined_indices = torch.cat([
                    indices_buf,
                    indices[None, :].expand(batch_size, -1)
                ], dim=1)
                
                # 保留TopK
                scores_buf, top_pos = torch.topk(combined_scores, top_k, dim=1)
                indices_buf = torch.gather(combined_indices, 1, top_pos)
                
                compute_events[i].record()

            current_r = next_r

        # 最终排序（非全量模式）
        with torch.cuda.stream(compute_stream):
            if top_k < ref_num:
                sorted_idx = torch.argsort(scores_buf, dim=1, descending=True)
                scores_buf = torch.gather(scores_buf, 1, sorted_idx)
                indices_buf = torch.gather(indices_buf, 1, sorted_idx)

        # 异步传输结果
        with torch.cuda.stream(output_stream):
            compute_events[-1].synchronize()
            transfer = (scores_buf if is_same_device 
                        else scores_buf.to(output_device, non_blocking=True))
            all_scores.append(transfer)
            transfer = (indices_buf if is_same_device 
                        else indices_buf.to(output_device, non_blocking=True))
            all_indices.append(transfer)

    # 全局同步
    torch.cuda.synchronize()
    return (
        torch.cat(all_indices, dim=0),
        torch.cat(all_scores, dim=0)
    )


In [51]:
emb_similarity_search_cpu(torch.rand(2, 128), torch.rand(5, 128),top_k=6)

(tensor([[ 4,  1,  2,  0,  3, -1],
         [ 0,  1,  3,  4,  2, -1]]),
 tensor([[0.7795, 0.7756, 0.7675, 0.7581, 0.7553,   -inf],
         [0.7952, 0.7727, 0.7592, 0.7451, 0.7396,   -inf]]))

In [52]:
emb_similarity_search_cuda(torch.rand(2, 128), torch.rand(5, 128),top_k=6)

(tensor([[ 4,  1,  0,  3,  2, -1],
         [ 3,  1,  0,  2,  4, -1]], device='cuda:0'),
 tensor([[0.7663, 0.7601, 0.7549, 0.7436, 0.7232,   -inf],
         [0.8000, 0.7898, 0.7637, 0.7529, 0.7491,   -inf]], device='cuda:0'))

In [62]:
emb_similarity_search_cpu(torch.tensor([]), torch.rand(5, 128),top_k=6)

(tensor([], size=(0, 6), dtype=torch.int64), tensor([], size=(0, 6)))

In [63]:
emb_similarity_search_cpu(torch.rand(3, 128), torch.tensor([]),top_k=6)

(tensor([[-1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1]]),
 tensor([[-inf, -inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf]]))

In [65]:
def convert_hits(indices: torch.Tensor, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    将TopK格式的搜索结果转换为展开的命中列表格式
    Args:
        indices: shape (num_qry, top_k), dtype long，包含参考索引（-1表示无效）
        scores:  shape (num_qry, top_k), dtype float，对应分数（-inf表示无效）
    Returns:
        new_indices: shape (num_hitted, 3)，每行格式 [qry_index, ref_index, top_k_pos]
        new_scores:  shape (num_hitted,)，有效命中的分数
    """
    # 生成查询索引网格
    num_qry, top_k = indices.shape
    qry_idx = torch.arange(num_qry, device=indices.device)[:, None].expand(-1, top_k)
    topk_pos = torch.arange(top_k, device=indices.device)[None, :].expand(num_qry, -1)

    # 创建有效命中掩码（排除ref_index=-1）
    valid_mask = (indices != -1) & (scores != -float('inf'))

    # 提取有效数据
    valid_qry = qry_idx[valid_mask]
    valid_ref = indices[valid_mask]
    valid_topk = topk_pos[valid_mask]
    valid_scores = scores[valid_mask]

    # 组合结果
    new_indices = torch.stack([valid_qry, valid_ref, valid_topk], dim=1)
    return new_indices, valid_scores

In [70]:
convert_hits(*emb_similarity_search_cpu(torch.rand(2, 128), torch.rand(5, 128),top_k=6))

(tensor([[0, 0, 0],
         [0, 4, 1],
         [0, 2, 2],
         [0, 3, 3],
         [0, 1, 4],
         [1, 4, 0],
         [1, 3, 1],
         [1, 0, 2],
         [1, 2, 3],
         [1, 1, 4]]),
 tensor([0.7784, 0.7614, 0.7449, 0.7421, 0.7079, 0.7711, 0.7639, 0.7565, 0.7339,
         0.6741]))