In [5]:
import torch

# ---------------------------
# 1. Fisher情報行列の近似計算
# ---------------------------
def compute_fisher_info(dist_constructor, params, num_samples=1000):
    """
    dist_constructor : 確率分布を生成する関数 (例: lambda θ1, θ2: Normal(mu, sigma))
    params : パラメータ (list of Tensors) [θ1, θ2, ...]
    num_samples : サンプル数
    
    戻り値: (n x n) の Fisher 情報行列 (torch.Tensor)
    """
    n_params = len(params)
    fisher = torch.zeros(n_params, n_params, dtype=params[0].dtype)
    
    for _ in range(num_samples):
        dist = dist_constructor(*params)
        x = dist.sample()
        logp = dist.log_prob(x)
        
        # パラメータに関する勾配(スコア)を計算
        grads = torch.autograd.grad(logp, params, create_graph=True)
        grads_vec = torch.stack([g.reshape(-1)[0] for g in grads])
        
        fisher += grads_vec.unsqueeze(1) @ grads_vec.unsqueeze(0)
    
    fisher /= num_samples
    return fisher

# ---------------------------
# 2. 三次形式 T_{i,j,k} の計算
# ---------------------------
def compute_third_order_tensor(dist_constructor, params, num_samples=1000):
    """
    T_{i,j,k} = E[ (∂_i log p)(∂_j log p)(∂_k log p) ]
    を Monte Carlo で近似する。
    
    戻り値: (n x n x n) テンソル
    """
    n_params = len(params)
    T = torch.zeros(n_params, n_params, n_params, dtype=params[0].dtype)
    
    for _ in range(num_samples):
        dist = dist_constructor(*params)
        x = dist.sample()
        logp = dist.log_prob(x)
        
        grads = torch.autograd.grad(logp, params, create_graph=True)
        grads_vec = [g.reshape(-1)[0] for g in grads]  # list of scalars
        
        # (∂_i log p)(∂_j log p)(∂_k log p) の期待値をサンプル平均で近似
        for i in range(n_params):
            for j in range(n_params):
                for k in range(n_params):
                    T[i, j, k] += grads_vec[i] * grads_vec[j] * grads_vec[k]
    
    T /= num_samples
    return T

# ---------------------------
# 3. Levi-Civita クリストフェル記号 (α=0)
# ---------------------------
def compute_christoffel_levi_civita(fisher, params):
    """
    fisher: (n x n) Fisher 計量
    params: パラメータ (list of Tensors)
    
    戻り値: Γ^k_{ij} (shape: (n, n, n))
    """
    n = fisher.shape[0]
    fisher_inv = torch.inverse(fisher)
    
    # g_{ij} の各パラメータ微分 (d_g[k,i,j] = ∂_{theta^k} g_{ij})
    d_g = torch.zeros(n, n, n, dtype=fisher.dtype)
    for i in range(n):
        for j in range(n):
            grad_list = torch.autograd.grad(fisher[i, j], params, retain_graph=True, allow_unused=True)
            for k, grad in enumerate(grad_list):
                if grad is not None:
                    d_g[k, i, j] = grad.reshape(-1)[0]
                else:
                    d_g[k, i, j] = 0.0
    
    # Γ^k_{ij} = 1/2 * Σ_l g^{kl} (∂_i g_{jl} + ∂_j g_{il} - ∂_l g_{ij})
    Gamma = torch.zeros(n, n, n, dtype=fisher.dtype)
    for k in range(n):
        for i in range(n):
            for j in range(n):
                tmp = 0.0
                for l in range(n):
                    tmp += fisher_inv[k, l] * (d_g[i, j, l] + d_g[j, i, l] - d_g[l, i, j])
                Gamma[k, i, j] = 0.5 * tmp
    
    return Gamma

# ---------------------------
# 4. α-接続のクリストフェル記号
# ---------------------------
def compute_christoffel_alpha(Gamma_lc, T, fisher_inv, alpha=1.0):
    """
    Gamma_lc : Levi-Civita クリストフェル記号 Γ^{(0)k}_{ij} (shape: (n, n, n))
    T        : 三次形式 T_{i,j,k} (shape: (n, n, n)) ※下付き3つ
    fisher_inv : g^{kl} (逆行列)
    alpha   : α値 (float)
    
    戻り値: Γ^{(α)k}_{ij} (shape: (n, n, n))
    """
    n = Gamma_lc.shape[0]
    
    # T^k_{ij} = g^{k,l} T_{l,i,j}
    T_up = torch.zeros_like(T)
    for l in range(n):
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    T_up[k, i, j] += fisher_inv[k, l] * T[l, i, j]
    
    # Γ^{(α)k}_{ij} = Γ^{(0)k}_{ij} + (α/2) T^k_{ij}
    Gamma_alpha = Gamma_lc.clone()
    Gamma_alpha += (alpha / 2.0) * T_up
    
    return Gamma_alpha

# ---------------------------
# 5. リーマン曲率テンソル・リッチテンソル・スカラー曲率の計算
# ---------------------------
def compute_riemann_curvature_alpha(Gamma_alpha, params):
    """
    R^{(α)l}_{ijk} = ∂_j Γ^{(α)l}_{ik} - ∂_k Γ^{(α)l}_{ij}
                     + Γ^{(α)l}_{jm}Γ^{(α)m}_{ik} - Γ^{(α)l}_{km}Γ^{(α)m}_{ij}
    戻り値: Riemann^{l,i,j,k} (shape: (n, n, n, n))
    """
    n = Gamma_alpha.shape[0]
    Riemann = torch.zeros(n, n, n, n, dtype=Gamma_alpha.dtype)
    
    for l in range(n):
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    # ∂_j Γ^{(α)l}_{i,k}
                    d1 = torch.autograd.grad(
                        Gamma_alpha[l, i, k], params[j],
                        retain_graph=True, allow_unused=True
                    )[0]
                    val_d1 = d1.reshape(-1)[0] if d1 is not None else 0.0
                    
                    # ∂_k Γ^{(α)l}_{i,j}
                    d2 = torch.autograd.grad(
                        Gamma_alpha[l, i, j], params[k],
                        retain_graph=True, allow_unused=True
                    )[0]
                    val_d2 = d2.reshape(-1)[0] if d2 is not None else 0.0
                    
                    sum_prod1 = 0.0
                    sum_prod2 = 0.0
                    for m in range(n):
                        sum_prod1 += Gamma_alpha[l, j, m] * Gamma_alpha[m, i, k]
                        sum_prod2 += Gamma_alpha[l, k, m] * Gamma_alpha[m, i, j]
                        
                    Riemann[l, i, j, k] = val_d1 - val_d2 + sum_prod1 - sum_prod2
    
    return Riemann

def compute_ricci_tensor_alpha(Riemann):
    """
    Ricci テンソル: R_{ij} = R^k_{ikj} 
    """
    n = Riemann.shape[0]
    Ricci = torch.zeros(n, n, dtype=Riemann.dtype)
    for i in range(n):
        for j in range(n):
            for k in range(n):
                Ricci[i, j] += Riemann[k, i, k, j]
    return Ricci

def compute_scalar_curvature_alpha(Riemann, fisher):
    """
    スカラー曲率: R = g^{ij} R_{ij}
    """
    Ricci = compute_ricci_tensor_alpha(Riemann)
    fisher_inv = torch.inverse(fisher)
    
    R_scalar = 0.0
    n = fisher.shape[0]
    for i in range(n):
        for j in range(n):
            R_scalar += fisher_inv[i, j] * Ricci[i, j]
    return R_scalar

# ---------------------------
#  使用例：自然パラメータ (θ1, θ2) による1次元正規分布
# ---------------------------
if __name__ == '__main__':
    # 1) θ2 < 0 となるよう適当な初期値を設定
    #    例: mu=0, sigma=1 となるようにすると
    #       σ^2 = 1 => -1/(2 θ2) = 1 => θ2 = -1/2
    #       μ = 0 => θ1 = 0
    theta1 = torch.tensor(0.0, requires_grad=True)   # => mu=0
    theta2 = torch.tensor(-0.5, requires_grad=True)  # => sigma=1

    # 2) 分布コンストラクタ: θ1, θ2 -> (mu, sigma)
    #    sigma^2 = -1/(2θ2), mu = θ1 * sigma^2
    def dist_constructor(t1, t2):
        sigma_sq = -1.0 / (2.0 * t2)      # > 0 になるように t2<0 が必要
        sigma = sigma_sq.sqrt()
        mu = t1 * sigma_sq
        return torch.distributions.Normal(mu, sigma)
    
    # 3) フィッシャー情報行列と三次形式
    num_samples = 1000
    fisher = compute_fisher_info(dist_constructor, [theta1, theta2], num_samples=num_samples)
    T_3 = compute_third_order_tensor(dist_constructor, [theta1, theta2], num_samples=num_samples)
    
    print("Fisher Information Matrix:")
    print(fisher)
    
    # 4) Levi-Civita (α=0) クリストフェル記号
    Gamma_lc = compute_christoffel_levi_civita(fisher, [theta1, theta2])
    
    # 5) α=1 (指数接続) のクリストフェル記号
    fisher_inv = torch.inverse(fisher)
    Gamma_alpha = compute_christoffel_alpha(Gamma_lc, T_3, fisher_inv, alpha=1.0)
    
    # 6) 曲率 (リーマン曲率テンソル, スカラー曲率)
    Riemann_alpha = compute_riemann_curvature_alpha(Gamma_alpha, [theta1, theta2])
    R_scalar_alpha = compute_scalar_curvature_alpha(Riemann_alpha, fisher)
    
    print("\n(α=1) Riemann Curvature Tensor:")
    print(Riemann_alpha)
    print("\n(α=1) Scalar Curvature:", R_scalar_alpha)


Fisher Information Matrix:
tensor([[ 1.1310, -0.0842],
        [-0.0842,  2.6330]], grad_fn=<DivBackward0>)

(α=1) Riemann Curvature Tensor:
tensor([[[[ 0.0000,  0.0252],
          [-0.0252,  0.0000]],

         [[ 0.0000,  0.4545],
          [-0.4545,  0.0000]]],


        [[[ 0.0000, -0.1389],
          [ 0.1389,  0.0000]],

         [[ 0.0000, -0.0160],
          [ 0.0160,  0.0000]]]], grad_fn=<CopySlices>)

(α=1) Scalar Curvature: tensor(0.2972, grad_fn=<AddBackward0>)
