In [1]:
import torch
import numpy as np
from affine_loss import AffineLoss

def get_homography(tx, ty, scale=1.0):
    """构造简单的单应矩阵 (Im -> Large)"""
    # 假设 Img 仅仅是 Large 的一个缩放平移
    # P_large = S * P_img + T
    H = torch.eye(3)
    H[0, 0] = scale
    H[1, 1] = scale
    H[0, 2] = tx
    H[1, 2] = ty
    return torch.inverse(H).unsqueeze(0) # 因为输入要求是 Large -> Img (Hs)

def test_affine_loss_v2():
    print("=== 开始测试 AffineLoss (v2: Grid Correspondence) ===")
    
    # 基础配置
    H, W = 64, 64
    steps = 3
    loss_fn = AffineLoss(img_size=(H, W), grid_stride=16, reg_weight=0.0) # 暂时关闭正则以便验证距离
    
    # ----------------------------------------------------------------
    # Case 1: 完美匹配 (Identity Transformation)
    # Img A 和 Img B 在大图中是完全同一个位置
    # M_a_b 是单位阵
    # 预测 delta 为 0
    # ----------------------------------------------------------------
    print("\n[Case 1] 完美重合测试:")
    Hs_a = torch.eye(3).unsqueeze(0) # Large -> Img (Identity)
    Hs_b = torch.eye(3).unsqueeze(0)
    M_a_b = torch.eye(2, 3).unsqueeze(0) # A -> B (Identity)
    
    delta_zeros = torch.zeros(1, steps, 2, 3)
    
    loss_1 = loss_fn(delta_zeros, Hs_a, Hs_b, M_a_b)
    print(f"Loss (Expect 0.0): {loss_1.item():.6f}")
    assert loss_1.item() < 1e-6

    # ----------------------------------------------------------------
    # Case 2: 纯平移验证 (Translation Check)
    # A 在 (0,0), B 在 (10, 10) (Large coords)
    # M_a_b 表示 A -> B 的变换。
    # 这里的逻辑是: coords_b = M_a_b @ coords_a
    # 如果 A 在 (0,0), B 在 (10,10), 意味着 Img A 对应的 Large 点是 P, Img B 对应的 Large 点也是 P (假设内容重叠)
    # 
    # 等一下，让我们理清 M_a_b 的定义。
    # 题目设定: Img A 和 Img B 来自 Large A 和 Large B。
    # Large B = M_a_b(Large A)。
    # 假设 Large A 的 (0,0) 被映射到了 Large B 的 (10, 10) (即 M 是平移 +10)
    # 并且 Img A 在 Large A 的 (0,0) 处裁切，Img B 在 Large B 的 (10,10) 处裁切。
    # 那么 Img A 的内容应该和 Img B 的内容完全一样。
    #
    # 在这种情况下:
    # 1. coords_a (Large A) = (0,0) 等
    # 2. coords_b (Large B target) = M_a_b @ coords_a = (10, 10) 等
    # 3. 网络预测 current_affine 应该把 coords_a 映射到 coords_b。
    #    即 current_affine 应该预测为 "平移 +10"。
    #    如果网络预测 delta=0 (Identity), 那么 pred = (0,0)。
    #    距离 = ||(0,0) - (10,10)|| = sqrt(200) ≈ 14.14
    # ----------------------------------------------------------------
    print("\n[Case 2] 平移目标验证:")
    Hs_a = torch.eye(3).unsqueeze(0) # A 在 Large A 原点
    Hs_b = torch.eye(3).unsqueeze(0) # B 在 Large B 原点
    
    # 真值 M_a_b: 平移 (+10, +10)
    M_gt = torch.tensor([[1.0, 0.0, 10.0],
                         [0.0, 1.0, 10.0]]).unsqueeze(0)
    
    # 预测 delta: 全 0 (Identity) -> 此时预测位置与真实位置差 (10, 10)
    loss_2 = loss_fn(delta_zeros, Hs_a, Hs_b, M_gt)
    
    expected_dist = np.sqrt(10**2 + 10**2)
    print(f"Loss (Expect ~{expected_dist:.4f}): {loss_2.item():.4f}")
    assert abs(loss_2.item() - expected_dist) < 1e-3

    # ----------------------------------------------------------------
    # Case 3: 预测正确验证 (Perfect Prediction)
    # 接 Case 2, 但这次网络预测出了 delta = (+10, +10)
    # ----------------------------------------------------------------
    print("\n[Case 3] 预测修正验证:")
    delta_correct = torch.zeros(1, steps, 2, 3)
    # 第一步就修正到位
    delta_correct[0, 0, 0, 2] = 10.0
    delta_correct[0, 0, 1, 2] = 10.0
    
    loss_3 = loss_fn(delta_correct, Hs_a, Hs_b, M_gt)
    print(f"Loss (Expect 0.0): {loss_3.item():.6f}")
    assert loss_3.item() < 1e-5
    
    # ----------------------------------------------------------------
    # Case 4: Mask 功能验证 (No Overlap)
    # A 和 B 完全不重叠。
    # Img A 映射到 Large B 后，得到 coords_b。
    # 这些 coords_b 再投影回 Img B 时，如果落在了画面外，Mask 应该为 0。
    # ----------------------------------------------------------------
    print("\n[Case 4] Mask 遮挡验证:")
    # Img B 在 Large B 的 (1000, 1000) 处
    # 我们用 get_homography(tx, ty) -> 实际上它是 Large -> Img 的逆
    # 这里的参数定义可能有点绕，我们直接构造 Hs_b
    # Hs_b * P_large = P_img
    # P_img = P_large - Offset
    # => Hs_b = [[1, 0, -1000], [0, 1, -1000], [0, 0, 1]]
    Hs_b_offset = torch.tensor([[1.0, 0.0, -1000.0],
                                [0.0, 1.0, -1000.0],
                                [0.0, 0.0, 1.0]]).unsqueeze(0)
    
    # coords_a 在 (0~64, 0~64)
    # M_gt 还是平移 +10 -> coords_b 在 (10~74, 10~74)
    # 投影回 Img B: (10-1000, 74-1000) -> (-990, -926)
    # 全部在 Img B 范围 (0~64) 之外 -> Mask 全 0
    # Loss 应该被处理为 0 (或者 epsilon 处理后的极小值)
    
    loss_4 = loss_fn(delta_zeros, Hs_a, Hs_b_offset, M_gt)
    print(f"Loss (Expect 0.0 due to full masking): {loss_4.item():.6f}")
    assert loss_4.item() < 1e-6

    # ----------------------------------------------------------------
    # Case 5: 正则化验证
    # ----------------------------------------------------------------
    print("\n[Case 5] 正则化验证:")
    loss_fn_reg = AffineLoss(img_size=(H, W), grid_stride=16, reg_weight=1.0)
    
    # 预测一个恒等变换，但带有巨大的线性扭曲
    # M_pred = [[2, 0, 0], [0, 2, 0]] (Scale 2x)
    delta_distort = torch.zeros(1, steps, 2, 3)
    delta_distort[:, :, 0, 0] = 1.0 # 1+1=2
    delta_distort[:, :, 1, 1] = 1.0
    
    # 让距离 Loss 为 0 (M_gt 也是 Scale 2x, 这样只剩下 Reg Loss)
    # 或者简单点，让 M_gt = I, 那么距离 Loss 会很大，我们只看 Loss 是否包含 Reg 分量
    # 简单测试：计算单纯的 Reg Loss
    # || [[2,0],[0,2]] - I ||_F = ||I||_F = sqrt(1+1) = 1.414
    
    # 我们需要手动分离距离损失影响，这里简单观测 Loss 是否显著大于距离 Loss
    # 设 M_gt = Identity
    # Pred = Scale 2x
    # Dist Loss > 0 (因为点位置变了)
    # Reg Loss = 1.414
    # Total > 1.414
    
    loss_5 = loss_fn_reg(delta_distort, Hs_a, Hs_b, M_a_b)
    print(f"Loss with distortion (Should be large): {loss_5.item():.4f}")
    
    assert loss_5.item() > 1.4

    print("\n>> 所有测试通过！")

if __name__ == "__main__":
    test_affine_loss_v2()

=== 开始测试 AffineLoss (v2: Grid Correspondence) ===

[Case 1] 完美重合测试:
Loss (Expect 0.0): 0.000000

[Case 2] 平移目标验证:
Loss (Expect ~14.1421): 14.1421

[Case 3] 预测修正验证:
Loss (Expect 0.0): 0.000000

[Case 4] Mask 遮挡验证:
Loss (Expect 0.0 due to full masking): 0.000000

[Case 5] 正则化验证:
Loss with distortion (Should be large): 85.7025

>> 所有测试通过！


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from consist_loss import ConsistLoss

def visualize_cycle(loss_fn, delta_a, delta_b, title="Cycle Consistency Visualization"):
    """
    可视化循环一致性效果
    绘制: 原始网格(绿) -> 中间变换(红, 仅示意) -> 循环回归(蓝)
    """
    loss_fn.eval()
    B = delta_a.shape[0]
    H, W = loss_fn.H, loss_fn.W
    
    # 获取最后一的时间步的累积矩阵
    steps = delta_a.shape[1]
    M_ab = torch.eye(2, 3).unsqueeze(0).repeat(B, 1, 1)
    M_ba = torch.eye(2, 3).unsqueeze(0).repeat(B, 1, 1)
    
    for t in range(steps):
        M_ab = M_ab + delta_a[:, t]
        M_ba = M_ba + delta_b[:, t]
        
    M_ab_3x3 = loss_fn._to_homogeneous_matrix(M_ab)
    M_ba_3x3 = loss_fn._to_homogeneous_matrix(M_ba)
    
    # 获取基础点
    points_orig = loss_fn.base_grid # (1, 3, N)
    
    # 计算中间状态 (A->B)
    points_mid = torch.bmm(M_ab_3x3, points_orig) # (B, 3, N)
    
    # 计算循环回归状态 (A->B->A)
    # Cycle = M_ba @ M_ab
    points_cycle = torch.bmm(M_ba_3x3, points_mid) # (B, 3, N)
    
    # 绘图 (取 Batch 0)
    p_orig = points_orig[0].detach().numpy()
    p_mid = points_mid[0].detach().numpy()
    p_cycle = points_cycle[0].detach().numpy()
    
    plt.figure(figsize=(12, 6))
    
    # 子图 1: 整体视图
    plt.subplot(1, 2, 1)
    plt.title(f"{title}\nGreen: Origin, Red: A->B, Blue: A->B->A")
    plt.scatter(p_orig[0], p_orig[1], c='g', marker='o', label='Origin (A)', s=40, alpha=0.6)
    plt.scatter(p_mid[0], p_mid[1], c='r', marker='x', label='Transformed (B)', s=30, alpha=0.4)
    plt.scatter(p_cycle[0], p_cycle[1], c='b', marker='+', label='Reconstructed (A)', s=60)
    plt.legend()
    plt.xlim(-W*0.5, W*1.5)
    plt.ylim(-H*0.5, H*1.5)
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.gca().invert_yaxis() # 图像坐标系 y 向下
    
    # 子图 2: 误差放大视图 (只画 Origin 和 Recon)
    plt.subplot(1, 2, 2)
    plt.title("Reconstruction Error Detail\n(Green vs Blue should align)")
    plt.scatter(p_orig[0], p_orig[1], c='g', marker='o', s=80, alpha=0.5, label='Origin')
    plt.scatter(p_cycle[0], p_cycle[1], c='b', marker='.', s=50, label='Cycle Result')
    
    # 画误差线
    for i in range(p_orig.shape[1]):
        plt.plot([p_orig[0, i], p_cycle[0, i]], [p_orig[1, i], p_cycle[1, i]], 'k-', alpha=0.2)
        
    plt.legend()
    plt.grid(True)
    plt.gca().invert_yaxis()
    
    plt.tight_layout()
    plt.savefig('cycle_vis.png')
    print("可视化图像已保存为 'cycle_vis.png'")
    plt.close()

def test_consist_loss():
    print("=== 开始测试 ConsistLoss (循环一致性) ===")
    
    H, W = 128, 128
    steps = 3
    loss_fn = ConsistLoss(img_size=(H, W), grid_stride=32)
    
    # -----------------------------------------------------------
    # Case 1: 完美逆平移 (Perfect Inverse Translation)
    # A->B: x+10, B->A: x-10
    # -----------------------------------------------------------
    print("\n[Case 1] 完美逆平移测试:")
    delta_a = torch.zeros(1, steps, 2, 3)
    delta_b = torch.zeros(1, steps, 2, 3)
    
    # Step 0: A->B 移动 +10, B->A 移动 -10
    delta_a[0, 0, 0, 2] = 10.0
    delta_b[0, 0, 0, 2] = -10.0
    
    loss_1 = loss_fn(delta_a, delta_b)
    print(f"Loss (Expect 0.0): {loss_1.item():.6f}")
    assert loss_1.item() < 1e-5
    
    # 可视化 Case 1
    visualize_cycle(loss_fn, delta_a, delta_b, title="Case 1: Perfect Inverse (+10 / -10)")

    # -----------------------------------------------------------
    # Case 2: 完美逆缩放 (Perfect Inverse Scaling)
    # A->B: Scale 2.0, B->A: Scale 0.5
    # -----------------------------------------------------------
    print("\n[Case 2] 完美逆缩放测试:")
    # 构造单位阵作为基底
    # M_ab = [[2, 0, 0], [0, 2, 0]] => Delta = [[1, 0, 0], [0, 1, 0]] (因为 Base 是 I)
    # M_ba = [[0.5, 0, 0], [0, 0.5, 0]] => Delta = [[-0.5, 0, 0], [0, -0.5, 0]]
    
    delta_a_scale = torch.zeros(1, steps, 2, 3)
    delta_b_scale = torch.zeros(1, steps, 2, 3)
    
    # A->B Delta (1.0)
    delta_a_scale[0, 0, 0, 0] = 1.0
    delta_a_scale[0, 0, 1, 1] = 1.0
    
    # B->A Delta (-0.5)
    delta_b_scale[0, 0, 0, 0] = -0.5
    delta_b_scale[0, 0, 1, 1] = -0.5
    
    loss_2 = loss_fn(delta_a_scale, delta_b_scale)
    print(f"Loss (Expect 0.0): {loss_2.item():.6f}")
    assert loss_2.item() < 1e-5

    # -----------------------------------------------------------
    # Case 3: 错误方向 (Mismatch)
    # A->B: x+10, B->A: x+10 (方向搞反了)
    # 循环结果应该是 x+20，误差为 20
    # -----------------------------------------------------------
    print("\n[Case 3] 错误方向测试 (累积误差):")
    delta_a_err = torch.zeros(1, steps, 2, 3)
    delta_b_err = torch.zeros(1, steps, 2, 3)
    
    delta_a_err[0, 0, 0, 2] = 10.0
    delta_b_err[0, 0, 0, 2] = 10.0 # 应该是 -10 才是逆
    
    loss_3 = loss_fn(delta_a_err, delta_b_err)
    
    # 误差分析:
    # A->B (+10), B->A (+10) => Cycle (+20)
    # 原点 P=0, Cycle P=20, Dist=20
    print(f"Loss (Expect ~20.0): {loss_3.item():.4f}")
    assert abs(loss_3.item() - 20.0) < 1e-3

    # -----------------------------------------------------------
    # Case 4: 旋转一致性 (Rotation)
    # A->B: Rot 90 deg, B->A: Rot -90 deg
    # -----------------------------------------------------------
    print("\n[Case 4] 旋转一致性测试:")
    
    # 90度旋转矩阵: [[0, -1], [1, 0]]
    # Delta (相对于 I): [[-1, -1], [1, -1]]
    rot_90_delta = torch.tensor([[-1., -1., 0.], [1., -1., 0.]])
    
    # -90度旋转矩阵: [[0, 1], [-1, 0]]
    # Delta (相对于 I): [[-1, 1], [-1, -1]]
    rot_neg90_delta = torch.tensor([[-1., 1., 0.], [-1., -1., 0.]])
    
    delta_a_rot = torch.zeros(1, steps, 2, 3)
    delta_b_rot = torch.zeros(1, steps, 2, 3)
    
    delta_a_rot[0, 0] = rot_90_delta
    delta_b_rot[0, 0] = rot_neg90_delta
    
    loss_4 = loss_fn(delta_a_rot, delta_b_rot)
    print(f"Loss (Expect ~0.0): {loss_4.item():.6f}")
    assert loss_4.item() < 1e-4
    
    print("\n>> 所有测试通过！查看 'cycle_vis.png' 确认可视化效果。")

if __name__ == "__main__":
    test_consist_loss()

=== 开始测试 ConsistLoss (循环一致性) ===

[Case 1] 完美逆平移测试:
Loss (Expect 0.0): 0.000000
可视化图像已保存为 'cycle_vis.png'

[Case 2] 完美逆缩放测试:
Loss (Expect 0.0): 0.000000

[Case 3] 错误方向测试 (累积误差):
Loss (Expect ~20.0): 20.0000

[Case 4] 旋转一致性测试:
Loss (Expect ~0.0): 0.000000

>> 所有测试通过！查看 'cycle_vis.png' 确认可视化效果。


In [3]:
import torch

def invert_affine_matrix(M):
    """
    计算仿射变换矩阵的逆矩阵 (从 B -> A)。
    
    参数:
        M (torch.Tensor): 形状为 (2, 3) 或 (B, 2, 3) 的仿射变换矩阵。
                          描述从 A -> B 的变换。
    
    返回:
        M_inv (torch.Tensor): 形状与输入相同的逆变换矩阵。
                              描述从 B -> A 的变换。
    """
    if M.dim() == 2:
        M = M.unsqueeze(0)  # 变为 (1, 2, 3) 处理
        is_batch = False
    else:
        is_batch = True
        
    batch_size = M.shape[0]
    device = M.device
    dtype = M.dtype

    # 1. 构建齐次坐标矩阵 (B, 3, 3)
    # 底部添加一行 [0, 0, 1]
    bottom_row = torch.tensor([0, 0, 1], device=device, dtype=dtype).view(1, 1, 3)
    bottom_row = bottom_row.expand(batch_size, -1, -1) # (B, 1, 3)
    
    M_homogeneous = torch.cat([M, bottom_row], dim=1)  # (B, 3, 3)
    
    # 2. 计算逆矩阵
    try:
        M_inv_homogeneous = torch.linalg.inv(M_homogeneous)
    except RuntimeError as e:
        print("错误: 矩阵不可逆 (可能存在奇异矩阵)。")
        raise e
        
    # 3. 取前两行作为结果 (B, 2, 3)
    M_inv = M_inv_homogeneous[:, :2, :]
    
    if not is_batch:
        return M_inv.squeeze(0)
    
    return M_inv

# --- 测试代码 ---
if __name__ == "__main__":
    # 示例 1: 单个矩阵 (旋转90度 + 平移)
    # 旋转 90度: [[0, -1], [1, 0]], 平移: [10, 20]
    M = torch.tensor([
        [0.0, -1.0, 10.0],
        [1.0,  0.0, 20.0]
    ])
    
    print(f"原矩阵 M (2,3):\n{M}")
    
    M_inv = invert_affine_matrix(M)
    print(f"\n逆矩阵 M_inv (2,3):\n{M_inv}")
    
    # 验证: M * M_inv 应该接近 Identity (但在仿射空间中)
    # 这里我们通过矩阵乘法验证逻辑：
    # 先补全再相乘应该得到单位阵
    pad = torch.tensor([[0,0,1]])
    M_3x3 = torch.cat([M, pad], dim=0)
    M_inv_3x3 = torch.cat([M_inv, pad], dim=0)
    print(f"\n验证 (M * M_inv):\n{torch.mm(M_3x3, M_inv_3x3)}") 

    # 示例 2: Batch 输入 (N, 2, 3)
    M_batch = torch.stack([M, M])
    M_inv_batch = invert_affine_matrix(M_batch)
    print(f"\nBatch 输出形状: {M_inv_batch.shape}")

原矩阵 M (2,3):
tensor([[ 0., -1., 10.],
        [ 1.,  0., 20.]])

逆矩阵 M_inv (2,3):
tensor([[ -0.,   1., -20.],
        [ -1.,   0.,  10.]])

验证 (M * M_inv):
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

Batch 输出形状: torch.Size([2, 2, 3])
