In [27]:
import torch
from typing import Callable, Literal, Optional, Union

@torch.no_grad()
def tanimoto(va: torch.Tensor, vb: torch.Tensor) -> torch.Tensor:
    """Tanimoto系数"""
    vp = torch.sum(va.unsqueeze(-2) * vb.unsqueeze(-3), dim=-1)
    vas = torch.sum(va**2, dim=-1, keepdim=True)
    vbs = torch.sum(vb**2, dim=-1, keepdim=True)
    return vp / (vas + vbs.transpose(-1,-2) - vp + 1e-6)

@torch.no_grad()
def cosine(va: torch.Tensor, vb: torch.Tensor) -> torch.Tensor:
    """余弦相似度"""
    norm_a = torch.norm(va, p=2, dim=-1, keepdim=True)
    norm_b = torch.norm(vb, p=2, dim=-1, keepdim=True)
    return torch.matmul(va, vb.transpose(-1,-2)) / (norm_a * norm_b.transpose(-1,-2) + 1e-6)

@torch.no_grad()
def dot(va: torch.Tensor, vb: torch.Tensor) -> torch.Tensor:
    """点积"""
    return torch.matmul(va, vb.transpose(-1,-2))

@torch.no_grad()
def jaccard(va: torch.Tensor, vb: torch.Tensor) -> torch.Tensor:
    """Jaccard系数"""
    intersection = torch.logical_and(va.unsqueeze(-2), vb.unsqueeze(-3))
    union = torch.logical_or(va.unsqueeze(-2), vb.unsqueeze(-3))
    return torch.sum(intersection.float(), dim=-1) / (torch.sum(union.float(), dim=-1) + 1e-6)

@torch.no_grad()
def pearson(va: torch.Tensor, vb: torch.Tensor) -> torch.Tensor:
    """皮尔逊相关系数"""
    mean_a = torch.mean(va, dim=-1, keepdim=True)
    mean_b = torch.mean(vb, dim=-1, keepdim=True)
    centered_a = va - mean_a
    centered_b = vb - mean_b
    numerator = torch.matmul(centered_a, centered_b.transpose(-1,-2))
    denominator = torch.sqrt(
        torch.sum(centered_a**2, dim=-1, keepdim=True) *
        torch.sum(centered_b**2, dim=-1, keepdim=True).transpose(-1,-2)
    )
    return numerator / (denominator + 1e-6)

@torch.no_grad()
def embedding_similarity_cpu(
    query: torch.Tensor, # shape: (n_q, dim), dtype: float32
    ref: torch.Tensor, # shape: (n_r, dim), dtype: float32
    chunk_size: int = 5120,
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cosine,
    work_device: torch.device = torch.device("cpu"),
    output_device: Optional[torch.device] = None,
) -> torch.Tensor: # shape: (n_q, n_r), dtype: float32
    
    # 自动选择设备
    output_device = output_device or work_device
    query = query.to(work_device)
    ref = ref.to(work_device)
    
    # 分块计算逻辑    
    results = []
    for q_chunk in query.split(chunk_size):
        chunk_results = []
        for r_chunk in ref.split(chunk_size):
            res = sim_operator(q_chunk, r_chunk).to(output_device)
            chunk_results.append(res)
        results.append(torch.cat(chunk_results, dim=1))
    return torch.cat(results).to(output_device)

@torch.no_grad()
def embedding_similarity_gpu(
    query: torch.Tensor,
    ref: torch.Tensor,
    chunk_size: int = 5120,
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cosine,
    work_device: torch.device = torch.device("cuda:0"),
    output_device: Optional[torch.device] = None,
) -> torch.Tensor:
    
    output_device = output_device or work_device
    torch.cuda.set_device(work_device)
    
    # 创建三阶段流
    h2d_stream = torch.cuda.Stream()  # 主机到设备传输流
    compute_stream = torch.cuda.Stream()  # 计算流
    d2h_stream = torch.cuda.Stream()  # 设备到主机传输流
    
    results = []
    ref_chunks = list(ref.split(chunk_size))
    
    # 预取第一个ref chunk
    current_ref = ref_chunks[0].to(work_device, non_blocking=True)
    
    for q_chunk in query.to(work_device).split(chunk_size):
        chunk_results = []
        next_ref_iter = iter(ref_chunks[1:] + [None])
        
        for i in range(len(ref_chunks)):
            # 流水线阶段1：预取下一个chunk
            with torch.cuda.stream(h2d_stream):
                next_ref = next(next_ref_iter, None)
                if next_ref is not None:
                    next_ref = next_ref.to(work_device, non_blocking=True)
            
            # 流水线阶段2：执行计算
            with torch.cuda.stream(compute_stream):
                sim = sim_operator(q_chunk, current_ref)
            
            # 流水线阶段3：传输结果
            with torch.cuda.stream(d2h_stream):
                sim_cpu = sim.to(output_device,non_blocking=True)
                chunk_results.append(sim_cpu)
            
            # 更新当前ref并同步流
            current_ref = next_ref if next_ref is not None else current_ref
            torch.cuda.current_stream().wait_stream(h2d_stream)
            torch.cuda.current_stream().wait_stream(d2h_stream)
        
        # 等待所有操作完成
        torch.cuda.synchronize()
        results.append(torch.cat(chunk_results, dim=1))
    
    return torch.cat(results).to(output_device)

def embedding_similarity(
    query: torch.Tensor,
    ref: torch.Tensor,
    chunk_size: int = 5120,
    sim_operator: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cosine,
    work_device: Union[str, torch.device, Literal['auto']] = 'auto',
    output_device: Union[str, torch.device, Literal['auto']] = 'auto',
) -> torch.Tensor:
    # 设备类型转换逻辑
    def resolve_device(
        device: Union[str, torch.device, Literal['auto']], 
        default: torch.device
    ) -> torch.device:
        if isinstance(device, torch.device):
            return device
        if device == 'auto':
            return default
        if device == 'cuda':
            device = 'cuda:0'
        return torch.device(device)

    # 自动推断工作设备
    _work_device = resolve_device(work_device, query.device)
    # 自动推断输出设备
    _output_device = resolve_device(output_device, _work_device)

    # 分发到具体实现
    if _work_device.type.startswith('cuda'):
        return embedding_similarity_gpu(
            query, ref, chunk_size, sim_operator,
            work_device=_work_device,
            output_device=_output_device
        )
    else:
        return embedding_similarity_cpu(
            query, ref, chunk_size, sim_operator,
            work_device=_work_device,
            output_device=_output_device
        )

In [29]:
import numpy as np
from IPython.display import display, Markdown

class TestRunner:
    def __init__(self):
        self.results = []
        self.cuda_available = torch.cuda.is_available()
        
    def setup(self):
        self.vec_a = torch.tensor([[1.0, 0.0], [0.5, 0.5]], dtype=torch.float32)
        self.vec_b = torch.tensor([[1.0, 0.0], [0.0, 1.0], [0.5, 0.5]], dtype=torch.float32)
        
    def assert_device(self, tensor, expected_device):
        tensor_device = str(tensor.device)
        expected_device = str(expected_device)
        if expected_device == "cuda":
            expected_device = "cuda:0"
        assert tensor_device == expected_device, f"Tensor is on {tensor_device}, expected {expected_device}"
    
    def run_test(self, test_func, test_name):
        try:
            self.setup()
            test_func()
            self.results.append((test_name, "✅ PASSED", ""))
        except Exception as e:
            self.results.append((test_name, "❌ FAILED", str(e)))
    
    # 测试用例组
    def test_device_control(self):
        test_cases = [
            ("cpu", "cpu"),
            ("auto", "auto"),
            (torch.device("cpu"), "cuda"),
            ("cuda", "auto"),
        ]
        
        for work_dev, out_dev in test_cases:
            if work_dev == "cuda" and not self.cuda_available:
                continue
            
            result = embedding_similarity(self.vec_a, self.vec_b, 
                                        work_device=work_dev,
                                        output_device=out_dev)
            
            expected_work = self.vec_a.device if work_dev == "auto" else torch.device(work_dev)
            expected_out = expected_work if out_dev == "auto" else torch.device(out_dev)
            
            self.assert_device(result, expected_out)
    
    def test_calculation_accuracy(self):
        test_cases = [
            (cosine, [
                [1.0, 0.0, 0.7071],
                [0.7071, 0.7071, 1.0]
            ]),
            (tanimoto, [
                [1.0, 0.0, 0.5],
                [0.5, 0.5, 1.0]
            ])
        ]
        
        for sim_func, expected in test_cases:
            result = embedding_similarity(self.vec_a, self.vec_b, sim_operator=sim_func)
            np.testing.assert_allclose(result.numpy(), expected, atol=1e-4)
    
    def test_chunk_handling(self):
        full_result = embedding_similarity(self.vec_a, self.vec_b, chunk_size=1024)
        chunk_result = embedding_similarity(self.vec_a, self.vec_b, chunk_size=1)
        assert torch.allclose(full_result, chunk_result), "分块处理结果不一致"
    
    def test_device_transfer(self):
        if not self.cuda_available:
            return
        
        # CPU计算，GPU输出
        cpu_result = embedding_similarity(self.vec_a, self.vec_b, 
                                        work_device="cpu",
                                        output_device="cuda")
        self.assert_device(cpu_result, "cuda:0")
        
        # GPU计算，CPU输出
        gpu_result = embedding_similarity(self.vec_a.to("cuda"), self.vec_b,
                                        output_device="cpu")
        self.assert_device(gpu_result, "cpu")
    
    def test_invalid_device(self):
        try:
            embedding_similarity(torch.randn(2,2), torch.randn(3,2),
                                work_device="invalid_device")
        except ValueError:
            pass
        
        try:
            embedding_similarity(torch.randn(2,2), torch.randn(3,2),
                                output_device=123)
        except TypeError:
            pass
    
    def test_edge_cases(self):
        # 空输入测试
        empty_vec = torch.empty(0, 2)
        result = embedding_similarity(empty_vec, empty_vec)
        assert result.shape == (0, 0)
        
        # 单元素测试
        single_vec = torch.tensor([[1.0, 0.0]])
        result = embedding_similarity(single_vec, single_vec)
        assert torch.allclose(result, torch.ones(1, 1))
    
    def show_results(self):
        display(Markdown("## 测试结果汇总"))
        for name, status, error in self.results:
            display(Markdown(f"- {name}: {status}"))
            if error:
                display(Markdown(f"  ```\n  {error}\n  ```"))

#%% 执行所有测试
if __name__ == "__main__":
    runner = TestRunner()
    
    test_cases = [
        ("设备控制测试", runner.test_device_control),
        ("计算精度测试", runner.test_calculation_accuracy),
        ("分块处理测试", runner.test_chunk_handling),
        ("设备转移测试", runner.test_device_transfer),
        ("边界情况测试", runner.test_edge_cases),
    ]
    
    for name, test_func in test_cases:
        runner.run_test(test_func, name)
    
    runner.show_results()

## 测试结果汇总

- 设备控制测试: ✅ PASSED

- 计算精度测试: ✅ PASSED

- 分块处理测试: ✅ PASSED

- 设备转移测试: ✅ PASSED

- 边界情况测试: ✅ PASSED