In [3]:
import torch

# ——————————————————————————————————————————————————————————
# Configuration
torch.manual_seed(0)
B, N, P, K = 5, 20, 15, 4   # batch size, samples, features, top-k
alpha = 0.5                 # exponent for penalty

# Random inputs
X      = torch.randn(B, N, P)       # [B, N, P]
Y      = torch.randn(B, N, 1)       # [B, N, 1]
scores = torch.randn(B, P)          # [B, P]

# ——————————————————————————————————————————————————————————
def batched_beta_full(X, Y, scores, K, alpha):
    topk_vals, topk_idx = scores.topk(K, dim=1)              # [B, K]
    idx_exp = topk_idx.unsqueeze(1).expand(-1, X.size(1), -1)# [B, N, K]
    X_topk = torch.gather(X, 2, idx_exp)                     # [B, N, K]
    Xt = X_topk.transpose(1, 2)                              # [B, K, N]
    penalty = torch.diag_embed(1.0 / (topk_vals**alpha))     # [B, K, K]
    XtX = torch.bmm(Xt, X_topk) + penalty                    # [B, K, K]
    Xty = torch.bmm(Xt, Y)                                   # [B, K, 1]
    beta = torch.linalg.solve(XtX, Xty).squeeze(-1)          # [B, K]
    beta_full = torch.zeros(B, P)
    beta_full.scatter_(1, topk_idx, beta)                    # [B, P]
    return beta_full

def loop_beta_full(X, Y, scores, K, alpha):
    B, _, P = X.shape
    beta_full = torch.zeros(B, P)
    for b in range(B):
        vals_b, idx_b = scores[b].topk(K)
        Xb, Yb = X[b], Y[b]
        Xb_topk = Xb[:, idx_b]                                # [N, K]
        penalty_b = torch.diag(1.0 / (vals_b**alpha))         # [K, K]
        XtXb = Xb_topk.t() @ Xb_topk + penalty_b              # [K, K]
        Xtyb = Xb_topk.t() @ Yb                               # [K, 1]
        betab = torch.linalg.solve(XtXb, Xtyb).squeeze(-1)    # [K]
        beta_full[b, idx_b] = betab
    return beta_full

# ——————————————————————————————————————————————————————————
# Compare reconstructions
beta_b = batched_beta_full(X, Y, scores, K, alpha)
beta_l = loop_beta_full(   X, Y, scores, K, alpha)

max_diff = (beta_b - beta_l).abs().max().item()
print(f"Max absolute difference between batched vs loop: {max_diff:.3e}")

Max absolute difference between batched vs loop: 5.588e-08


In [4]:
beta_b.norm(), beta_l.norm()

(tensor(0.8832), tensor(0.8832))