In [25]:
import torch
import time

def loop_based_rank_by_appearance(x: torch.Tensor) -> torch.Tensor:
    """forループを使用して各行の出現順ランクを計算する関数"""
    if x.shape[0] == 0: # B=0 の場合
        # B=0 の場合、結果のテンソルも最初の次元が0になるべき
        # N が 0 でない場合、形状は (0, N)
        # N が 0 の場合、形状は (0, 0)
        # どちらの場合も、x と同じ形状で最初の次元が0の空のテンソルを返す
        return torch.empty((0, x.shape[1]), dtype=torch.int64, device=x.device)
    if x.shape[1] == 0: # N=0 の場合 (B > 0)
        # 各行は空のテンソルに対するuniqueの結果（これも空のinverse_indices）になる
        # stackすると(B,0)のテンソルになる
        return torch.empty_like(x, dtype=torch.int64)

    output_rows = []
    for row in x:
        _uniques, inverse_indices = torch.unique(row, sorted=False, return_inverse=True)
        output_rows.append(inverse_indices)
    return torch.stack(output_rows)

def vectorized_row_unique_rank_by_appearance(x: torch.Tensor) -> torch.Tensor:
    """ベクトル化された手法で各行の出現順ランクを計算する関数"""
    B, N = x.shape
    device = x.device

    if B == 0: # B=0 の場合のガード処理
        return torch.empty((0, N), dtype=torch.int64, device=device)
    if N == 0: # N=0 の場合のガード処理 (B > 0)
        return torch.empty_like(x, dtype=torch.int64)

    # ステージ1: 各要素の値がその行で最初に出現する列インデックスを計算 (mci)
    cols_broadcast = torch.arange(N, device=device).view(1, 1, N).expand(B, N, N)
    eq_mask = (x.unsqueeze(2) == x.unsqueeze(1))

    sentinel = N
    masked_cols = torch.where(eq_mask, cols_broadcast, sentinel)
    mci = masked_cols.min(dim=2).values

    # ステージ2: mci テンソルの各行に対してデンスランクを計算
    A = mci
    S_values, S_indices = A.sort(dim=1)

    R_sorted = torch.zeros_like(A)
    R_sorted[:, 1:] = (S_values[:, 1:] != S_values[:, :-1]).cumsum(dim=1)

    R_final = torch.empty_like(A)
    R_final.scatter_(dim=1, index=S_indices, src=R_sorted)

    return R_final

def run_comparison(test_cases, repetitions=3, device_str='cpu'):
    """指定されたテストケースで速度比較を実行する関数"""
    print(f"テストデバイス: {device_str}")
    print("-" * 80) # Adjusted width for clarity
    print(f"{'Shape (B, N)':<15} | {'Loop Time (ms)':<18} | {'Vectorized Time (ms)':<22} | {'Speedup':<10} | {'Outputs Equal':<15}")
    print("-" * 80) # Adjusted width

    for B, N, val_low, val_high in test_cases:
        x = torch.empty((B, N), dtype=torch.int64, device=device_str) # Initialize first
        if B > 0 and N > 0:
            x = torch.randint(val_low, val_high, (B, N), device=device_str, dtype=torch.int64)

        # For-loop版のタイミング
        start_time_loop = time.perf_counter()
        for _ in range(repetitions):
            res_loop = loop_based_rank_by_appearance(x.clone())
        end_time_loop = time.perf_counter()
        time_loop = (end_time_loop - start_time_loop) * 1000 / repetitions

        # ベクトル化版のタイミング
        start_time_vec = time.perf_counter()
        for _ in range(repetitions):
            res_vectorized = vectorized_row_unique_rank_by_appearance(x.clone())
        end_time_vec = time.perf_counter()
        time_vectorized = (end_time_vec - start_time_vec) * 1000 / repetitions

        outputs_equal = torch.equal(res_loop, res_vectorized)

        speedup_factor = time_loop / time_vectorized if time_vectorized > 0.000001 else float('inf')
        
        # if outputs_equal == False:
        #     print(res_loop)
        #     print("===============================")
        #     print(res_vectorized)

        print(f"({B:<5}, {N:<5}) | {time_loop:<18.4f} | {time_vectorized:<22.4f} | {speedup_factor:<10.2f}x | {str(outputs_equal):<15}")

if __name__ == '__main__':
    test_configurations = [
        (10, 5, 0, 5),
        (1000, 10, 0, 10),
        (10000, 10, 0, 10),
        (100, 50, 0, 25),
        (10, 200, 0, 50),
        (100, 200, 0, 50),
        (1000, 0, 0, 1), # N=0
        (0, 10, 0, 1),   # B=0
        (0, 0, 0, 1),     # B=0, N=0
        (7396, 9, 0, 10000)
    ]

    run_comparison(test_configurations, repetitions=3, device_str='cpu')

    if torch.cuda.is_available():
        print("\n" + "=" * 80) # Adjusted width
        print("GPU でのテスト (利用可能な場合)")
        print("=" * 80) # Adjusted width
        run_comparison(test_configurations, repetitions=3, device_str='cuda')
    else:
        print("\nCUDA (GPU) は利用できません。")

テストデバイス: cpu
--------------------------------------------------------------------------------
Shape (B, N)    | Loop Time (ms)     | Vectorized Time (ms)   | Speedup    | Outputs Equal  
--------------------------------------------------------------------------------
(10   , 5    ) | 0.5014             | 7.9522                 | 0.06      x | False          
(1000 , 10   ) | 10.7448            | 26.5029                | 0.41      x | False          
(10000, 10   ) | 104.1470           | 123.4626               | 0.84      x | False          
(100  , 50   ) | 0.9738             | 28.1945                | 0.03      x | False          
(10   , 200  ) | 0.1596             | 29.1388                | 0.01      x | False          
(100  , 200  ) | 1.4314             | 28.8254                | 0.05      x | False          
(1000 , 0    ) | 0.0056             | 0.0053                 | 1.05      x | True           
(0    , 10   ) | 0.0040             | 0.0031                 | 1.26      x | True