In [None]:
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# Step 1: Initialize Parameters (Strictly from Table 4 of frontiers.pdf)
# -----------------------------------------------------------------------------
N_total = 1024               # Total OFDM subcarriers (Table 4)
N_pilots = 256               # Number of pilot subcarriers (Table 4)
fft_size = 1024              # FFT/IFFT size (matches total subcarriers; Table 4)
cp_length = 128              # Cyclic prefix length (Table 4)
modulation = "QPSK"          # Modulation (paper evaluates QPSK/QAM; QPSK as default)
symbols_per_frame = 50       # Number of OFDM symbols per frame (practical for simulation)
fs = 10000                   # Sampling frequency (10 kHz; consistent with UWA acoustic band in paper)

# Exclude DC subcarrier (index 0) from data/pilot allocation (per paper's UWA-OFDM design)
dc_index = 0
valid_indices = np.arange(1, N_total)  # Valid subcarriers: 1–1023 (1023 total)

# Comb-type pilot placement (evenly spaced in valid indices; Table 4: comb-type)
pilot_step = len(valid_indices) // N_pilots  # 1023 // 256 = 4
pilot_indices = valid_indices[::pilot_step][:N_pilots]  # 256 pilots (matches Table 4)

# Data subcarriers: valid indices excluding pilots (1023 - 256 = 767)
data_indices = np.array([idx for idx in valid_indices if idx not in pilot_indices])
# Verify alignment with Table 4 parameters
assert len(pilot_indices) == N_pilots, "Pilot count mismatch with Table 4"
assert len(data_indices) == len(valid_indices) - N_pilots, "Data subcarrier count mismatch with Table 4"

# QPSK pilot symbols (gray-coded; fixed for receiver reference, as in paper)
pilot_symbols = np.tile([1 + 1j, -1 + 1j, -1 - 1j, 1 - 1j], N_pilots // 4)


# -----------------------------------------------------------------------------
# Step 2: Modulation Function (QPSK/16QAM; Matches Paper's Evaluation)
# -----------------------------------------------------------------------------
def modulate_bits(bits, mod_scheme="QPSK"):
    """Modulate binary bits to complex symbols )."""
    if mod_scheme == "QPSK":
        bits_per_symbol = 2
        constellation = {(0,0):1+1j, (0,1):-1+1j, (1,1):-1-1j, (1,0):1-1j}
    elif mod_scheme == "16QAM":
        bits_per_symbol = 4
        constellation = {
            (0,0,0,0):3+3j, (0,0,0,1):1+3j, (0,0,1,0):-1+3j, (0,0,1,1):-3+3j,
            (0,1,0,0):3+1j, (0,1,0,1):1+1j, (0,1,1,0):-1+1j, (0,1,1,1):-3+1j,
            (1,1,0,0):3-1j, (1,1,0,1):1-1j, (1,1,1,0):-1-1j, (1,1,1,1):-3-1j,
            (1,0,0,0):3-3j, (1,0,0,1):1-3j, (1,0,1,0):-1-3j, (1,0,1,1):-3-3j
        }
    else:
        raise ValueError("Modulation must be 'QPSK' or '16QAM'")

    assert len(bits) % bits_per_symbol == 0, f"Bits length must be multiple of {bits_per_symbol}"
    bit_groups = bits.reshape(-1, bits_per_symbol)
    return np.array([constellation[tuple(group)] for group in bit_groups])


# -----------------------------------------------------------------------------
# Step 3: Generate Single OFDM Symbol (Pilot Insertion + CP; Table 4)
# -----------------------------------------------------------------------------
def generate_ofdm_symbol(data_symbols):
    """Generate time-domain OFDM symbol with comb-type pilots and cyclic prefix ."""
    # 1. Initialize frequency-domain symbol (use np.complex128 for NumPy 2.0 compatibility)
    freq_symbol = np.zeros(fft_size, dtype=np.complex128)  # Fixed: np.complex128 instead of np.complex_
    
    # 2. Insert data and pilot symbols (exclude DC subcarrier)
    freq_symbol[data_indices] = data_symbols.astype(np.complex128)  # Ensure complex128 type
    freq_symbol[pilot_indices] = pilot_symbols.astype(np.complex128)
    
    # 3. IFFT: Convert to time domain (FFT shift for DC alignment; power normalization)
    time_symbol = np.fft.ifft(np.fft.fftshift(freq_symbol)) * np.sqrt(fft_size)
    
    # 4. Append cyclic prefix (copy last 'cp_length' samples; Table 4: CP length = 128)
    cp = time_symbol[-cp_length:]
    return np.concatenate([cp, time_symbol])


# -----------------------------------------------------------------------------
# Step 4: Full OFDM Transmitter Pipeline (Per Paper's UWA-OFDM System)
# -----------------------------------------------------------------------------
def ofdm_transmitter(mod_scheme="QPSK"):
    """Full OFDM transmitter: Bits → Modulation → OFDM Frame (per frontiers.pdf)."""
    # 1. Generate random binary data (matches data subcarrier count per symbol)
    bits_per_symbol = 2 if mod_scheme == "QPSK" else 4
    total_bits = len(data_indices) * bits_per_symbol * symbols_per_frame
    tx_bits = np.random.randint(0, 2, size=total_bits)
    
    # 2. Modulate bits to complex data symbols
    data_symbols = modulate_bits(tx_bits, mod_scheme=mod_scheme)
    data_symbols_per_symbol = data_symbols.reshape(symbols_per_frame, len(data_indices))
    
    # 3. Generate OFDM symbols and concatenate into frame (real-valued for UWA transducers)
    ofdm_symbols = [generate_ofdm_symbol(syms) for syms in data_symbols_per_symbol]
    tx_frame = np.real(np.concatenate(ofdm_symbols))  # UWA uses real signals (per paper's system design)
    
    return tx_frame, tx_bits


# -----------------------------------------------------------------------------
# Step 5: Example Usage (Verify Alignment with frontiers.pdf)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    tx_frame, tx_bits = ofdm_transmitter(mod_scheme="QPSK")
    
    # Print verification (matches Table 4 and paper's design)
    print(f"Total subcarriers: {N_total} (Table 4)")
    print(f"Pilot subcarriers: {len(pilot_indices)} (Table 4: 256)")
    print(f"Data subcarriers: {len(data_indices)} (1023 - 256 = 767)")
    print(f"FFT size: {fft_size} (Table 4)")
    print(f"CP length: {cp_length} (Table 4)")
    
    # Plot first 2 symbols (visualize CP and symbol structure)
    plt.figure(figsize=(12, 4))
    symbol_len = fft_size + cp_length
    plt.plot(tx_frame[:2 * symbol_len], label="OFDM Frame (First 2 Symbols)")
    plt.axvline(x=symbol_len, color='r', linestyle='--', label="End of 1st Symbol (CP + Data)")

    plt.xlabel("Sample Index")
    plt.ylabel("Amplitude (Real Baseband)")
    plt.legend()
    plt.grid(True)
    plt.show()