In [None]:
import torch
import torch.nn as nn

class GEMUnrolledEM(torch.nn.Module):
    def __init__(self, num_nodes, num_layers=5, eta=0.1, use_detach=False):
        """
        GEM Unrolled EM Network
        Args:
            num_nodes: 图节点数
            num_layers: 展开层数 K
            eta: M-step 学习率
            use_detach: 是否在 M-step 中 detach gradient
        """
        super().__init__()
        self.num_nodes = num_nodes
        self.num_layers = num_layers
        self.eta = eta
        self.use_detach = use_detach
        
        # theta 是生成 W 的参数
        self.theta = nn.Parameter(torch.rand(num_nodes, num_nodes))
        
    def generate_W(self, theta):
        """用 theta 生成图权重 W, 可以是非线性映射"""
        # 例如简单映射 + softplus 保证非负
        W = torch.nn.functional.softplus(theta)
        # 对角置零，避免自环
        W = W - torch.diag(torch.diag(W))
        return W

    def compute_laplacian(self, W):
        """计算图 Laplacian"""
        D = torch.diag(W.sum(dim=1))
        L = D - W
        return L

    def E_step(self, L, x, eps=1e-3):
        """E-step: 计算潜在信号 z_hat"""
        N = L.shape[0]
        L_reg = L + eps * torch.eye(N, device=L.device)
        cov = torch.linalg.inv(L_reg)
        cov = (cov + cov.T) / 2
        z_hat = cov @ x  # 简化版后验均值
        return z_hat

    def M_step(self, theta, z_hat):
        """
        M-step: 用梯度更新 theta
        Q 函数依赖 theta 生成的 W
        """
        W = self.generate_W(theta)
        # 构造示例损失 Q (这里假设 Q = ||W - z_hat||^2)
        Q = ((W - z_hat) ** 2).sum()

        # 对 theta 求梯度
        grad = torch.autograd.grad(Q, theta, create_graph=not self.use_detach)[0]

        if self.use_detach:
            grad = grad.detach()

        # 更新 theta
        theta_next = theta - self.eta * grad
        return theta_next

    def forward(self, x):
        """
        x: [num_nodes, feature_dim] 输入
        """
        theta = self.theta
        for k in range(self.num_layers):
            # 1. 生成 W
            W = self.generate_W(theta)

            # 2. 计算 Laplacian
            L = self.compute_laplacian(W)

            # 3. E-step
            z_hat = self.E_step(L, x)

            # 4. M-step
            theta = self.M_step(theta, z_hat)

        # 返回最终生成的 W 和后验均值 z_hat
        W_final = self.generate_W(theta)
        return W_final, z_hat

In [None]:
num_nodes = 6
feature_dim = 1
x = torch.rand(num_nodes, feature_dim)

model = GEMUnrolledEM(num_nodes=num_nodes, num_layers=3, eta=0.1, use_detach=False)
theta_final, z_hat_final = model(x)

# 定义最终任务损失，比如目标信号 y
y = torch.rand(num_nodes, feature_dim)
loss = ((z_hat_final - y)**2).sum()

# 端到端训练
loss.backward()  # 自动反向传播到 W

假设 M-step 是通过最优化得到参数 θ*：
$$
\theta^* = \arg\max_\theta Q(\theta, z)
$$
M-step 的最优条件（零梯度条件）：
$$
\nabla_\theta Q(\theta^*, z) = 0
$$
利用 隐函数求导，可以得到 θ* 对 z 的梯度：
$$
\frac{d \theta^*}{dz} = - (\nabla^2_{\theta\theta} Q(\theta^*, z))^{-1} \nabla^2_{\theta z} Q(\theta^*, z)
$$
	•	forward：求 θ*（M-step 输出）
	•	backward：用 torch.autograd.functional.hvp 或 Neumann series 近似 Hessian inverse

In [None]:
import torch

class GEMImplicitEM(torch.nn.Module):
    def __init__(self, num_nodes, num_layers=5, eta=0.1):
        super().__init__()
        self.num_nodes = num_nodes
        self.num_layers = num_layers
        self.eta = eta
        self.W = torch.nn.Parameter(torch.rand(num_nodes, num_nodes))
    
    def compute_laplacian(self, W):
        D = torch.diag(W.sum(dim=1))
        L = D - W
        return L
    
    def E_step(self, L, x, eps=1e-3):
        N = L.shape[0]
        L_reg = L + eps * torch.eye(N, device=L.device)
        cov = torch.linalg.inv(L_reg)
        cov = (cov + cov.T)/2
        z_hat = cov @ x
        return z_hat
    
    def M_step(self, theta, z_hat, n_iter=10):
        """
        M-step using implicit differentiation.
        theta: [N, N]
        z_hat: [N, F]
        """
        # forward: gradient update iteration (K steps)
        theta_k = theta
        for _ in range(n_iter):
            Q = ((theta_k - z_hat)**2).sum()
            grad = torch.autograd.grad(Q, theta_k, create_graph=True)[0]
            theta_k = theta_k - self.eta * grad  # forward pass

        theta_star = theta_k.detach()  # detach to avoid storing all intermediate grads

        # backward: implicit differentiation
        theta_star.requires_grad_(True)
        Q_star = ((theta_star - z_hat)**2).sum()

        # Hessian-vector product function
        def Hv(v):
            return torch.autograd.grad(
                torch.autograd.grad(Q_star, theta_star, create_graph=True)[0] @ v,
                theta_star,
                retain_graph=True
            )[0]

        # Solve (I - eta * H)^{-1} * grad_output approximately using Neumann series
        def neumann_solve(grad_out, K=5):
            v = grad_out
            approx = v.clone()
            for _ in range(K):
                Hv_v = Hv(v)
                v = eta * Hv_v
                approx = approx + v
            return approx

        return theta_star, neumann_solve

    def forward(self, x):
        theta = self.W
        for k in range(self.num_layers):
            L = self.compute_laplacian(theta)
            z_hat = self.E_step(L, x)
            theta, neumann_solver = self.M_step(theta, z_hat)
        return theta, z_hat

In [None]:
import torch

class LaplacianPseudoInverse:
    def __init__(self, method='neumann', eps=1e-3, K=5, rank=None, maxiter=50):
        """
        Laplacian 广义逆计算器
        Args:
            method: 'neumann' | 'cg' | 'lowrank'
            eps: 正则化参数，避免奇异
            K: Neumann series 或迭代次数
            rank: low-rank 截断特征数量
            maxiter: CG 最大迭代次数
        """
        assert method in ['neumann', 'cg', 'lowrank'], "method must be one of 'neumann','cg','lowrank'"
        self.method = method
        self.eps = eps
        self.K = K
        self.rank = rank
        self.maxiter = maxiter

    def compute(self, L):
        if self.method == 'neumann':
            return self._neumann(L)
        elif self.method == 'cg':
            return self._cg(L)
        elif self.method == 'lowrank':
            return self._lowrank(L)
    
    # ----------------------
    # Neumann series 近似
    # ----------------------
    def _neumann(self, L_sparse):
        N = L_sparse.shape[0]
        device = L_sparse.device
        I_sparse = torch.sparse_coo_tensor(
            indices=torch.stack([torch.arange(N), torch.arange(N)]),
            values=torch.ones(N, device=device),
            size=(N,N)
        )
        M = I_sparse - L_sparse / self.eps
        M_power = I_sparse.clone()
        L_pinv_approx = M_power.clone()
        for _ in range(self.K):
            M_power = torch.sparse.mm(M_power, M)
            L_pinv_approx = torch.sparse.add(L_pinv_approx, M_power)
        L_pinv_approx = L_pinv_approx * (1.0 / self.eps)
        return L_pinv_approx

    # ----------------------
    # Conjugate Gradient 迭代求逆列
    # ----------------------
    def _cg(self, L):
        N = L.shape[0]
        device = L.device
        I = torch.eye(N, device=device)
        L_reg = L + self.eps * I
        L_pinv_approx = torch.zeros_like(L)
        for i in range(N):
            e = I[:, i]
            x, _ = torch.linalg.cg(L_reg, e, maxiter=self.maxiter)
            L_pinv_approx[:, i] = x
        return L_pinv_approx

    # ----------------------
    # 低秩特征截断近似
    # ----------------------
    def _lowrank(self, L):
        N = L.shape[0]
        device = L.device
        eigvals, eigvecs = torch.linalg.eigh(L)
        if self.rank is None:
            self.rank = N - 1  # 默认去掉零特征
        # 取最后 rank 个非零特征
        Lambda_inv = torch.diag(1.0 / eigvals[-self.rank:])
        L_pinv_approx = eigvecs[:, -self.rank:] @ Lambda_inv @ eigvecs[:, -self.rank:].T
        return L_pinv_approx

In [None]:
# 假设 L 是稀疏 Laplacian
num_nodes = 6
edges = torch.tensor([[0,1],[0,2],[1,2],[2,3],[3,4],[4,5],[3,5]])
weights = torch.tensor([0.6]*len(edges))
adj = torch.zeros(num_nodes, num_nodes)
for (i,j), w in zip(edges.tolist(), weights):
    adj[i,j] = w
    adj[j,i] = w
D = torch.diag(adj.sum(dim=1))
L = D - adj

# ----------------------
# 方法1: Neumann series
# ----------------------
L_sparse = L.to_sparse()
lp_inv_neumann = LaplacianPseudoInverse(method='neumann', eps=1e-3, K=5)
L_pinv = lp_inv_neumann.compute(L_sparse)
print(L_pinv)

# ----------------------
# 方法2: CG 迭代
# ----------------------
lp_inv_cg = LaplacianPseudoInverse(method='cg', eps=1e-3, maxiter=50)
L_pinv_cg = lp_inv_cg.compute(L)
print(L_pinv_cg)

# ----------------------
# 方法3: 低秩截断
# ----------------------
lp_inv_lowrank = LaplacianPseudoInverse(method='lowrank', rank=5)
L_pinv_lowrank = lp_inv_lowrank.compute(L)
print(L_pinv_lowrank)