In [15]:
# # receiver.py -- Rectified receiver for FSO-MDM OAM system
# # Requirements: numpy, scipy, matplotlib, encoding.py, turbulence.py, lgBeam.py (optional)
# import os
# import sys
# import warnings
# import argparse
# from typing import Dict, Tuple, Any, Optional

# import numpy as np
# import matplotlib.pyplot as plt
# from scipy.linalg import inv, pinv
# from scipy.fft import fft2, ifft2

# # script dir resolution (allow running in notebooks)
# try:
#     SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# except NameError:
#     SCRIPT_DIR = os.getcwd()
# sys.path.insert(0, SCRIPT_DIR)

# # imports from user modules (encoding provides framing, pilots, LDPC wrapper)
# try:
#     from encoding import QPSKModulator, PilotHandler, PyLDPCWrapper, FSO_MDM_Frame
# except Exception as e:
#     raise ImportError(f"receiver.py requires encoding.py in the same directory: {e}")

# # optional turbulence angular prop
# try:
#     from turbulence import angular_spectrum_propagation
# except Exception as e:
#     angular_spectrum_propagation = None
#     warnings.warn(f"turbulence.angular_spectrum_propagation not available: {e}")

# # optional lgBeam
# try:
#     from lgBeam import LaguerreGaussianBeam
# except Exception:
#     LaguerreGaussianBeam = None

# warnings.filterwarnings("ignore")


# # ---------------------------
# # Utilities
# # ---------------------------
# def reconstruct_grid_from_gridinfo(grid_info: Dict[str, Any]):
#     if grid_info is None:
#         raise ValueError("grid_info is required to reconstruct spatial grid.")
#     x = np.asarray(grid_info.get("x"))
#     y = np.asarray(grid_info.get("y"))
#     if x.size == 0 or y.size == 0:
#         raise ValueError("grid_info.x/y empty or missing.")
#     X, Y = np.meshgrid(x, y, indexing="ij")
#     delta_x = np.mean(np.diff(x))
#     delta_y = np.mean(np.diff(y))
#     if not np.isclose(delta_x, delta_y, rtol=1e-6, atol=0.0):
#         warnings.warn("Non-square sampling intervals detected; using delta_x as delta.")
#     delta = float(delta_x)
#     return X, Y, delta, x, y


# def energy_normalize_field(field: np.ndarray, delta: float):
#     p = np.sum(np.abs(field) ** 2) * (delta ** 2)
#     if p > 0:
#         return field / np.sqrt(p)
#     return field


# # ---------------------------
# # OAM Demultiplexer
# # ---------------------------
# class OAMDemultiplexer:
#     def __init__(self, spatial_modes, wavelength, w0, z_distance, angular_prop_func=angular_spectrum_propagation):
#         self.spatial_modes = list(spatial_modes)
#         self.n_modes = len(self.spatial_modes)
#         self.wavelength = wavelength
#         self.w0 = w0
#         self.z_distance = z_distance
#         self.angular_prop = angular_prop_func
#         self._ref_cache = {}

#     def _make_ref_key(self, mode_key, N, delta, X_shape):
#         return (mode_key, int(N), float(delta), tuple(X_shape))

#     def reference_field(self, mode_key: Tuple[int, int], X, Y, delta, grid_z, tx_beam_obj=None):
#         N = X.shape[0]
#         key = self._make_ref_key(mode_key, N, delta, X.shape)
#         if key in self._ref_cache:
#             return self._ref_cache[key].copy()

#         R = np.sqrt(X ** 2 + Y ** 2)
#         PHI = np.arctan2(Y, X)

#         beam = tx_beam_obj
#         if beam is None:
#             p, l = mode_key
#             if LaguerreGaussianBeam is None:
#                 # fallback analytic simple helix*gaussian (unnormalized)
#                 w = self.w0
#                 amp = (R / w) ** (2 * p) * np.exp(-(R ** 2) / (w ** 2))
#                 ref_z0 = amp * np.exp(1j * l * PHI)
#             else:
#                 beam = LaguerreGaussianBeam(p, l, self.wavelength, self.w0)
#                 ref_z0 = beam.generate_beam_field(R, PHI, 0.0)
#         else:
#             ref_z0 = beam.generate_beam_field(R, PHI, 0.0)

#         # if grid_z>0 and angular_prop provided, propagate numerically
#         if self.angular_prop is None or grid_z == 0.0:
#             ref = ref_z0
#         else:
#             ref = self.angular_prop(ref_z0.copy(), delta, self.wavelength, grid_z)

#         self._ref_cache[key] = ref.copy()
#         return ref

#     def project_field(self, E_rx, grid_info, receiver_radius=None, tx_frame=None):
#         X, Y, delta, x, y = reconstruct_grid_from_gridinfo(grid_info)
#         R = np.sqrt(X**2 + Y**2)

#         if not np.iscomplexobj(E_rx):
#             warnings.warn("E_rx appears to be real (intensity). Assuming sqrt(I) zero-phase field for projection.")
#             E_rx = np.sqrt(np.abs(E_rx)).astype(np.complex128)

#         dA = float(delta ** 2)
#         if receiver_radius is not None:
#             aperture_mask = (R <= receiver_radius).astype(float)
#         else:
#             aperture_mask = np.ones_like(R, dtype=float)

#         symbols = {}
#         N = X.shape[0]

#         for mode_key in self.spatial_modes:
#             tx_beam_obj = None
#             if tx_frame is not None:
#                 sig = tx_frame.tx_signals.get(mode_key)
#                 if sig is not None:
#                     tx_beam_obj = sig.get("beam", None)

#             ref = self.reference_field(mode_key, X, Y, delta, grid_z=self.z_distance, tx_beam_obj=tx_beam_obj)
#             ref_ap = ref * aperture_mask
#             ref_energy = np.sum(np.abs(ref_ap) ** 2) * dA
#             projection = np.sum(E_rx * np.conj(ref_ap)) * dA
#             if ref_energy > 1e-20:
#                 symbols[mode_key] = projection / ref_energy
#             else:
#                 symbols[mode_key] = 0.0 + 0.0j
#         return symbols

#     def extract_symbols_sequence(self, E_rx_sequence, grid_info, receiver_radius=None, tx_frame=None):
#         seq = np.asarray(E_rx_sequence)
#         if seq.ndim == 2:
#             seq = seq[np.newaxis, ...]
#         n_frames = seq.shape[0]
#         symbols_per_mode = {mode: np.zeros(n_frames, dtype=complex) for mode in self.spatial_modes}
#         for i in range(n_frames):
#             snapshot = self.project_field(seq[i], grid_info, receiver_radius, tx_frame=tx_frame)
#             for mode in self.spatial_modes:
#                 symbols_per_mode[mode][i] = snapshot.get(mode, 0.0 + 0.0j)
#         return symbols_per_mode


# # ---------------------------
# # Channel estimator
# # ---------------------------
# class ChannelEstimator:
#     def __init__(self, pilot_handler: PilotHandler, spatial_modes):
#         self.pilot_handler = pilot_handler
#         self.spatial_modes = list(spatial_modes)
#         self.M = len(self.spatial_modes)
#         self.H_est = None
#         self.noise_var_est = None

#     def _gather_pilots(self, rx_symbols_per_mode: Dict[Tuple[int,int], np.ndarray], tx_frame: FSO_MDM_Frame):
#         tx_signals = tx_frame.tx_signals if tx_frame is not None else {}

#         pilot_positions = None
#         for mode_key in self.spatial_modes:
#             sig = tx_signals.get(mode_key)
#             if sig is not None and "pilot_positions" in sig:
#                 pilot_positions = np.asarray(sig["pilot_positions"], dtype=int)
#                 break
#         if pilot_positions is None:
#             pilot_positions = np.asarray(self.pilot_handler.pilot_positions, dtype=int) if (self.pilot_handler and getattr(self.pilot_handler, 'pilot_positions', None) is not None) else np.array([], dtype=int)

#         if pilot_positions is None or len(pilot_positions) == 0:
#             return None, None, np.array([], dtype=int)

#         min_len = min([len(rx_symbols_per_mode[mk]) for mk in self.spatial_modes])
#         valid_pos = pilot_positions[pilot_positions < min_len]
#         if len(valid_pos) == 0:
#             return None, None, np.array([], dtype=int)

#         n_p = len(valid_pos)
#         Y_p = np.zeros((self.M, n_p), dtype=complex)
#         P_p = np.zeros((self.M, n_p), dtype=complex)

#         for idx, mk in enumerate(self.spatial_modes):
#             Y_p[idx, :] = rx_symbols_per_mode[mk][valid_pos]
#             if tx_frame is None or mk not in tx_frame.tx_signals:
#                 raise ValueError("tx_frame with tx_signals required for LS channel estimation (to provide pilot symbols).")
#             tx_syms = np.asarray(tx_frame.tx_signals[mk]["symbols"])
#             P_p[idx, :] = tx_syms[valid_pos]

#         return Y_p, P_p, valid_pos

#     def estimate_channel_ls(self, rx_symbols_per_mode: Dict[Tuple[int, int], np.ndarray], tx_frame: FSO_MDM_Frame):
#         Y_p, P_p, pilot_pos = self._gather_pilots(rx_symbols_per_mode, tx_frame)
#         if Y_p is None or P_p is None or P_p.size == 0:
#             warnings.warn("No valid pilots found for LS channel estimation. Returning identity H.")
#             self.H_est = np.eye(self.M, dtype=complex)
#             return self.H_est

#         try:
#             PPH = P_p @ P_p.conj().T
#             cond = np.linalg.cond(PPH)
#             if cond > 1e6:
#                 warnings.warn(f"Pilot Gram matrix ill-conditioned (cond={cond:.2e}), using pseudo-inverse.")
#                 H = Y_p @ pinv(P_p)
#             else:
#                 H = Y_p @ P_p.conj().T @ inv(PPH)
#         except np.linalg.LinAlgError:
#             warnings.warn("Matrix inversion failed; using pseudo-inverse for channel estimate.")
#             H = Y_p @ pinv(P_p)

#         if np.linalg.cond(H) > 1e8:
#             reg = 1e-6
#             H = H @ inv(H + reg * np.eye(self.M))

#         self.H_est = H
#         return H

#     def estimate_noise_variance(self, rx_symbols_per_mode: Dict[Tuple[int,int], np.ndarray], tx_frame: FSO_MDM_Frame, H_est: np.ndarray):
#         Y_p, P_p, pilot_pos = self._gather_pilots(rx_symbols_per_mode, tx_frame)
#         if Y_p is None or P_p is None or P_p.size == 0:
#             self.noise_var_est = 1e-6
#             return self.noise_var_est
#         residual = Y_p - H_est @ P_p
#         noise_var = np.mean(np.abs(residual) ** 2)
#         self.noise_var_est = max(noise_var, 1e-12)
#         return self.noise_var_est


# # ---------------------------
# # FSORx: Full receiver pipeline
# # ---------------------------
# class FSORx:
#     def __init__(self, spatial_modes, wavelength, w0, z_distance, pilot_handler: PilotHandler, ldpc_instance: Optional[PyLDPCWrapper] = None, eq_method: str = "zf", receiver_radius: Optional[float] = None):
#         self.spatial_modes = list(spatial_modes)
#         self.n_modes = len(self.spatial_modes)
#         self.wavelength = wavelength
#         self.w0 = w0
#         self.z_distance = z_distance
#         self.pilot_handler = pilot_handler
#         self.eq_method = eq_method.lower()
#         self.receiver_radius = receiver_radius

#         self.qpsk = QPSKModulator(symbol_energy=1.0)

#         if ldpc_instance is not None:
#             self.ldpc = ldpc_instance
#         else:
#             try:
#                 self.ldpc = PyLDPCWrapper(n=2048, rate=0.8, dv=2, dc=8, seed=42)
#                 warnings.warn("No LDPC instance provided; receiver created local PyLDPCWrapper that may not match TX.")
#             except Exception as e:
#                 self.ldpc = None
#                 warnings.warn(f"Cannot construct LDPC wrapper locally: {e}; LDPC decode disabled for demo.")

#         self.demux = OAMDemultiplexer(self.spatial_modes, self.wavelength, self.w0, self.z_distance)
#         self.chan_est = ChannelEstimator(self.pilot_handler, self.spatial_modes)
#         self.metrics = {}

#     def receive_frame(self, rx_field_sequence, tx_frame: FSO_MDM_Frame, original_data_bits: np.ndarray, verbose: bool = True, bypass_ldpc: bool = True):
#         if verbose:
#             print("\n" + "=" * 72)
#             print("FSO-OAM Receiver: Start")
#             print("=" * 72)

#         grid_info = tx_frame.grid_info
#         if grid_info is None:
#             raise ValueError("tx_frame.grid_info required for demux/projection.")

#         if verbose:
#             print("1) OAM demultiplexing (projection)...")
#         rx_symbols_per_mode = self.demux.extract_symbols_sequence(rx_field_sequence, grid_info, receiver_radius=self.receiver_radius, tx_frame=tx_frame)
#         if verbose:
#             first_mode = self.spatial_modes[0]
#             print(f"   Extracted {len(rx_symbols_per_mode[first_mode])} symbols per mode (incl. pilots).")

#         if verbose:
#             print("2) Channel estimation (LS using pilots)...")
#         H_est = self.chan_est.estimate_channel_ls(rx_symbols_per_mode, tx_frame)
#         if verbose:
#             print("   H_est magnitude (rows):")
#             for row in np.abs(H_est):
#                 print("     [" + " ".join(f"{v:.3f}" for v in row) + "]")
#             print(f"   cond(H_est) = {np.linalg.cond(H_est):.2e}")

#         if verbose:
#             print("3) Noise variance estimation...")
#         noise_var = self.chan_est.estimate_noise_variance(rx_symbols_per_mode, tx_frame, H_est)
#         if verbose:
#             print(f"   Estimated noise variance σ² = {noise_var:.3e}")

#         if verbose:
#             print("4) Separate pilots and data")
#         pilot_positions = None
#         for mk in self.spatial_modes:
#             sig = tx_frame.tx_signals.get(mk)
#             if sig is not None and "pilot_positions" in sig:
#                 pilot_positions = np.asarray(sig["pilot_positions"], dtype=int)
#                 break
#         if pilot_positions is None:
#             pilot_positions = np.asarray(self.pilot_handler.pilot_positions, dtype=int) if (self.pilot_handler and getattr(self.pilot_handler, 'pilot_positions', None) is not None) else np.array([], dtype=int)

#         first_mode = self.spatial_modes[0]
#         total_rx_symbols = len(rx_symbols_per_mode[first_mode])
#         data_mask = np.ones(total_rx_symbols, dtype=bool)
#         if pilot_positions is not None and pilot_positions.size > 0:
#             valid_pilots = pilot_positions[pilot_positions < total_rx_symbols]
#             data_mask[valid_pilots] = False

#         rx_data_per_mode = {mk: rx_symbols_per_mode[mk][data_mask] for mk in self.spatial_modes}
#         data_lengths = [len(v) for v in rx_data_per_mode.values()]
#         if len(set(data_lengths)) > 1:
#             warnings.warn("Uneven data counts across modes; truncating to minimum length.")
#             min_len = min(data_lengths)
#             for mk in self.spatial_modes:
#                 rx_data_per_mode[mk] = rx_data_per_mode[mk][:min_len]
#         if data_lengths and data_lengths[0] == 0:
#             raise ValueError("No data symbols available after removing pilots.")

#         Y_data = np.vstack([rx_data_per_mode[mk] for mk in self.spatial_modes])
#         N_data = Y_data.shape[1]
#         if verbose:
#             print(f"   Data symbols per mode: {N_data}")

#         if verbose:
#             print("5) Equalization")
#         H = H_est.copy()
#         try:
#             cond_H = np.linalg.cond(H)
#         except Exception:
#             cond_H = np.inf
#         if self.eq_method == "auto":
#             use_mmse = cond_H > 1e4
#         else:
#             use_mmse = (self.eq_method == "mmse")

#         if not use_mmse:
#             try:
#                 W_zf = inv(H)
#                 S_est = W_zf @ Y_data
#             except np.linalg.LinAlgError:
#                 warnings.warn("ZF inversion failed; switching to pseudo-inverse.")
#                 W_zf = pinv(H)
#                 S_est = W_zf @ Y_data
#         else:
#             sigma2 = max(noise_var, 1e-12)
#             try:
#                 W_mmse = inv(H.conj().T @ H + sigma2 * np.eye(self.n_modes)) @ H.conj().T
#                 S_est = W_mmse @ Y_data
#             except np.linalg.LinAlgError:
#                 warnings.warn("MMSE matrix inversion failed; fallback to pinv(H).")
#                 W_mmse = pinv(H).conj().T
#                 S_est = W_mmse @ Y_data

#         if verbose:
#             print(f"   Equalized symbols shape: {S_est.shape} (modes x symbols)")
#             print(f"   Sample post-eq symbol (mode 0, first 5): {S_est[0, :5]}")

#         if verbose:
#             print("6) Demodulation (QPSK)")
#         s_est_flat = S_est.flatten()
#         IDEAL_THRESHOLD = 1e-4
#         if noise_var < IDEAL_THRESHOLD:
#             if verbose:
#                 print("   Low noise: hard decisions.")
#             received_bits = self.qpsk.demodulate_hard(s_est_flat)
#             llrs = None
#         else:
#             if verbose:
#                 print("   Using soft LLRs for demodulation.")
#             llrs = self.qpsk.demodulate_soft(s_est_flat, noise_var)
#             received_bits = (llrs < 0).astype(int)

#         if verbose:
#             print(f"   Demodulated coded bits: {len(received_bits)}")

#         if verbose:
#             print("7) LDPC decoding")
#         decoded_info_bits = np.array([], dtype=int)
#         if bypass_ldpc or (self.ldpc is None):
#             # bypass or ldpc not available: return coded bits directly (useful for demo)
#             decoded_info_bits = received_bits.copy()
#             if verbose:
#                 print(f"   LDPC bypass: returning {len(decoded_info_bits)} bits (no decoding)")
#         else:
#             try:
#                 if llrs is not None:
#                     decoded_info_bits = self.ldpc.decode_bp(llrs)
#                     if verbose:
#                         print(f"   Decoded info bits (BP): {len(decoded_info_bits)}")
#                 else:
#                     decoded_info_bits = self.ldpc.decode_hard(received_bits)
#                     if verbose:
#                         print(f"   Decoded info bits (hard): {len(decoded_info_bits)}")
#             except Exception as e:
#                 warnings.warn(f"LDPC decode failed: {e}; falling back to hard bits.")
#                 decoded_info_bits = received_bits

#         if verbose:
#             print("8) Performance metrics (BER)")
#         orig = np.asarray(original_data_bits, dtype=int)
#         L_orig = len(orig)
#         L_rec = len(decoded_info_bits)
#         compare_len = min(L_orig, L_rec)
#         if compare_len == 0 and L_orig > 0:
#             bit_errors = L_orig
#             ber = 1.0
#         else:
#             trimmed_orig = orig[:compare_len]
#             trimmed_rec = decoded_info_bits[:compare_len]
#             bit_errors_common = np.sum(trimmed_orig != trimmed_rec) if compare_len > 0 else 0
#             len_mismatch = abs(L_orig - L_rec)
#             bit_errors = int(bit_errors_common + len_mismatch)
#             ber = bit_errors / L_orig if L_orig > 0 else 0.0

#         self.metrics = {
#             "H_est": H_est,
#             "noise_var": noise_var,
#             "bit_errors": int(bit_errors),
#             "total_bits": int(L_orig),
#             "ber": float(ber),
#             "n_data_symbols": int(N_data),
#             "n_modes": int(self.n_modes),
#             "cond_H": float(np.linalg.cond(H_est))
#         }

#         if verbose:
#             print(f"   Original bits: {L_orig}, Decoded bits: {L_rec}, Errors: {bit_errors}, BER={ber:.3e}")
#             print("=" * 72)

#         return decoded_info_bits, self.metrics


# # ---------------------------
# # Plot helper
# # ---------------------------
# def plot_constellation(rx_symbols, title="Received Constellation"):
#     plt.figure(figsize=(5, 5))
#     plt.plot(np.real(rx_symbols), np.imag(rx_symbols), ".", alpha=0.6)
#     plt.axhline(0, color="grey")
#     plt.axvline(0, color="grey")
#     plt.title(title)
#     plt.xlabel("I")
#     plt.ylabel("Q")
#     plt.axis("equal")
#     plt.grid(True)
#     plt.show()


# # ---------------------------
# # Demo main (rectified)
# # ---------------------------
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description="Receiver demo for FSO-MDM OAM (lightweight).")
#     parser.add_argument("--realistic", action="store_true", help="Do NOT orthonormalize spatial refs (use realistic, possibly ill-conditioned refs).")
#     parser.add_argument("--debug", action="store_true", help="Enable a few debug prints.")
#     parser.add_argument("--use-ldpc", action="store_true", help="Try to run LDPC decoding (requires transmitter LDPC match).")
#     parser.add_argument("--snr", type=float, default=25.0, help="SNR in dB for synthesized AWGN.")
#     parser.add_argument("--nframes", type=int, default=256, help="Number of spatial frames to synthesize.")
#     args, unknown = parser.parse_known_args()

#     # lightweight synth function used by demo (vectorized)
#     def synthesize_spatial_sequence(tx_frame, spatial_modes, demuxer, n_samples=256, snr_db=30, apply_turb=False, angular_prop=None, orthonormalize_refs=True):
#         grid_info = tx_frame.grid_info
#         X, Y, delta, x, y = reconstruct_grid_from_gridinfo(grid_info)
#         N = X.shape[0]
#         M = len(spatial_modes)

#         lengths = [len(tx_frame.tx_signals[mk]['symbols']) for mk in spatial_modes]
#         total_len = min(lengths)
#         T = min(total_len, n_samples)

#         # fetch reference fields
#         refs_list = []
#         for mk in spatial_modes:
#             sig = tx_frame.tx_signals.get(mk, {})
#             beam_obj = sig.get("beam", None)
#             rf = demuxer.reference_field(mk, X, Y, delta, grid_z=demuxer.z_distance, tx_beam_obj=beam_obj)
#             refs_list.append(rf)
#         refs_stack = np.stack(refs_list, axis=0)  # (M,N,N)

#         # Optionally orthonormalize w.r.t area-weighted inner product
#         if orthonormalize_refs:
#             V = refs_stack.reshape(M, -1).T  # (N*N, M)
#             Q, Rq = np.linalg.qr(V, mode='reduced')
#             ortho_stack = Q.T.reshape(M, N, N).copy()
#             # area normalize
#             for m in range(M):
#                 col = ortho_stack[m]
#                 norm = np.sqrt(np.sum(np.abs(col)**2) * (delta**2))
#                 if norm == 0:
#                     raise RuntimeError("Zero-energy reference after orthonormalization")
#                 ortho_stack[m] = col / norm
#             refs_stack = ortho_stack
#             # refill demux cache with orthonormal refs
#             for mi, mk in enumerate(spatial_modes):
#                 key = demuxer._make_ref_key(mk, N, delta, X.shape)
#                 demuxer._ref_cache[key] = refs_stack[mi].copy()

#         # prepare tx symbol matrix
#         tx_sym_matrix = np.zeros((M, T), dtype=np.complex128)
#         for mi, mk in enumerate(spatial_modes):
#             tx_sym_matrix[mi, :] = np.asarray(tx_frame.tx_signals[mk]['symbols'])[:T]

#         rx_fields = np.zeros((T, N, N), dtype=np.complex128)
#         snr_lin = 10.0 ** (snr_db / 10.0)

#         for t in range(T):
#             s_vec = tx_sym_matrix[:, t]
#             E = np.tensordot(s_vec, refs_stack, axes=(0, 0))
#             if apply_turb and (angular_prop is not None):
#                 try:
#                     E = angular_prop(E, delta, demuxer.wavelength, demuxer.z_distance)
#                 except Exception as e:
#                     print("  Warning: angular_prop failed; continuing: ", e)

#             total_power = np.sum(np.abs(E)**2) * (delta**2)
#             if total_power <= 0:
#                 total_power = 1.0
#             E = E / np.sqrt(total_power)

#             noise_power_total = 1.0 / max(snr_lin, 1e-12)
#             noise_var_per_sample = noise_power_total / (N * N)
#             sigma = np.sqrt(noise_var_per_sample / 2.0)
#             noise = sigma * (np.random.randn(N, N) + 1j * np.random.randn(N, N))

#             E_noisy = E + noise
#             rx_fields[t] = E_noisy

#         return rx_fields, tx_sym_matrix, refs_stack

#     # ---- Build conservative demo tx_frame ----
#     from encoding import FSO_MDM_Frame

#     N_demo = 256
#     D_demo = 0.45
#     x = np.linspace(-D_demo/2, D_demo/2, N_demo)
#     y = x.copy()
#     grid_info = {"x": x, "y": y, "N": N_demo, "D": D_demo}

#     spatial_modes = [(0,1),(0,-1),(0,2),(0,-2)]
#     frame_len = 512
#     pilot_count = 64
#     pilot_positions = np.arange(pilot_count)
#     rng = np.random.RandomState(42)
#     const = np.array([1+1j, -1+1j, -1-1j, 1-1j], dtype=complex)/np.sqrt(2)

#     tx_signals = {}
#     M = len(spatial_modes)
#     # create orthogonal pilots (DFT rows across pilot_count) to give well-conditioned P_p
#     pilots_matrix = np.exp(1j * 2.0 * np.pi * (np.arange(M)[:, None] * np.arange(pilot_count)[None, :]) / float(pilot_count))

#     for mi, mk in enumerate(spatial_modes):
#         syms = const[rng.randint(0,4, frame_len)].astype(complex)
#         syms[pilot_positions] = pilots_matrix[mi, :]
#         tx_signals[mk] = {"symbols": syms, "pilot_positions": list(pilot_positions), "n_symbols": int(len(syms)), "beam": None}

#     tx_frame = FSO_MDM_Frame(tx_signals, multiplexed_field=None, grid_info=grid_info, metadata={})

#     spatial_modes = list(tx_frame.metadata["spatial_modes"])
#     wavelength = getattr(tx_frame, "wavelength", 1550e-9)
#     w0 = getattr(tx_frame, "w0", 25e-3)
#     z_distance = getattr(tx_frame, "z_distance", 0.0)

#     demux = OAMDemultiplexer(spatial_modes, wavelength, w0, z_distance, angular_prop_func=angular_spectrum_propagation)
#     class _StubPilot:
#         def __init__(self, pos):
#             self.pilot_positions = np.asarray(pos, dtype=int)
#     pilot_handler = _StubPilot(tx_frame.metadata.get("pilot_positions", list(pilot_positions)))

#     ldpc_inst = getattr(tx_frame, "ldpc", None) if args.use_ldpc else None
#     fsorx = FSORx(spatial_modes, wavelength, w0, z_distance, pilot_handler, ldpc_instance=ldpc_inst, eq_method="auto")

#     print(f"\nSynthesizing {args.nframes} spatial frames at SNR={args.snr} dB ...")
#     rx_fields, tx_sym_matrix, refs_stack = synthesize_spatial_sequence(tx_frame, spatial_modes, demux, n_samples=args.nframes, snr_db=args.snr, apply_turb=False, angular_prop=angular_spectrum_propagation, orthonormalize_refs=not args.realistic)

#     # quick Gram check
#     M = len(spatial_modes)
#     V_ortho = refs_stack.reshape(M, -1)
#     G = V_ortho @ V_ortho.conj().T * ((x[1]-x[0])**2)
#     print("\nOrthonormalized spatial Gram (real-rounded):")
#     print(np.round(G.real, 6))

#     # noiseless sanity test
#     try:
#         s0 = tx_sym_matrix[:, 0]
#         E0 = np.tensordot(s0, refs_stack, axes=(0, 0))
#         proj0 = demux.project_field(E0, tx_frame.grid_info, tx_frame=tx_frame)
#         proj_vec = np.array([proj0[mk] for mk in spatial_modes])
#         print("\nNoiseless test: TX s0:", s0)
#         print("Projected s0:", proj_vec)
#     except Exception as e:
#         print("Noiseless sanity check failed:", e)

#     # Build original coded bits from TX symbols via QPSK hard demap (option A)
#     qpsk = QPSKModulator(symbol_energy=1.0)

#     # Determine pilot positions (same as receiver)
#     pilot_positions = np.asarray(tx_frame.metadata.get("pilot_positions", list(pilot_positions)), dtype=int)
#     T_sent = tx_sym_matrix.shape[1]

#     # Build a data mask: True for data, False for pilot (same logic used by FSORx)
#     data_mask = np.ones(T_sent, dtype=bool)
#     if pilot_positions.size > 0:
#         valid_pilots = pilot_positions[pilot_positions < T_sent]
#         data_mask[valid_pilots] = False

#     # Extract tx data symbols (mode-major order) and flatten
#     tx_data_symbols = tx_sym_matrix[:, data_mask]   # shape (M, N_data)
#     tx_data_flat = tx_data_symbols.flatten(order='C')  # mode-major flatten

#     # Demap to coded bits (these are the "original_data_bits" for receiver BER calc)
#     original_coded_bits = qpsk.demodulate_hard(tx_data_flat)

#     if args.debug:
#         print("\nDEBUG: tx_data_symbols_flat (first 8):", tx_symbols_flat[:8])
#         print("DEBUG: original_coded_bits (len):", len(original_coded_bits))

#     # Run the receiver (bypass LDPC by default in demo)
#     print("Running FSORx.receive_frame() ...")
#     decoded_bits, metrics = fsorx.receive_frame(rx_fields, tx_frame, original_coded_bits, verbose=not args.debug, bypass_ldpc=True)

#     # diagnostics plots (safe-guarded)
#     try:
#         plt.figure(figsize=(6,5))
#         plt.title("Example received intensity (frame 0)")
#         plt.imshow(np.abs(rx_fields[0])**2, origin="lower")
#         plt.colorbar(label="Intensity [a.u.]")
#         plt.show()
#     except Exception:
#         pass

#     try:
#         rx_symbols = demux.extract_symbols_sequence(rx_fields, tx_frame.grid_info, tx_frame=tx_frame)
#         first_mode = spatial_modes[0]
#         sample_syms = rx_symbols[first_mode][:128]
#         plot_constellation(sample_syms, title=f"Projected symbols (mode {first_mode})")
#     except Exception:
#         pass

#     print("\nSanity check complete. Receiver metrics (if available):")
#     for k, v in metrics.items():
#         if hasattr(v, "shape"):
#             print(f"  {k}: shape {getattr(v,'shape', None)}")
#         else:
#             print(f"  {k}: {v}")

In [16]:
# receiver.py -- Receiver for FSO-MDM OAM system (rectified to match rectified encoder)
# Requirements: numpy, scipy, matplotlib, encoding.py, turbulence.py, lgBeam.py (optional but preferred)
import os
import sys
import warnings
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import inv, pinv
from scipy.fft import fft2, ifft2
from typing import Dict, Tuple, Any, Optional

# script dir resolution (same pattern used in other modules)
try:
    SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
    SCRIPT_DIR = os.getcwd()
sys.path.insert(0, SCRIPT_DIR)

# imports from your rectified modules
try:
    from encoding import QPSKModulator, PilotHandler, PyLDPCWrapper, FSO_MDM_Frame
except Exception as e:
    raise ImportError(f"receiver.py requires encoding.py in the same directory: {e}")

try:
    from turbulence import angular_spectrum_propagation
except Exception as e:
    angular_spectrum_propagation = None
    warnings.warn(f"turbulence.angular_spectrum_propagation not available: {e}")

# lgBeam may or may not be needed (we prefer using beam instances attached to frame)
try:
    from lgBeam import LaguerreGaussianBeam
except Exception:
    LaguerreGaussianBeam = None

warnings.filterwarnings("ignore")


# -------------------------------------------------------
# Utilities: grid reconstruction, normalization helpers
# -------------------------------------------------------
def reconstruct_grid_from_gridinfo(grid_info: Dict[str, Any]):
    """
    grid_info expected keys (from encoding._generate_spatial_field):
      - x: 1D array
      - y: 1D array
      - grid_size or extent_m (not mandatory)
    Returns X,Y,delta (float), x,y arrays
    """
    if grid_info is None:
        raise ValueError("grid_info is required to reconstruct spatial grid.")
    x = np.asarray(grid_info.get("x"))
    y = np.asarray(grid_info.get("y"))
    if x.size == 0 or y.size == 0:
        raise ValueError("grid_info.x/y empty or missing.")
    X, Y = np.meshgrid(x, y, indexing="ij")
    # sampling interval (assume uniform)
    delta_x = np.mean(np.diff(x))
    delta_y = np.mean(np.diff(y))
    if not np.isclose(delta_x, delta_y, rtol=1e-6, atol=0.0):
        warnings.warn("Non-square sampling intervals detected; using delta_x as delta.")
    delta = float(delta_x)
    return X, Y, delta, x, y


def energy_normalize_field(field: np.ndarray, delta: float):
    """
    Normalize field so that total power (sum |E|^2 * delta^2) == 1
    """
    p = np.sum(np.abs(field) ** 2) * (delta ** 2)
    if p > 0:
        return field / np.sqrt(p)
    return field


# -------------------------------------------------------
# OAM Demultiplexer (mode projection)
# -------------------------------------------------------
class OAMDemultiplexer:
    """
    Project received complex field onto reference LG modes (uses transmitter's beam objects when present).
    """

    def __init__(self, spatial_modes, wavelength, w0, z_distance, angular_prop_func=angular_spectrum_propagation):
        self.spatial_modes = list(spatial_modes)
        self.n_modes = len(self.spatial_modes)
        self.wavelength = wavelength
        self.w0 = w0
        self.z_distance = z_distance
        self.angular_prop = angular_prop_func
        # cache for reference fields per grid size/delta
        self._ref_cache = {}

    def _make_ref_key(self, mode_key, N, delta, X_shape):
        # Use immutable and small tuple for cache key
        return (mode_key, int(N), float(delta), int(X_shape[0]), int(X_shape[1]))

    def reference_field(self, mode_key: Tuple[int, int], X, Y, delta, grid_z, tx_beam_obj=None):
        """
        Construct (or retrieve) propagated reference field for mode_key on grid X,Y (at z=grid_z).
        If tx_beam_obj is provided (the beam instance saved by transmitter), use it to generate reference at z=0 then propagate.
        """
        N = X.shape[0]
        key = self._make_ref_key(mode_key, N, delta, X.shape)
        if key in self._ref_cache:
            return self._ref_cache[key].copy()

        # compute R,PHI and generate reference field at z=0 via provided beam or by constructing LaguerreGaussianBeam
        R = np.sqrt(X ** 2 + Y ** 2)
        PHI = np.arctan2(Y, X)

        beam = tx_beam_obj
        if beam is None:
            # try to instantiate if lgBeam available (fallback)
            p, l = mode_key
            if LaguerreGaussianBeam is None:
                raise RuntimeError("No beam instance available and lgBeam import missing.")
            beam = LaguerreGaussianBeam(p, l, self.wavelength, self.w0)

        ref_z0 = beam.generate_beam_field(R, PHI, 0.0)
        # propagate numerically to grid z if requested (+ uses angular spectrum function if available)
        if self.angular_prop is None or grid_z == 0.0:
            ref = ref_z0
        else:
            ref = self.angular_prop(ref_z0.copy(), delta, self.wavelength, grid_z)

        # store (aperture-unmasked) reference in cache
        self._ref_cache[key] = ref.copy()
        return ref

    def project_field(self, E_rx, grid_info, receiver_radius=None, tx_frame=None):
        """
        Project single complex field (E_rx) onto modes in self.spatial_modes.
        - grid_info: from frame.grid_info
        - tx_frame: optional FSO_MDM_Frame to pull beam instances / pilot positions
        Returns dict mapping mode_key -> complex projection symbol (per-slice).
        """
        X, Y, delta, x, y = reconstruct_grid_from_gridinfo(grid_info)
        R = np.sqrt(X ** 2 + Y ** 2)

        # if E_rx is intensity-only (real, >=0) attempt to construct amplitude; warn user
        if not np.iscomplexobj(E_rx):
            warnings.warn("E_rx appears to be real (intensity). Assuming zero-phase sqrt(I) field for projection.")
            E_rx = np.sqrt(np.abs(E_rx)).astype(np.complex128)

        dA = float(delta ** 2)
        if receiver_radius is not None:
            aperture_mask = (R <= receiver_radius).astype(float)
        else:
            aperture_mask = np.ones_like(R, dtype=float)

        symbols = {}
        N = X.shape[0]

        for mode_key in self.spatial_modes:
            tx_beam_obj = None
            if tx_frame is not None:
                # try to get beam instance stored in frame.tx_signals
                sig = tx_frame.tx_signals.get(mode_key)
                if sig is not None:
                    tx_beam_obj = sig.get("beam", None)

            try:
                ref = self.reference_field(mode_key, X, Y, delta, grid_z=self.z_distance, tx_beam_obj=tx_beam_obj)
            except Exception as e:
                warnings.warn(f"Cannot build reference for mode {mode_key}: {e}; returning 0 projection.")
                symbols[mode_key] = 0.0 + 0.0j
                continue

            ref_ap = ref * aperture_mask
            ref_energy = np.sum(np.abs(ref_ap) ** 2) * dA
            projection = np.sum(E_rx * np.conj(ref_ap)) * dA
            if ref_energy > 1e-20:
                symbols[mode_key] = projection / ref_energy
            else:
                symbols[mode_key] = 0.0 + 0.0j
        return symbols

    def extract_symbols_sequence(self, E_rx_sequence, grid_info, receiver_radius=None, tx_frame=None):
        """
        Accepts E_rx_sequence as:
          - list/np.ndarray of 2D complex fields (n_frames x N x N)
          - or single 2D field -> returns single-column arrays
        Returns symbols_per_mode: dict mode_key -> complex array (len = n_frames)
        """
        # convert to array-like
        seq = np.asarray(E_rx_sequence)
        if seq.ndim == 2:
            seq = seq[np.newaxis, ...]  # shape (1, N, N)
        n_frames = seq.shape[0]
        symbols_per_mode = {mode: np.zeros(n_frames, dtype=complex) for mode in self.spatial_modes}
        for i in range(n_frames):
            symbols_snapshot = self.project_field(seq[i], grid_info, receiver_radius, tx_frame=tx_frame)
            for mode in self.spatial_modes:
                symbols_per_mode[mode][i] = symbols_snapshot.get(mode, 0.0 + 0.0j)
        return symbols_per_mode


# -------------------------------------------------------
# Channel estimator (LS + optional MMSE fallback)
# -------------------------------------------------------
class ChannelEstimator:
    """
    LS channel estimator using pilot symbols. Expects tx_signals shaped per encodingRunner frame.
    """

    def __init__(self, pilot_handler: PilotHandler, spatial_modes):
        self.pilot_handler = pilot_handler
        self.spatial_modes = list(spatial_modes)
        self.M = len(self.spatial_modes)
        self.H_est = None
        self.noise_var_est = None

    def _gather_pilots(self, rx_symbols_per_mode: Dict[Tuple[int,int], np.ndarray],
                       tx_frame: FSO_MDM_Frame):
        """
        Construct Y_pilot (M x n_p) and P_pilot (M x n_p) aligned on valid pilot indices within frame length.
        Uses pilot positions & pilot_sequence stored in tx_frame.tx_signals[mode] (per-mode). It computes the
        intersection of pilot positions across modes to ensure consistent columns for LS solving.
        Returns: Y_p (M x n_p), P_p (M x n_p), valid_positions (1D array of global frame indices)
        """
        # tx_signals must be present
        if tx_frame is None or not hasattr(tx_frame, "tx_signals"):
            return None, None, np.array([], dtype=int)

        # collect per-mode pilot positions (as numpy arrays)
        per_mode_positions = {}
        for mk in self.spatial_modes:
            sig = tx_frame.tx_signals.get(mk)
            if sig is None:
                # if any mode missing tx info, we cannot form consistent pilot matrix
                return None, None, np.array([], dtype=int)
            pos = np.asarray(sig.get("pilot_positions", []), dtype=int)
            per_mode_positions[mk] = pos

        # compute intersection (positions present in all modes)
        # If a mode has no pilots, intersection becomes empty -> no LS possible
        all_pos_sets = [set(pos.tolist()) for pos in per_mode_positions.values()]
        if not all_pos_sets:
            return None, None, np.array([], dtype=int)
        inter = set.intersection(*all_pos_sets)
        if len(inter) == 0:
            # no common pilot times across modes; fall back to using first mode's pilot positions
            warnings.warn("No common pilot positions across modes. Falling back to first mode pilot positions (may be invalid).")
            first = self.spatial_modes[0]
            valid_positions = np.asarray(sorted(list(per_mode_positions[first])), dtype=int)
        else:
            valid_positions = np.asarray(sorted(list(inter)), dtype=int)

        # ensure pilot positions fit into received sequence length
        min_rx_len = min([len(rx_symbols_per_mode[mk]) for mk in self.spatial_modes])
        valid_positions = valid_positions[valid_positions < min_rx_len]
        if valid_positions.size == 0:
            return None, None, np.array([], dtype=int)

        n_p = len(valid_positions)
        Y_p = np.zeros((self.M, n_p), dtype=complex)
        P_p = np.zeros((self.M, n_p), dtype=complex)

        for idx, mk in enumerate(self.spatial_modes):
            # received pilots on this mode (aligned to global frame indices)
            Y_p[idx, :] = rx_symbols_per_mode[mk][valid_positions]

            # TX pilot values: either available in 'pilot_sequence' (with mapping to pilot_positions),
            # or directly readable from tx_frame.tx_signals[mk]["symbols"] at those indices.
            sig = tx_frame.tx_signals[mk]
            tx_syms = sig.get("symbols", np.array([], dtype=complex))
            tx_pseq = sig.get("pilot_sequence", None)
            tx_ppos = np.asarray(sig.get("pilot_positions", []), dtype=int)

            if tx_pseq is not None and tx_ppos is not None and tx_ppos.size > 0:
                # map valid_positions -> indices inside tx_ppos to fetch pilot values
                # build a mapping from position -> index in tx_pseq (pilot order)
                pos_to_idx = {int(p): i for i, p in enumerate(tx_ppos.tolist())}
                row_vals = []
                for p in valid_positions:
                    if int(p) in pos_to_idx:
                        row_vals.append(tx_pseq[pos_to_idx[int(p)]])
                    else:
                        # as fallback use symbol value at that global index
                        if int(p) < len(tx_syms):
                            row_vals.append(tx_syms[int(p)])
                        else:
                            row_vals.append(0.0 + 0.0j)
                P_p[idx, :] = np.asarray(row_vals, dtype=complex)
            else:
                # fallback: directly index tx_syms at global pilot indices
                safe_vals = []
                for p in valid_positions:
                    if int(p) < len(tx_syms):
                        safe_vals.append(tx_syms[int(p)])
                    else:
                        safe_vals.append(0.0 + 0.0j)
                P_p[idx, :] = np.asarray(safe_vals, dtype=complex)

        return Y_p, P_p, valid_positions

    def estimate_channel_ls(self, rx_symbols_per_mode: Dict[Tuple[int, int], np.ndarray], tx_frame: FSO_MDM_Frame):
        Y_p, P_p, pilot_pos = self._gather_pilots(rx_symbols_per_mode, tx_frame)
        if Y_p is None or P_p is None or P_p.size == 0:
            warnings.warn("No valid pilots found for LS channel estimation. Returning identity H.")
            self.H_est = np.eye(self.M, dtype=complex)
            return self.H_est

        # LS: H = Y_p * P_p^H * (P_p P_p^H)^{-1}
        try:
            PPH = P_p @ P_p.conj().T  # M x M
            # regularize PPH for stability
            cond_pph = np.linalg.cond(PPH) if np.all(np.isfinite(PPH)) else np.inf
            if not np.isfinite(cond_pph) or cond_pph > 1e6:
                # small ridge regularization to stabilize inversion
                reg = 1e-6 * np.trace(PPH) / max(1.0, PPH.shape[0])
                PPH_reg = PPH + reg * np.eye(PPH.shape[0])
                try:
                    invPPH = inv(PPH_reg)
                except Exception:
                    invPPH = pinv(PPH_reg)
            else:
                invPPH = inv(PPH)
            H = Y_p @ P_p.conj().T @ invPPH
        except Exception as e:
            warnings.warn(f"Pilot LS solve failed ({e}); using pseudo-inverse fallback.")
            try:
                H = Y_p @ pinv(P_p)
            except Exception as e2:
                warnings.warn(f"Pseudo-inverse fallback failed ({e2}); returning identity H.")
                H = np.eye(self.M, dtype=complex)

        # sanitize (replace nan/inf by zeros)
        H = np.asarray(H, dtype=complex)
        nan_mask = ~np.isfinite(H)
        if np.any(nan_mask):
            warnings.warn("Channel estimate contains non-finite entries; sanitizing to zeros.")
            H[nan_mask] = 0.0 + 0.0j

        # small regularization: if H extremely ill-conditioned, damp it to avoid numerical blow-up downstream
        try:
            cond_H = np.linalg.cond(H)
            if not np.isfinite(cond_H) or cond_H > 1e12:
                warnings.warn(f"H estimate extremely ill-conditioned (cond={cond_H:.2e}). Regularizing H slightly.")
                # shrink towards identity in small amount
                alpha = 1e-6
                H = (1.0 - alpha) * H + alpha * np.eye(self.M, dtype=complex)
        except Exception:
            pass

        self.H_est = H
        return H

    def estimate_noise_variance(self, rx_symbols_per_mode: Dict[Tuple[int,int], np.ndarray],
                                tx_frame: FSO_MDM_Frame, H_est: np.ndarray):
        # compute residual on pilot positions
        Y_p, P_p, pilot_pos = self._gather_pilots(rx_symbols_per_mode, tx_frame)
        if Y_p is None or P_p is None or P_p.size == 0:
            # fallback to tiny noise var (quiet)
            self.noise_var_est = 1e-6
            return self.noise_var_est
        residual = Y_p - H_est @ P_p
        noise_var = np.mean(np.abs(residual) ** 2)
        noise_var = max(noise_var, 1e-12)  # floor
        self.noise_var_est = float(noise_var)
        return self.noise_var_est


# -------------------------------------------------------
# FSORx: Full receiver pipeline (ZF/MMSE + LDPC decode)
# -------------------------------------------------------
class FSORx:
    def __init__(self, spatial_modes, wavelength, w0, z_distance,
                 pilot_handler: PilotHandler,
                 ldpc_instance: Optional[PyLDPCWrapper] = None,
                 eq_method: str = "zf", receiver_radius: Optional[float] = None):
        self.spatial_modes = list(spatial_modes)
        self.n_modes = len(self.spatial_modes)
        self.wavelength = wavelength
        self.w0 = w0
        self.z_distance = z_distance
        self.pilot_handler = pilot_handler
        self.eq_method = eq_method.lower()
        self.receiver_radius = receiver_radius

        # QPSK mapping must match encoding.QPSKModulator
        self.qpsk = QPSKModulator(symbol_energy=1.0)

        # LDPC: prefer shared instance from transmitter for exact parity matrix
        if ldpc_instance is not None:
            self.ldpc = ldpc_instance
        else:
            # try to construct default wrapper (requires pyldpc)
            try:
                self.ldpc = PyLDPCWrapper(n=2048, rate=0.8, dv=2, dc=8, seed=42)
                warnings.warn("No LDPC instance provided; receiver created local PyLDPCWrapper that may not match TX.")
            except Exception as e:
                raise RuntimeError(f"Cannot construct LDPC wrapper. Provide ldpc_instance from transmitter. Error: {e}")

        self.demux = OAMDemultiplexer(self.spatial_modes, self.wavelength, self.w0, self.z_distance)
        self.chan_est = ChannelEstimator(self.pilot_handler, self.spatial_modes)
        self.metrics = {}

    def receive_frame(self, rx_field_sequence, tx_frame: FSO_MDM_Frame,
                      original_data_bits: np.ndarray, verbose: bool = True):
        """
        Main receiver routine:
          - rx_field_sequence: array-like of complex fields (n_frames x N x N) or single 2D field
          - tx_frame: FSO_MDM_Frame produced by encodingRunner.transmit(...)
          - original_data_bits: ground-truth info bits for final BER calculation
        Returns: decoded_bits (1D np.array int), metrics dict
        """
        if verbose:
            print("\n" + "=" * 72)
            print("FSO-OAM Receiver: Start")
            print("=" * 72)

        # grid info from tx_frame (required)
        grid_info = tx_frame.grid_info
        if grid_info is None:
            raise ValueError("tx_frame.grid_info required for demux/projection.")

        # 1) Demultiplex: projection
        if verbose: print("1) OAM demultiplexing (projection)...")
        rx_symbols_per_mode = self.demux.extract_symbols_sequence(rx_field_sequence, grid_info,
                                                                   receiver_radius=self.receiver_radius,
                                                                   tx_frame=tx_frame)
        if verbose:
            first_mode = self.spatial_modes[0]
            print(f"   Extracted {len(rx_symbols_per_mode[first_mode])} symbols per mode (incl. pilots).")

        # 2) Channel estimation (LS using pilots)
        if verbose: print("2) Channel estimation (LS using pilots)...")
        H_est = self.chan_est.estimate_channel_ls(rx_symbols_per_mode, tx_frame)
        if verbose:
            # ensure H_est is numeric matrix
            try:
                H_abs = np.abs(H_est)
                print("   H_est magnitude (rows):")
                for row in H_abs:
                    print("     [" + " ".join(f"{v:.5f}" for v in row) + "]")
                condH = np.linalg.cond(H_est)
                print(f"   cond(H_est) = {condH:.2e}")
            except Exception:
                print("   H_est unavailable for pretty print.")

        # 3) Noise estimate from pilot residuals
        if verbose: print("3) Noise variance estimation...")
        noise_var = self.chan_est.estimate_noise_variance(rx_symbols_per_mode, tx_frame, H_est)
        if verbose:
            print(f"   Estimated noise variance σ² = {noise_var:.3e}")

        # 4) Separate pilots and data symbols (use pilot positions from tx_frame)
        if verbose: print("4) Separate pilots and data")
        # determine pilot positions array from tx_frame (use intersection approach done in ChannelEstimator)
        # reuse the same valid pilot positions used for channel estimation
        _, _, pilot_positions_used = self.chan_est._gather_pilots(rx_symbols_per_mode, tx_frame)
        if pilot_positions_used is None or pilot_positions_used.size == 0:
            pilot_positions_used = np.array([], dtype=int)

        # length check and data mask
        first_mode = self.spatial_modes[0]
        total_rx_symbols = len(rx_symbols_per_mode[first_mode])
        data_mask = np.ones(total_rx_symbols, dtype=bool)
        if pilot_positions_used.size > 0:
            valid_pilots = pilot_positions_used[pilot_positions_used < total_rx_symbols]
            data_mask[valid_pilots] = False

        # Build rx_data_per_mode: M arrays of data symbols (order: spatial_modes order)
        rx_data_per_mode = {mk: rx_symbols_per_mode[mk][data_mask] for mk in self.spatial_modes}
        data_lengths = [len(v) for v in rx_data_per_mode.values()]
        if len(data_lengths) == 0:
            raise ValueError("No modes present in rx_symbols_per_mode.")
        if len(set(data_lengths)) > 1:
            warnings.warn("Uneven data counts across modes; truncating to minimum length.")
            min_len = int(min(data_lengths))
            for mk in self.spatial_modes:
                rx_data_per_mode[mk] = rx_data_per_mode[mk][:min_len]
        if data_lengths and min(data_lengths) == 0:
            raise ValueError("No data symbols available after removing pilots.")

        # Stack into matrix Y_data (M x Ndata), mode-major as transmitter flattening expects
        Y_data = np.vstack([rx_data_per_mode[mk] for mk in self.spatial_modes])  # shape M x Ndata
        N_data = Y_data.shape[1]
        if verbose:
            print(f"   Data symbols per mode: {N_data}")

        # 5) Equalize (ZF or MMSE). Use reg if ill-conditioned.
        if verbose: print("5) Equalization")
        H = H_est.copy()
        # compute condition
        try:
            cond_H = np.linalg.cond(H)
        except Exception:
            cond_H = np.inf
        if self.eq_method == "auto":
            use_mmse = cond_H > 1e4
        else:
            use_mmse = (self.eq_method == "mmse")

        # ensure H numeric and finite
        if not np.all(np.isfinite(H)):
            warnings.warn("H_est contains non-finite entries; replacing with identity fallback.")
            H = np.eye(self.n_modes, dtype=complex)

        if not use_mmse:
            try:
                W_zf = inv(H)
                S_est = W_zf @ Y_data
            except Exception:
                warnings.warn("ZF inversion failed; switching to pseudo-inverse.")
                W_zf = pinv(H)
                S_est = W_zf @ Y_data
        else:
            # MMSE: W = (H^H H + σ² I)^{-1} H^H
            sigma2 = max(noise_var, 1e-12)
            try:
                W_mmse = inv(H.conj().T @ H + sigma2 * np.eye(self.n_modes)) @ H.conj().T
                S_est = W_mmse @ Y_data
            except Exception:
                warnings.warn("MMSE matrix inversion failed; fallback to pinv(H).")
                W_mmse = pinv(H).conj().T
                S_est = W_mmse @ Y_data

        if verbose:
            print(f"   Equalized symbols shape: {S_est.shape} (modes x symbols)")
            if S_est.size > 0:
                sample = S_est[0, :min(5, S_est.shape[1])]
                print(f"   Sample post-eq symbol (mode 0, first {len(sample)}): {sample}")

        # 6) Demodulate: choose hard vs soft depending on noise variance
        if verbose: print("6) Demodulation (QPSK)")
        # Flatten S_est in row-major (mode-major) to match transmitter flattening (mode0 symbols first, then mode1)
        s_est_flat = S_est.flatten()  # row-major by default -> correct order
        IDEAL_THRESHOLD = 1e-4
        use_soft = (noise_var >= IDEAL_THRESHOLD)
        llrs = None

        if not use_soft:
            if verbose: print("   Low noise: hard decisions.")
            received_bits = self.qpsk.demodulate_hard(s_est_flat)
        else:
            if verbose: print("   Using soft LLRs for demodulation.")
            llrs = self.qpsk.demodulate_soft(s_est_flat, noise_var)
            # Convert llrs to hard bits for fallback/hard path as well
            received_bits = (llrs < 0).astype(int)

        if verbose:
            print(f"   Demodulated coded bits: {len(received_bits)}")

        # 7) LDPC decode (BP if LLRs available else hard)
        if verbose: print("7) LDPC decoding")
        decoded_info_bits = np.array([], dtype=int)
        try:
            if use_soft and llrs is not None:
                # Preferred BP decode with LLRs
                decoded_info_bits = self.ldpc.decode_bp(llrs)
                if verbose:
                    print(f"   Decoded info bits (BP): {len(decoded_info_bits)}")
            else:
                decoded_info_bits = self.ldpc.decode_hard(received_bits)
                if verbose:
                    print(f"   Decoded info bits (hard): {len(decoded_info_bits)}")
        except Exception as e:
            warnings.warn(f"LDPC decode failed ({e}); returning empty decoded bits.")
            decoded_info_bits = np.array([], dtype=int)

        # 8) BER and metrics
        if verbose: print("8) Performance metrics (BER)")
        orig = np.asarray(original_data_bits, dtype=int)
        L_orig = int(len(orig))
        L_rec = int(len(decoded_info_bits))
        compare_len = min(L_orig, L_rec)
        if compare_len == 0 and L_orig > 0:
            bit_errors = L_orig
            ber = 1.0
        else:
            trimmed_orig = orig[:compare_len]
            trimmed_rec = decoded_info_bits[:compare_len]
            bit_errors_common = int(np.sum(trimmed_orig != trimmed_rec)) if compare_len > 0 else 0
            # count length mismatch as errors
            len_mismatch = abs(L_orig - L_rec)
            bit_errors = bit_errors_common + len_mismatch
            ber = float(bit_errors) / float(L_orig) if L_orig > 0 else 0.0

        # store metrics
        cond_H_metric = np.nan
        try:
            cond_H_metric = float(np.linalg.cond(H_est)) if H_est is not None else np.nan
            if not np.isfinite(cond_H_metric):
                cond_H_metric = float(np.inf)
        except Exception:
            cond_H_metric = float(np.inf)

        self.metrics = {
            "H_est": H_est,
            "noise_var": float(noise_var),
            "bit_errors": int(bit_errors),
            "total_bits": int(L_orig),
            "ber": float(ber),
            "n_data_symbols": int(N_data),
            "n_modes": int(self.n_modes),
            "cond_H": cond_H_metric
        }

        if verbose:
            print(f"   Original bits: {L_orig}, Decoded bits: {L_rec}, Errors: {bit_errors}, BER={ber:.3e}")
            print("=" * 72)

        return decoded_info_bits, self.metrics


# -------------------------------------------------------
# Example helper (diagnostic plotting)
# -------------------------------------------------------
def plot_constellation(rx_symbols, title="Received Constellation"):
    plt.figure(figsize=(5, 5))
    plt.plot(rx_symbols.real, rx_symbols.imag, ".", alpha=0.6)
    plt.axhline(0, color="grey"); plt.axvline(0, color="grey")
    plt.title(title); plt.xlabel("I"); plt.ylabel("Q"); plt.axis("equal")
    plt.grid(True)
    plt.show()


# ---------------------------
# If module run directly, provide a small sanity check sketch (no I/O)
# ---------------------------
if __name__ == "__main__":
    print("receiver.py module - run FSORx from your pipeline to test.")

receiver.py module - run FSORx from your pipeline to test.
