In [1]:
import torch as th
import torch.nn as nn
from scipy.special import beta
from geoopt.manifolds import PoincareBall

# --- 1. 粘贴您提供的基础工具函数 (保持不变) ---

arsinh = th.asinh

@th.jit.script
def hemisphere_to_poincare(x):
    x_T, x_n1 = x[..., :-1], x[..., -1]
    y = x_T / (1 + x_n1.unsqueeze(-1))
    return y

@th.jit.script
def unidirectional_poincare_mlr(x, z_norm, z_unit, r, c):
    dtype = x.dtype
    device = x.device
    rc = th.sqrt(th.as_tensor(c, dtype=dtype, device=device))
    z_unit = z_unit.to(dtype)
    z_norm = z_norm.to(dtype)
    r = r.to(dtype)
    drcr = 2. * rc * r

    rcx = rc * x
    # 注意：这里假设 x 是 [Batch, Dim], 那么 cx2 应该是 [Batch, 1]
    cx2 = rcx.pow(2).sum(dim=-1, keepdim=True)

    # 关键点：matmul(rcx, z_unit)
    # 如果 x 是 [Batch, Dim], z_unit 是 [Dim] (或者 [1, Dim])
    # 结果应该是 [Batch, 1]
    # 这里需要确保 z_unit 的形状能被广播或正确矩阵乘法
    if z_unit.dim() == 1:
        inner_prod = th.matmul(rcx, z_unit).unsqueeze(-1)
    else:
        # 假设 z_unit 是 [Dim, 1] 或者与 x 对齐
        inner_prod = th.matmul(rcx, z_unit.t()) # 这里的具体实现取决于 z_unit 的定义形状，下面测试会调整

    return 2 * z_norm / rc * arsinh(
        (2. * inner_prod * th.cosh(drcr) - (1. + cx2) * th.sinh(drcr)) / th.clamp_min(1. - cx2, 1e-15))

# --- 2. 定义修正后的类 ---

class CorPolyHyperbolicCholeskyMetric(nn.Module):
    def __init__(self, n, jitter: float = 1e-5):
        super().__init__()
        self.n = n
        self.pball = PoincareBall(c=1.0, learnable=False)
        self.register_buffer('c', th.tensor(1.0))
        self.jitter = jitter

    def correlation_to_poincare_concate(self, C):
        # Step 1: Cholesky
        I = th.eye(C.shape[-1], device=C.device, dtype=C.dtype)
        L = th.linalg.cholesky(C + self.jitter * I)

        # Step 2: Map
        size = L.size()
        # 注意：这里原始代码逻辑可能有点特定于输入维度，为测试简单化，假设输入 [Batch, n, n]
        # product_dims 对于 [Batch, n, n] 应该是 1 (因为 size[1:-2] 是空的)
        if len(size) > 3:
            product_dims = th.prod(th.tensor(size[1:-2])).item()
        else:
            product_dims = 1.0
            
        dim_in = int(product_dims * size[-1] * (size[-1] - 1) / 2)
        beta_n = beta(dim_in / 2, 1 / 2)
        mapped_rows = []
        
        for i in range(1, L.shape[-2]):
            hs_r = L[..., i, :i+1]
            pball_r = hemisphere_to_poincare(hs_r)
            beta_ni = beta(i / 2, 1 / 2)
            # logmap0: map from ball to tangent space at 0 (Euclidean space)
            v_r = self.pball.logmap0(pball_r) * beta_n / beta_ni
            mapped_rows.append(v_r)

        # expmap0: map back to ball
        flat_vec = th.cat(mapped_rows, dim=-1).contiguous().view(size[0], -1)
        x = self.pball.expmap0(flat_vec)
        return x

    def undirectional_RMLR(self, C, weight_g, weight_v, gamma):
        """
        修正后的前向传播
        """
        # 1. 将相关矩阵映射到庞加莱球上的向量 x
        # 形状: [Batch, Dim_Feature]
        C_phi = self.correlation_to_poincare_concate(C)
        
        # 2. 修正 weight_v 的归一化
        # weight_v 应该是参数，形状通常是 [Dim_Feature] 或者 [Dim_Feature, 1]
        # 错误写法: weight_v.norm(dim=0) -> 如果 weight_v 是 [Dim]，dim=0 归一化整个向量得到标量，这本身没错
        # 但如果 weight_v 是 [Batch, Dim] (这通常不对，权重一般不带 Batch)，或者 [Dim_Out, Dim_In]
        
        # 我们假设 weight_v 是一个方向向量，形状为 [Dim_Feature]
        # 正确做法：对特征维度归一化
        
        # 为了兼容 matmul，通常希望 z_unit 是 [Dim_Feature]
        weight_v_unit = weight_v / weight_v.norm(p=2, dim=-1, keepdim=True).clamp_min(1e-15)
        
        # 调用 MLR
        # 注意：unidirectional_poincare_mlr 内部实现需要根据输入调整
        # 这里我们稍微调整下传入参数以匹配我们上面定义的简化版 mlr
        return unidirectional_poincare_mlr(C_phi, weight_g, weight_v_unit, gamma, c=self.c)

# --- 3. 验证脚本 ---

def test_normalization_logic():
    print("=== 测试 1: 验证归一化逻辑 ===")
    
    # 假设特征维度是 5
    dim = 5
    # 模拟一个权重向量
    weight_v = th.randn(dim) 
    print(f"原始 weight_v 形状: {weight_v.shape}")
    
    # 错误方式 (dim=0): 对 1D 向量来说 dim=0 就是求向量的模，对于 1D 向量来说 dim=0 和 dim=-1 是一样的
    # 但是！如果 weight_v 是为了多输出定义的 [Out_Dim, In_Dim]，dim=0 就是错误的
    
    # 让我们假设更复杂的情况：多类分类，权重通常是 [Num_Classes, Feature_Dim]
    # 或者对于单向 MLR，weight_v 只是一个方向向量 [Feature_Dim]
    
    # 场景 A: weight_v 是 [Feature_Dim] (1D)
    norm_0 = weight_v.norm(dim=0)
    norm_last = weight_v.norm(dim=-1)
    print(f"1D向量下: dim=0={norm_0:.4f}, dim=-1={norm_last:.4f} (应该相等)")
    
    # 场景 B: weight_v 误被定义为 [Batch, Dim] 或者 [Out, Dim]
    # 假设这里我们只想定义一个方向，所以它是 1D 的。
    # 修正的核心在于：keepdim=True
    
    weight_v_unit_fixed = weight_v / weight_v.norm(dim=-1, keepdim=True).clamp_min(1e-15)
    print(f"修正后单位向量模长: {weight_v_unit_fixed.norm().item():.4f} (应为 1.0)")
    print(f"修正后形状: {weight_v_unit_fixed.shape}")

def test_forward_backward():
    print("\n=== 测试 2: 完整的前向和反向传播 ===")
    
    n = 4 # 相关矩阵大小 4x4
    batch_size = 2
    
    # 计算映射后的特征维度: n*(n-1)/2 = 4*3/2 = 6
    feature_dim = int(n * (n - 1) / 2)
    
    model = CorPolyHyperbolicCholeskyMetric(n=n)
    
    # 模拟输入：正定相关矩阵 (通过随机矩阵生成)
    A = th.randn(batch_size, n, n)
    # 构造相关矩阵: A*A^T 然后归一化对角线
    cov = th.bmm(A, A.transpose(1, 2))
    d = th.sqrt(th.diagonal(cov, dim1=-2, dim2=-1))
    outer_d = th.bmm(d.unsqueeze(2), d.unsqueeze(1))
    C = cov / outer_d
    C.requires_grad = True # 我们要测试梯度能否传回 C
    
    # 定义 MLR 的参数
    # weight_g (标量或与 batch 对齐，通常是 margin 相关的模长增益)
    weight_g = nn.Parameter(th.tensor(5.0))
    
    # weight_v (方向向量，维度应等于 feature_dim)
    weight_v = nn.Parameter(th.randn(feature_dim))
    
    # gamma (margin)
    gamma = nn.Parameter(th.tensor(0.5))
    
    print(f"输入 C 形状: {C.shape}")
    print(f"参数 weight_v 形状: {weight_v.shape}")
    
    # --- 前向传播 ---
    try:
        output = model.undirectional_RMLR(C, weight_g, weight_v, gamma)
        print(f"前向传播成功。输出形状: {output.shape}")
        # 输出应该是 [Batch, 1] 或 [Batch, Batch] 取决于具体 MLR 实现，这里预期是 [Batch, 1] 表示 logit
        
    except Exception as e:
        print(f"前向传播失败: {e}")
        import traceback
        traceback.print_exc()
        return

    # --- 反向传播 ---
    try:
        loss = output.mean()
        loss.backward()
        print("反向传播成功。")
        
        # 检查梯度
        if C.grad is not None:
            print(f"输入 C 的梯度范数: {C.grad.norm().item():.4f}")
        else:
            print("错误: C 没有梯度")
            
        if weight_v.grad is not None:
            print(f"参数 weight_v 的梯度范数: {weight_v.grad.norm().item():.4f}")
        else:
            print("错误: weight_v 没有梯度")
            
    except Exception as e:
        print(f"反向传播失败: {e}")

if __name__ == "__main__":
    test_normalization_logic()
    test_forward_backward()


=== 测试 1: 验证归一化逻辑 ===
原始 weight_v 形状: torch.Size([5])
1D向量下: dim=0=2.8411, dim=-1=2.8411 (应该相等)
修正后单位向量模长: 1.0000 (应为 1.0)
修正后形状: torch.Size([5])

=== 测试 2: 完整的前向和反向传播 ===
输入 C 形状: torch.Size([2, 4, 4])
参数 weight_v 形状: torch.Size([6])
前向传播成功。输出形状: torch.Size([2, 1])
反向传播成功。
输入 C 的梯度范数: 27.0566
参数 weight_v 的梯度范数: 2.6430
