In [1]:
import numpy as np

In [None]:
# =========================
# Utilities
# =========================
def db2lin(x_db): return 10.0**(x_db/10.0)
def lin2db(x):    return 10.0*np.log10(x+1e-30)

def steering_ula(phi, Nr, d_over_lambda=0.5):
    m = np.arange(Nr)[:, None]
    return np.exp(-1j*2*np.pi*d_over_lambda*m*np.sin(phi))  # (Nr,1)

def delay_atom(k_idx, tau, f0):
    k = k_idx[:, None]  # (Kp,1)
    return np.exp(-1j*2*np.pi*f0*k*tau)  # (Kp,1)

def doppler_atom(s_idx, nu, Tsym):
    s = s_idx[:, None]  # (S,1)
    return np.exp(1j*2*np.pi*nu*Tsym*s)  # (S,1)

def complex_gaussian(shape, sigma2):
    sigma = np.sqrt(max(sigma2, 0.0)/2.0)
    return sigma*(np.random.randn(*shape)+1j*np.random.randn(*shape))

def make_atom(phi, tau, nu, Nr, Kset, S, f0, Tsym, d_over_lambda=0.5):
    """Safe 3D rank-1 atom (Nr,Kp,S)."""
    aphi = steering_ula(phi, Nr, d_over_lambda)          # (Nr,1)
    atau_col = delay_atom(Kset, tau, f0)                 # (Kp,1)
    anu_col  = doppler_atom(np.arange(S), nu, Tsym)      # (S,1)
    M_phi_tau = aphi @ atau_col.T                        # (Nr,Kp)
    atom = M_phi_tau[:, :, None] * anu_col[None, None, :, 0]  # (Nr,Kp,S)
    return atom

def nmse_db(H_true_list, H_est_list):
    num = 0.0
    den = 0.0
    for Ht, He in zip(H_true_list, H_est_list):
        num += np.sum(np.abs(Ht - He)**2)
        den += np.sum(np.abs(Ht)**2) + 1e-30
    return 10*np.log10(num/den + 1e-30)



def gen_perRB_gains(
    N_RB, L,
    base_mag=None,       # (L,)，路径的基准幅度；None 则随机递减
    sigma_dB=3.0,        # 对数组阴影标准差[dB]
    rho=0.9,             # RB 相关性（AR(1) 系数），越大越平滑
    K_dB=None,           # Rician K 因子[dB]，None 表示纯 Rayleigh
    seed=None
):
    rng = np.random.default_rng(seed)
    if base_mag is None:
        base_mag = np.sort(0.7+0.6*rng.random(L))[::-1]
    base_mag = np.asarray(base_mag).reshape(L)

    # 阴影：对数正态（每路径一条 AR(1) 轨迹）
    sigma = sigma_dB/20*np.log(10)   # dB->对数域
    log_s = np.zeros((N_RB, L))
    eps = rng.standard_normal((N_RB, L))
    for l in range(L):
        for r in range(1, N_RB):
            log_s[r, l] = rho*log_s[r-1, l] + np.sqrt(1-rho**2)*eps[r, l]
    shadow = np.exp(sigma*log_s)  # (N_RB,L)

    # 小尺度：Rayleigh 或 Rician
    if K_dB is None:
        # 纯 Rayleigh：相位均匀
        phase = rng.uniform(-np.pi, np.pi, size=(N_RB, L))
        mag   = base_mag[None, :]*shadow
        g_rL  = mag * np.exp(1j*phase)
    else:
        K = 10**(np.asarray(K_dB).reshape(1, -1)/10.0) if np.ndim(K_dB) else 10**(K_dB/10.0)
        if np.ndim(K_dB)==0: K = np.full((1, L), K)
        # Rician: 直达分量 + 零均值复高斯分量
        # 归一到 E|g|^2 ∝ base_mag^2；再乘上阴影
        phase_LOS = rng.uniform(-np.pi, np.pi, size=(N_RB, L))
        los = np.sqrt(K/(K+1)) * base_mag[None, :] * np.exp(1j*phase_LOS)
        nlos_sigma = (base_mag[None, :]/np.sqrt(2*(K+1)))
        nlos = nlos_sigma*(rng.standard_normal((N_RB, L)) + 1j*rng.standard_normal((N_RB, L)))
        g_rL = (los + nlos) * shadow

    return g_rL  # 形状 (N_RB, L)





# =========================
# Channel generation (Nt=1)  — supports manual truth
# =========================

def db2lin(x_db): return 10.0**(x_db/10.0)


def steering_ula(phi, Nr, d_over_lambda=0.5):
    m = np.arange(Nr)[:, None]
    return np.exp(-1j*2*np.pi*d_over_lambda*m*np.sin(phi))  # (Nr,1)


def make_atom(phi, tau, nu, Nr, Kset, S, f0, Tsym, d_over_lambda=0.5):
    """Safe 3D rank-1 atom (Nr,Kp,S)."""
    aphi = steering_ula(phi, Nr, d_over_lambda)          # (Nr,1)
    atau_col = delay_atom(Kset, tau, f0)                 # (Kp,1)
    anu_col  = doppler_atom(np.arange(S), nu, Tsym)      # (S,1)
    M_phi_tau = aphi @ atau_col.T                        # (Nr,Kp)
    atom = M_phi_tau[:, :, None] * anu_col[None, None, :, 0]  # (Nr,Kp,S)
    return atom


def generate_multiRB_channel(
    Nr=16, L=4, S=14, N_RB=128,
    K_total=12, Kp=6,
    f0=120e3, fc=10e9, c=3e8,
    Tsym=None, SNR_dB=20.0,
    d_over_lambda=0.5,
    pilot_design="diverse",
    # ---- manual truth ----
    phi_true=None, tau_true=None, nu_true=None,  # if None -> random
    g_rL=None,       # shape (N_RB, L); if None -> synthesize from g_mag + random phases
    g_mag=None,
    g_model=None
):
    if Tsym is None:
        Fs = f0*1024; cp_len = 6
        Tsym = 1.0/f0 + cp_len/Fs

    # large-scale truth
    if phi_true is None:
        phi_true = (np.random.uniform(-60, 60, size=L)*np.pi/180.0)
    else:
        phi_true = np.asarray(phi_true).reshape(L)
    if tau_true is None:
        tau_true = np.random.uniform(0.0, 1.0/f0, size=L)
    else:
        tau_true = np.asarray(tau_true).reshape(L)
    if nu_true is None:
        nu_true  = np.random.uniform(-0.45/Tsym, +0.45/Tsym, size=L)
    else:
        nu_true = np.asarray(nu_true).reshape(L)



    if g_rL is not None:
        g_rL = np.asarray(g_rL); assert g_rL.shape==(N_RB, L)
        g_mag_out   = np.abs(g_rL).mean(axis=0)
        g_phase_rL  = np.angle(g_rL)
    else:
        if g_mag is None:
            # 使用统计模型自动生成（推荐默认）
            gm = dict(sigma_dB=3.0, rho=0.9, K_dB=None, seed=None, base_mag=None)
            if g_model is not None: gm.update(g_model)
            g_rL = gen_perRB_gains(N_RB, L, **gm)
            g_mag_out   = np.abs(g_rL).mean(axis=0)
            g_phase_rL  = np.angle(g_rL)
        else:
            g_mag = np.asarray(g_mag)
            if g_mag.ndim==1:
                # All RB same amplitudde
                phase = np.random.uniform(-np.pi, np.pi, size=(N_RB, L))
                g_rL  = g_mag[None,:] * np.exp(1j*phase)
                g_mag_out  = g_mag
                g_phase_rL = np.angle(g_rL)
            elif g_mag.ndim==2 and g_mag.shape==(N_RB, L):
                phase = np.random.uniform(-np.pi, np.pi, size=(N_RB, L))
                g_rL  = g_mag * np.exp(1j*phase)
                g_mag_out  = g_mag.mean(axis=0)
                g_phase_rL = np.angle(g_rL)
            else:
                raise ValueError("g_mag must be (L,) or (N_RB,L)")
            
            
            
    # subcarrier sets
    all_k = np.arange(K_total)
    Ksets = []
    if pilot_design == "fixed":
        base = np.linspace(0, K_total-1, Kp, dtype=int)
        for _ in range(N_RB): Ksets.append(base.copy())
    else:
        for _ in range(N_RB):
            Ksets.append(np.sort(np.random.choice(all_k, size=Kp, replace=False)))

    Y_list = []
    H_clean_list = []
    SNR = db2lin(SNR_dB)

    for r in range(N_RB):
        Kset = Ksets[r]
        H_r = np.zeros((Nr, Kp, S), dtype=complex)
        for l in range(L):
            atom = make_atom(phi_true[l], tau_true[l], nu_true[l],
                             Nr, Kset, S, f0, Tsym, d_over_lambda)
            H_r = H_r + g_rL[r, l] * atom
        H_clean_list.append(H_r.copy())
        sig_pow = np.mean(np.abs(H_r)**2)
        N0 = sig_pow / SNR
        Y_list.append(H_r + complex_gaussian(H_r.shape, N0))

    true = dict(phi=phi_true, tau=tau_true, nu=nu_true,
                g_mag=g_mag_out, g_phase_rL=g_phase_rL, g_rL=g_rL,
                f0=f0, Tsym=Tsym, H_clean_list=H_clean_list)
    return Y_list, Ksets, true

In [3]:


# ---------- small helpers ----------
def parabolic_peak_1d(vals, i):
    N = len(vals)
    if i<=0 or i>=N-1: return 0.0
    y1,y2,y3 = vals[i-1], vals[i], vals[i+1]
    denom = (y1 - 2*y2 + y3)
    if np.abs(denom) < 1e-12: return 0.0
    delta = 0.5*(y1 - y3)/denom
    return float(np.clip(delta, -0.5, 0.5))


# ---------- Non-Maximum Suppression, Suppress nearby secondary peaks ----------
def nms3d(agg, w_phi=2, w_tau=2, w_nu=2, topK=32):
    P, T, U = agg.shape
    mask = np.ones_like(agg, dtype=bool)
    peaks = []
    flat_idx = np.argsort(agg.ravel())[::-1]
    for idx in flat_idx:
        if len(peaks) >= topK: break
        p = idx // (T*U); t = (idx // U) % T; u = idx % U
        if not mask[p,t,u]: continue
        peaks.append((p,t,u, agg[p,t,u]))
        p0,p1 = max(0, p-w_phi), min(P, p+w_phi+1)
        t0,t1 = max(0, t-w_tau), min(T, t+w_tau+1)
        u0,u1 = max(0, u-w_nu),  min(U, u+w_nu+1)
        mask[p0:p1, t0:t1, u0:u1] = False
    return peaks  # list of (i_phi,i_tau,i_nu,val)

def grid_min_separation_ok(phi, tau, nu, S_list, mins):
    for (p,t,n) in S_list:
        if (abs(phi-p) < mins[0]) and (abs(tau-t) < mins[1]) and (abs(nu-n) < mins[2]):
            return False
    return True


# ======================================
# Joint Multi-RB 3D-NOMP (tightened) + two-stage (global refine + final LS debias)
# ======================================
def joint_3d_nomp_multi_rb(
    Y_list, Ksets, Nr, S, f0, Tsym,
    K_max=12,
    ov_phi=4, ov_tau=12, ov_nu=12,
    d_over_lambda=0.5,
    newton_iters=8,
    nms_win=(2,2,2),
    min_sep=None,            # if None, auto from f0/S/Tsym
    stop_resid_drop=1e-2,
    mdl_penalize_gains=True,
    cfar_alpha=4.0,
    rel_peak_floor=0.10,
    do_final_debias=True,
    do_global_refine=True,           # First phase of the two-phase approach: global loop refinement
    global_refine_sweeps=1,          #  1~2
    verbose=True
):
    # default min separation (rad, s, Hz)
    if min_sep is None:
        min_sep = (8*np.pi/180, 1.0/(8*f0), 1.0/(S*Tsym))

    N_RB = len(Y_list)

    # ===== Grids =====
    Nphi = Nr*ov_phi
    u = np.linspace(-1, 1, Nphi)
    phi_grid = np.arcsin(np.clip(u, -1, 1))

    Ntau = int(12*ov_tau)
    tau_grid = np.linspace(0.0, 1.0/f0, Ntau, endpoint=False)

    Nnu = S*ov_nu
    nu_grid = np.linspace(-0.5/Tsym, 0.5/Tsym, Nnu, endpoint=False)

    # ===== Dictionaries (AoA / Doppler are common to all RBs) =====
    n_vec = np.arange(Nr)[:, None]
    A_phi = np.exp(-1j*2*np.pi*d_over_lambda*n_vec*np.sin(phi_grid)[None, :])  # (Nr x Nphi)
    colnorm_phi = np.sqrt((np.abs(A_phi)**2).sum(axis=0))
    s_idx = np.arange(S)[:, None]
    A_nu = np.exp(1j*2*np.pi*Tsym*s_idx*nu_grid[None, :])  # (S x Nnu)
    colnorm_nu = np.sqrt((np.abs(A_nu)**2).sum(axis=0))

    # Residuals
    R_list = [Y.copy() for Y in Y_list]
    N_obs = sum([np.prod(Y.shape) for Y in Y_list])  # total complex samples

    phi_est, tau_est, nu_est = [], [], []
    G_arr = None
    residual_energy_trace = [np.sum([np.vdot(R, R).real for R in R_list])]
    first_peak_val = None

    # ---- per-RB delay dictionary cache ----
    tau_cache = {}
    def get_tau_dict(Kset):
        key = tuple(Kset.tolist())
        if key not in tau_cache:
            k = Kset[:, None]
            D = np.exp(-1j*2*np.pi*f0*k*tau_grid[None, :])  # (Kp,Ntau)
            norms = np.sqrt((np.abs(D)**2).sum(axis=0))
            tau_cache[key] = (D, norms)
        return tau_cache[key]

    # ---- aggregated correlation ----
    def aggregated_correlation_cube():
        agg = np.zeros((Nphi, Ntau, Nnu), dtype=float)
        for r in range(N_RB):
            R = R_list[r]
            Kset = Ksets[r]
            D_tau, norm_tau = get_tau_dict(Kset)
            R2 = R.reshape(Nr, -1)                           # (Nr, Kp*S)
            Z1 = A_phi.conj().T @ R2                         # (Nphi, Kp*S)
            Z1 = Z1.reshape(Nphi, len(Kset), S)              # (Nphi, Kp, S)
            Z2 = np.einsum('pks,kn->pns', Z1, D_tau.conj())  # (Nphi, Ntau, S)
            C_r = np.einsum('pns,sq->pnq', Z2, A_nu.conj())  # (Nphi, Ntau, Nnu)
            denom = (colnorm_phi[:,None,None] * norm_tau[None,:,None] * colnorm_nu[None,None,:] + 1e-12)
            Cn = C_r / denom
            agg += np.abs(Cn)**2
        return agg

    # ---- single-path per-RB LS gain ----
    def ls_gain_per_RB(phi, tau, nu):
        g_r = np.zeros(N_RB, dtype=complex)
        for r in range(N_RB):
            Kset = Ksets[r]
            atom = make_atom(phi, tau, nu, Nr, Kset, S, f0, Tsym, d_over_lambda)
            num  = np.vdot(atom, R_list[r])
            den  = np.vdot(atom, atom).real + 1e-12
            g_r[r] = num/den
        return g_r

    # ---- residual update ----
    def subtract_path_from_residual(phi, tau, nu, g_r, sign=1.0):
        for r in range(N_RB):
            Kset = Ksets[r]
            atom = make_atom(phi, tau, nu, Nr, Kset, S, f0, Tsym, d_over_lambda)
            R_list[r] = R_list[r] - sign*g_r[r]*atom

    # ---- coordinate-cycling Newton refinement (lightweight) ----
    def refine_newton(phi, tau, nu, iters=newton_iters):
        m = np.arange(Nr)[:, None]
        s = np.arange(S)[:, None]
        for _ in range(iters):
            aphi = np.exp(-1j*2*np.pi*d_over_lambda*m*np.sin(phi))
            dphi = aphi * (-1j*2*np.pi*d_over_lambda*m*np.cos(phi))
            anu  = np.exp( 1j*2*np.pi*nu*Tsym*s)
            dnu  = anu * ( 1j*2*np.pi*Tsym*s)

            z_phi = 0+0j; z_nu = 0+0j
            H_11 = 0.0;   H_22 = 0.0
            for r in range(N_RB):
                Kset = Ksets[r]
                atau_col = np.exp(-1j*2*np.pi*f0*Kset[:, None]*tau)

                v_m = np.einsum('mks,ks->m', R_list[r], (atau_col.conj() @ anu.conj().T))
                z_phi += np.vdot(dphi[:,0], v_m)

                w_s = np.einsum('mks,mk->s', R_list[r], (aphi.conj() @ atau_col.conj().T))
                z_nu  += np.vdot(dnu[:,0], w_s)

                H_11 += np.vdot(dphi[:,0], dphi[:,0]).real * (atau_col.size * S)
                H_22 += np.vdot(dnu[:,0],  dnu[:,0]).real  * (Nr * len(Kset))

            eps = 1e-6
            damp = 0.35
            phi = np.clip(phi + damp*(z_phi.conjugate()).real/(H_11+eps), -np.pi/2+1e-6, np.pi/2-1e-6)
            nu  = np.clip(nu  + damp*(z_nu.conjugate()).real /(H_22+eps), -0.5/Tsym, +0.5/Tsym)

            # tau step
            num = 0+0j; den = 0.0
            for r in range(N_RB):
                Kset = Ksets[r]
                atau_col = np.exp(-1j*2*np.pi*f0*Kset[:, None]*tau)
                d_at_col = atau_col * (-1j*2*np.pi*f0*Kset[:, None])

                A1 = aphi; AN = anu.T
                M0 = (A1 @ atau_col.T);  A0 = M0[:, :, None] * AN[None, None, :]
                M1 = (A1 @ d_at_col.T);  A1_tau = M1[:, :, None] * AN[None, None, :]

                q   = np.vdot(A0,     R_list[r])
                q_p = np.vdot(A1_tau, R_list[r])

                num += (q.conjugate()*q_p)
                den += (np.abs(q_p)**2 + 1e-12)

            tau = np.clip(tau + 0.5*(2.0*num.real)/(2.0*den + 1e-12), 0.0, 1.0/f0 - 1e-12)
        return phi, tau, nu

    # ---- sweep refinement over all detected paths ----
    def cyclic_refine_sweep():
        K = len(phi_est)
        if K==0: return
        for i in range(K):
            gi = G_arr[:, i].copy()
            subtract_path_from_residual(phi_est[i], tau_est[i], nu_est[i], gi, sign=-1.0)  # add back
            pr, tr, nr = refine_newton(phi_est[i], tau_est[i], nu_est[i])
            gi_new = ls_gain_per_RB(pr, tr, nr)
            subtract_path_from_residual(pr, tr, nr, gi_new, sign=+1.0)
            phi_est[i], tau_est[i], nu_est[i] = pr, tr, nr
            G_arr[:, i] = gi_new

    # ---- merge near-duplicates ----
    def merge_nearby_paths():
        nonlocal G_arr, phi_est, tau_est, nu_est
        if G_arr is None or G_arr.shape[1] <= 1: return
        K = G_arr.shape[1]
        keep = np.ones(K, dtype=bool)
        power = np.sum(np.abs(G_arr)**2, axis=0)
        for i in range(K):
            if not keep[i]: continue
            for j in range(i+1, K):
                if not keep[j]: continue
                if (abs(phi_est[i]-phi_est[j]) < min_sep[0] and
                    abs(tau_est[i]-tau_est[j]) < min_sep[1] and
                    abs(nu_est [i]-nu_est [j]) < min_sep[2]):
                    if power[i] >= power[j]:
                        keep[j] = False
                    else:
                        keep[i] = False
                        break
        phi_est = list(np.array(phi_est)[keep])
        tau_est = list(np.array(tau_est)[keep])
        nu_est  = list(np.array(nu_est )[keep])
        G_arr   = G_arr[:, keep]

    # ---- FINAL joint LS debias per RB ----
    def final_joint_ls_debias(phi_arr, tau_arr, nu_arr):
        """Return G_ls (N_RB x K) and final residual list."""
        K = len(phi_arr)
        if K == 0:
            return np.zeros((N_RB, 0), dtype=complex), [Y.copy() for Y in Y_list]
        G_ls = np.zeros((N_RB, K), dtype=complex)
        R_final = []
        for r in range(N_RB):
            Kset = Ksets[r]
            # Build design matrix A_r: (Nr*Kp*S) x K
            Acols = []
            for k in range(K):
                atom = make_atom(phi_arr[k], tau_arr[k], nu_arr[k], Nr, Kset, S, f0, Tsym, d_over_lambda)
                Acols.append(atom.reshape(-1))
            A_r = np.stack(Acols, axis=1)  # (N, K)
            y_r = Y_list[r].reshape(-1)
            # Least squares solve
            g_r, *_ = np.linalg.lstsq(A_r, y_r, rcond=None)
            G_ls[r, :] = g_r
            # residual
            y_hat = A_r @ g_r
            R_final.append((y_r - y_hat).reshape(Nr, len(Kset), S))
        return G_ls, R_final

    # ---------- MAIN detection ----------
    last_E = residual_energy_trace[-1]
    for k in range(K_max):
        
        # ---------- coarse search ----------
        agg = aggregated_correlation_cube()

        # CFAR/GLRT gate
        med = np.median(agg)
        std = np.std(agg)
        th  = med + cfar_alpha*std
        ip, it, iu = np.unravel_index(np.argmax(agg), agg.shape)
        peak_now = agg[ip,it,iu]
        if first_peak_val is None:
            first_peak_val = peak_now
        if (peak_now < th) or (peak_now < rel_peak_floor*first_peak_val):
            if verbose: print(f"Stop by CFAR: peak={peak_now:.2e} < max({th:.2e}, {rel_peak_floor:.2f}*first)")
            break

        cand = nms3d(agg, w_phi=nms_win[0], w_tau=nms_win[1], w_nu=nms_win[2], topK=16)
        chosen = None
        for (ip,it,iu,val) in cand:
            if (val < th) or (val < rel_peak_floor*first_peak_val):
                continue
            dphi = parabolic_peak_1d(agg[:,it,iu], ip)
            dtau = parabolic_peak_1d(agg[ip,:,iu], it)
            dnu  = parabolic_peak_1d(agg[ip,it,:], iu)
            phi0 = phi_grid[np.clip(ip,1,Nphi-2)] + dphi*(phi_grid[1]-phi_grid[0])
            tau0 = tau_grid[np.clip(it,1,Ntau-2)] + dtau*(tau_grid[1]-tau_grid[0])
            nu0  = nu_grid [np.clip(iu,1,Nnu -2)] + dnu *(nu_grid[1] -nu_grid[0] )
            if grid_min_separation_ok(phi0, tau0, nu0,
                                      list(zip(phi_est,tau_est,nu_est)),
                                      mins=min_sep):
                chosen = (phi0, tau0, nu0, val); break
        if chosen is None:
            if verbose: print("Stop: no candidate passes NMS+CFAR+min-sep.")
            break

        phi0,tau0,nu0,_ = chosen
        if verbose: print(f"[Search] k={k+1} init: phi={phi0*180/np.pi:.2f}deg")


        # ---------- refine_newton ----------
        pr, tr, nr = refine_newton(phi0, tau0, nu0)
        gr = ls_gain_per_RB(pr, tr, nr)
        subtract_path_from_residual(pr, tr, nr, gr)

        phi_est.append(pr); tau_est.append(tr); nu_est.append(nr)
        G_arr = gr[:,None] if (G_arr is None) else np.concatenate([G_arr, gr[:,None]], axis=1)

        merge_nearby_paths()

        Eres = np.sum([np.vdot(R, R).real for R in R_list])
        rel_drop = (last_E - Eres) / (last_E + 1e-12)
        residual_energy_trace.append(Eres)
        last_E = Eres
        if rel_drop < stop_resid_drop:
            if verbose: print(f"Stop by residual-drop: ΔE/E={rel_drop:.2e}")
            break

        if (len(phi_est)) % 2 == 0:
            cyclic_refine_sweep()
            merge_nearby_paths()

    # ---- MDL Minimum Description Length (count per-RB complex gains) ----
    if len(residual_energy_trace) <= 1 or G_arr is None:
        return dict(phi=np.array([]), tau=np.array([]), nu=np.array([]),
                    G=np.zeros((len(Y_list),0), dtype=complex),
                    R_list=[Y.copy() for Y in Y_list])

    mdl = []
    for kk in range(1, G_arr.shape[1]+1):
        sigma2 = residual_energy_trace[min(kk, len(residual_energy_trace)-1)]/N_obs
        if mdl_penalize_gains:
            p_k = kk*(3 + 2*len(Y_list))   # 3 shared + 2*N_RB (complex gain)
        else:
            p_k = kk*3
        mdl_k = N_obs*np.log(max(sigma2,1e-20)) + 0.5*p_k*np.log(N_obs)
        mdl.append(mdl_k)
    k_opt = int(np.argmin(mdl)) + 1

    # truncate to k_opt
    phi_est = np.array(phi_est[:k_opt])
    tau_est = np.array(tau_est[:k_opt])
    nu_est  = np.array(nu_est [:k_opt])
    G_arr   = G_arr[:, :k_opt]

    # ===== 二段式：阶段A —— 全局循环重细化（可选）=====
    if do_global_refine and k_opt > 0:
        if verbose: print(f"[Global-Refine] {global_refine_sweeps} sweep(s) over {k_opt} paths...")
        for _ in range(global_refine_sweeps):
            for i in range(k_opt):
                gi = G_arr[:, i].copy()
                # add back this path
                subtract_path_from_residual(phi_est[i], tau_est[i], nu_est[i], gi, sign=-1.0)
                # refine parameters under current interference model
                pr, tr, nr = refine_newton(phi_est[i], tau_est[i], nu_est[i])
                # LS gains per RB and subtract again
                gi_new = ls_gain_per_RB(pr, tr, nr)
                subtract_path_from_residual(pr, tr, nr, gi_new, sign=+1.0)
                # commit
                phi_est[i], tau_est[i], nu_est[i] = pr, tr, nr
                G_arr[:, i] = gi_new

    # ===== 二段式：阶段B —— 最终 per-RB 联合 LS 去偏 =====
    if do_final_debias:
        if verbose: print("[Debias] Running per-RB joint LS debias...")
        G_ls, R_final = final_joint_ls_debias(phi_est, tau_est, nu_est)
        return dict(phi=phi_est, tau=tau_est, nu=nu_est, G=G_ls, R_list=R_final)
    else:
        # reconstruct residuals using current G_arr
        R_final = [Y.copy() for Y in Y_list]
        for i in range(k_opt):
            subtract_path_from_residual(phi_est[i], tau_est[i], nu_est[i], G_arr[:, i], sign=+1.0)
        return dict(phi=phi_est, tau=tau_est, nu=nu_est, G=G_arr, R_list=R_final)

# =========================
# Reconstruction for NMSE
# =========================
def reconstruct_channel_list(phi, tau, nu, G, Ksets, Nr, S, f0, Tsym, d_over_lambda=0.5):
    """
    Build H_est per RB from estimated params and per-RB gains.
    phi/tau/nu: (K,)
    G: (N_RB x K)
    returns: list of (Nr x Kp x S)
    """
    K = len(phi)
    N_RB = G.shape[0] if G.ndim==2 else 0
    H_est_list = []
    for r in range(N_RB):
        Kset = Ksets[r]
        Nr_local = Nr
        H_r = np.zeros((Nr_local, len(Kset), S), dtype=complex)
        for k in range(K):
            if G[r, k] == 0: continue
            atom = make_atom(phi[k], tau[k], nu[k], Nr, Kset, S, f0, Tsym, d_over_lambda)
            H_r += G[r, k] * atom
        H_est_list.append(H_r)
    return H_est_list






In [None]:
np.random.seed(0)

Nr = 16
L_true = 4
S = 14
N_RB = 64        
K_total = 12 # scs number in one RB
Kp = 12         
f0 = 120e3
Fs = f0*1024 # sample rate
cp_len = 6
Tsym = 1.0/f0 + cp_len/Fs
SNR_dB = 20.0

# ===== manual truth =====
phi_deg_man = np.array([6.0, 26.0, 12.0, 33.0])
tau_us_man  = np.array([3.53, 5.382, 3.65, 7.43])
nu_hz_man   = np.array([5.0e4, -1.25e4, 3.13e4, 3.1e3])

phi_true_in = phi_deg_man*np.pi/180.0
tau_true_in = tau_us_man*1e-6
nu_true_in  = nu_hz_man

# g_mag_in = np.clip(np.array([1.2,0.9,0.6,0.4])[None,:] * (1 + 0.2*np.random.randn(8,4)), 1e-5, None) # RB-specific magnitudes, random phase:
# g_mag_in = np.array([1.2, 1, 0.8, 0.6])  # (L,) # Shared magnitudes, random phase:

g_rL_in = gen_perRB_gains(N_RB=8, L=4, base_mag=[1.2,0.9,0.6,0.4], sigma_dB=2.5, rho=0.8, seed=1)
g_model=dict(base_mag=[1.2,0.9,0.6,0.4], sigma_dB=3, rho=0.9, K_dB=None, seed=0)  # Fully specified complex gains:

Y_list, Ksets, truth = generate_multiRB_channel(
    Nr=Nr, L=L_true, S=S, N_RB=N_RB,
    K_total=K_total, Kp=Kp, f0=f0, Tsym=Tsym,
    SNR_dB=SNR_dB, pilot_design="fixed",
    phi_true=phi_true_in, tau_true=tau_true_in, nu_true=nu_true_in,
    g_model=g_model,
    # g_rL=g_rL_in, 
    # g_mag=g_mag_in
)

In [6]:

est = joint_3d_nomp_multi_rb(
    Y_list, Ksets, Nr=Nr, S=S, f0=f0, Tsym=Tsym,
    K_max=8, ov_phi=4, ov_tau=16, ov_nu=16,
    newton_iters=8, verbose=True,
    nms_win=(2,2,2),
    min_sep=(8*np.pi/180, 1.0/(12*f0), 1.0/(S*Tsym)),   # tighter delay sep for Kp=12
    stop_resid_drop=1e-2,
    mdl_penalize_gains=True,
    cfar_alpha=4.0,
    rel_peak_floor=0.10,
    do_final_debias=True,
    do_global_refine=True,        # two-stage "Phase A"
    global_refine_sweeps=2
)

def deg(x): return x*180/np.pi
print("\n=== Ground Truth vs Estimate (unordered) ===")
print("phi_true [deg]:", np.round(deg(truth["phi"]), 2))
print("tau_true [us] :", np.round(1e6*truth["tau"], 3))
print("nu_true  [Hz] :", np.round(truth["nu"], 3))

print("\nphi_est  [deg]:", np.round(deg(est["phi"]), 2))
print("tau_est  [us] :", np.round(1e6*est["tau"], 3))
print("nu_est   [Hz] :", np.round(est["nu"], 3))

# ========== Per-RB gains ==========
if est["G"].shape[1] > 0:
    print("\n--- Per-RB gains (magnitudes, debiased) ---")
    G_true_abs = np.abs(truth["g_rL"])   # (N_RB, L_true)
    G_est_abs  = np.abs(est["G"])        # (N_RB, K_est)
    for r in range(N_RB):
        print(f"RB #{r:02d} | |g_true_r|:, {np.round(np.abs(truth['g_rL'][r]), 3)}  | |g_est|: {np.round(G_est_abs[r],3)}")

    # for r in range(N_RB):
        # print(f"RB #{r:02d} | |g_true|: {np.round(G_true_abs[r],3)} | |g_est|: {np.round(G_est_abs[r],3)}")
    # for r in range(N_RB):
    #     print(f"RB #{r:02d} | g_true: {np.round(truth['g_rL'][r],3)} | g_est: {np.round(est['G'][r],3)}")

# ========== NMSE(dB)（relative to no noise H_clean）==========
H_est_list = reconstruct_channel_list(est["phi"], est["tau"], est["nu"], est["G"],
                                        Ksets, Nr, S, f0, Tsym)
nmse_val_db = nmse_db(truth["H_clean_list"], H_est_list)
print(f"\nNMSE (dB) vs clean channel: {nmse_val_db:.2f} dB")


[Search] k=1 init: phi=3.50deg
[Search] k=2 init: phi=29.60deg
[Search] k=3 init: phi=12.67deg
[Search] k=4 init: phi=29.00deg
[Search] k=5 init: phi=20.87deg
Stop: no candidate passes NMS+CFAR+min-sep.
[Global-Refine] 2 sweep(s) over 5 paths...
[Debias] Running per-RB joint LS debias...

=== Ground Truth vs Estimate (unordered) ===
phi_true [deg]: [ 6. 26. 12. 33.]
tau_true [us] : [3.53  5.382 3.65  7.43 ]
nu_true  [Hz] : [ 50000. -12500.  31300.   3100.]

phi_est  [deg]: [ 4.51 28.82 12.65 29.92 20.46]
tau_est  [us] : [3.53  5.382 3.65  7.43  5.382]
nu_est   [Hz] : [ 49721.937 -12382.944  31412.904   3137.325  -9408.658]

--- Per-RB gains (magnitudes, debiased) ---
RB #00 | |g_true_r|:, [1.2 0.9 0.6 0.4]  | |g_est|: [1.112 0.744 0.576 0.317 0.286]
RB #01 | |g_true_r|:, [1.107 0.95  0.73  0.461]  | |g_est|: [1.029 0.783 0.723 0.371 0.3  ]
RB #02 | |g_true_r|:, [1.004 0.781 0.652 0.458]  | |g_est|: [0.928 0.641 0.632 0.367 0.255]
RB #03 | |g_true_r|:, [0.72  0.767 0.536 0.404]  | |g_es