In [1]:
import numpy as np

def calculate_ber_per_channel(estimated_channel, rx_signal, tx_pilots, channel_type, mod_scheme="QPSK"):
    """
    Calculate BER for shallow coastal or continental shelf channel (per frontiers.pdf).
    Args:
        estimated_channel: (batch_size, num_subcarriers, 2) → [real, imaginary]
        rx_signal: Received signal (batch_size, time_steps, num_subcarriers, pilot_dim)
        tx_pilots: Transmitted pilots (batch_size, time_steps, num_subcarriers)
        channel_type: "shallow_coastal" or "continental_shelf"
        mod_scheme: "QPSK" (primary modulation in paper)
    Returns:
        ber: Average BER for the channel
    """
    # Convert estimated channel to complex
    h_est = estimated_channel[..., 0] + 1j * estimated_channel[..., 1]  # (batch, N)
    
    # Use pilot symbols from 1st time step (paper's receiver logic)
    rx_complex = rx_signal[:, 0, :, 0] + 1j * rx_signal[:, 0, :, 1]  # (batch, N)
    tx_complex = tx_pilots[:, 0, :].astype(np.complex128)  # (batch, N)
    
    # Compensate channel: ŷ = rx / h_est (mitigate channel distortion)
    y_hat = rx_complex / h_est
    y_hat = y_hat / np.abs(y_hat)  # Normalize to constellation
    
    # QPSK constellation (gray-coded, per paper: 1-285)
    constellation = {
        1+1j: (0, 0),    # 00
        -1+1j: (0, 1),   # 01
        -1-1j: (1, 1),   # 11
        1-1j: (1, 0)     # 10
    }
    bits_per_symbol = 2
    
    # Count bit errors
    total_bits = 0
    error_bits = 0
    for batch_idx in range(y_hat.shape[0]):
        for subcarrier in range(y_hat.shape[1]):
            # Closest constellation point for estimated symbol
            est_symbol = min(constellation.keys(), 
                           key=lambda s: np.abs(y_hat[batch_idx, subcarrier] - s))
            # Closest constellation point for true transmitted symbol
            true_symbol = min(constellation.keys(), 
                            key=lambda s: np.abs(tx_complex[batch_idx, subcarrier] - s))
            
            # Compare bits
            est_bits = constellation[est_symbol]
            true_bits = constellation[true_symbol]
            error_bits += sum(e != t for e, t in zip(est_bits, true_bits))
            total_bits += bits_per_symbol
    
    ber = error_bits / total_bits
    print(f"BER for {channel_type} channel: {ber:.2e}")
    return ber

In [2]:
import numpy as np

def calculate_amplitude_error(true_channel, estimated_channel):
    """
    Amplitude error: | |h| - |ŷ| | (per frontiers.pdf's channel error analysis).
    Args:
        true_channel: (batch_size, num_subcarriers, 2) → [real, imaginary]
        estimated_channel: (batch_size, num_subcarriers, 2) → [real, imaginary]
    Returns:
        amp_error: (batch_size, num_subcarriers) → amplitude error per subcarrier
        avg_amp_error: Average amplitude error over batch
    """
    # True amplitude: |h| = sqrt(real² + imag²)
    true_amp = np.sqrt(true_channel[..., 0]**2 + true_channel[..., 1]** 2)
    # Estimated amplitude: |ŷ|
    est_amp = np.sqrt(estimated_channel[..., 0]**2 + estimated_channel[..., 1]** 2)
    # Amplitude error
    amp_error = np.abs(true_amp - est_amp)
    avg_amp_error = np.mean(amp_error)
    
    print(f"Average amplitude error: {avg_amp_error:.4f}")
    return amp_error, avg_amp_error

In [None]:
import numpy as np

def calculate_phase_error(true_channel, estimated_channel):
    """
    Phase error: | ∠h - ∠ŷ | (per frontiers.pdf's channel error analysis).
    Args:
        true_channel: (batch_size, num_subcarriers, 2) → [real, imaginary]
        estimated_channel: (batch_size, num_subcarriers, 2) → [real, imaginary]
    Returns:
        phase_error: (batch_size, num_subcarriers) → phase error (radians) per subcarrier
        avg_phase_error: Average phase error over batch (radians)
    """
    # True phase: ∠h = arctan2(imag, real)
    true_phase = np.arctan2(true_channel[..., 1], true_channel[..., 0])
    # Estimated phase: ∠ŷ
    est_phase = np.arctan2(estimated_channel[..., 1], estimated_channel[..., 0])
    # Phase error (wrap to [-π, π] to avoid 2π discontinuity)
    phase_error = np.abs(np.mod(true_phase - est_phase + np.pi, 2*np.pi) - np.pi)
    avg_phase_error = np.mean(phase_error)
    
    print(f"Average phase error: {avg_phase_error:.4f} radians")
    return phase_error, avg_phase_error

In [None]:
import numpy as np

def calculate_mse(true_channel, estimated_channel):
    """
    MSE between true and estimated channel coefficients (per frontiers.pdf Eq. 17).
    Args:
        true_channel: (batch_size, num_subcarriers, 2) → [real, imaginary]
        estimated_channel: (batch_size, num_subcarriers, 2) → [real, imaginary]
    Returns:
        mse: Average MSE over batch
    """
    # MSE = E[ (h_real - ŷ_real)² + (h_imag - ŷ_imag)² ]
    mse_real = np.mean((true_channel[..., 0] - estimated_channel[..., 0])**2)
    mse_imag = np.mean((true_channel[..., 1] - estimated_channel[..., 1])** 2)
    mse = (mse_real + mse_imag) / 2  # Average over real/imaginary parts
    
    print(f"Average MSE: {mse:.6f}")
    return mse