In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv1D, LSTM, Dense, Dropout, 
                                     Reshape, Concatenate, Layer)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.losses import MeanSquaredError

# -----------------------------------------------------------------------------
# Step 1: Complex Channel Output Layer 
# Paper Eq. 16: \(\hat{h}_i = \sum_{d=1}^D (W_{i,d}^{(r)} + jW_{i,d}^{(i)})h_d^{final} + b_i^{(r)} + jb_i^{(i)}\)
# -----------------------------------------------------------------------------
class ComplexChannelOutputLayer(Layer):
    def __init__(self, num_subcarriers, **kwargs):
        super().__init__(**kwargs)
        self.num_subcarriers = num_subcarriers  # N = 1024 (Table 4: Number of Subcarriers)
    
    def build(self, input_shape):
        # Input shape: (batch_size, D) where D = 64 (2nd LSTM's hidden units)
        D = input_shape[-1]
        # Real part weights/biases: W^{(r)}, b^{(r)}
        self.W_real = self.add_weight(
            shape=(D, self.num_subcarriers),
            initializer="glorot_uniform",
            name="W_real"
        )
        self.b_real = self.add_weight(
            shape=(self.num_subcarriers,),
            initializer="zeros",
            name="b_real"
        )
        # Imaginary part weights/biases: W^{(i)}, b^{(i)}
        self.W_imag = self.add_weight(
            shape=(D, self.num_subcarriers),
            initializer="glorot_uniform",
            name="W_imag"
        )
        self.b_imag = self.add_weight(
            shape=(self.num_subcarriers,),
            initializer="zeros",
            name="b_imag"
        )
    
    def call(self, inputs):
        # inputs: Final hidden state of 2nd LSTM (h_final: batch_size × D)
        h_final = inputs
        # Compute real and imaginary parts of \(\hat{h}_i\)
        h_real = tf.matmul(h_final, self.W_real) + self.b_real  # batch_size × N
        h_imag = tf.matmul(h_final, self.W_imag) + self.b_imag  # batch_size × N
        # Combine into complex channel coefficients (shape: batch_size × N × 2)
        # Axis 2: [0] = real, [1] = imaginary (for consistency with data format)
        return tf.stack([h_real, h_imag], axis=-1)
    
    def get_config(self):
        config = super().get_config()
        config.update({"num_subcarriers": self.num_subcarriers})
        return config

# -----------------------------------------------------------------------------
# Step 2: Build CRNet Model 
# -----------------------------------------------------------------------------
def build_crnet(num_subcarriers=1024,  # Table 4: Number of OFDM Subcarriers
                pilot_dim=10,         # num = 10 (1-167: input shape (1152,1024,10))
                conv_filters=64,       #  64 filters per Conv1D
                conv_kernel=3,         #  kernel size 3
                lstm_units=64,         #  64 units per LSTM
                dense_units=128,       #  128 units in Dense layer
                dropout_rate=0.2):     #  dropout = 0.2
    
    # ----------------------
    # Input Layer 
    # Input 1: Received OFDM signal (shape: (batch_size, time_steps, num_subcarriers, pilot_dim))
    # Paper input shape: (N_t=1152, N_fft=1024, num=10) → time_steps=1152, num_subcarriers=1024, pilot_dim=10
    input_received = Input(
        shape=(None, num_subcarriers, pilot_dim),  # None = variable time_steps (1152)
        name="input_received_signal"
    )
    
    # Input 2: Transmitted pilot symbols (X_p(k); shape: (batch_size, time_steps, num_subcarriers))
    input_pilots = Input(
        shape=(None, num_subcarriers),
        name="input_transmitted_pilots"
    )
    
    # ----------------------
    # Reshape Inputs for Conv1D 
    # Combine "num_subcarriers" and "pilot_dim" into a single feature dimension for Conv1D
    # New shape: (batch_size, time_steps, num_subcarriers × pilot_dim)
    x = Reshape((-1, num_subcarriers * pilot_dim))(input_received)
    # Concatenate received signal with pilot symbols (paper: 1-108: train with received + transmitted pilots)
    x = Concatenate(axis=-1)([x, input_pilots])  # Shape: (batch_size, time_steps, (num_subcarriers×pilot_dim) + num_subcarriers)
    
    # ----------------------
    # Conv1D Layers 
    # 1st Conv1D Layer: 64 filters, kernel=3, ReLU
    x = Conv1D(
        filters=conv_filters,
        kernel_size=conv_kernel,
        activation="relu",
        padding="same",
        name="conv1d_1"
    )(x)
    x = Dropout(dropout_rate)(x)
    
    # 2nd Conv1D Layer: 64 filters, kernel=3, ReLU
    x = Conv1D(
        filters=conv_filters,
        kernel_size=conv_kernel,
        activation="relu",
        padding="same",
        name="conv1d_2"
    )(x)
    x = Dropout(dropout_rate)(x)
    
    # ----------------------
    # LSTM Layers 
    # 1st LSTM: Returns full sequence (shape: (batch_size, time_steps, lstm_units))
    x = LSTM(
        units=lstm_units,
        return_sequences=True,
        name="lstm_1"
    )(x)
    x = Dropout(dropout_rate)(x)
    
    # 2nd LSTM: Returns final hidden state (shape: (batch_size, lstm_units))
    x = LSTM(
        units=lstm_units,
        return_sequences=False,
        name="lstm_2"
    )(x)
    x = Dropout(dropout_rate)(x)
    
    # ----------------------
    # Dense Layer 
    x = Dense(
        units=dense_units,
        activation="relu",
        name="dense_1"
    )(x)
    
    # ----------------------
    # Output Layer: Complex Channel Estimation
    # Output shape: (batch_size, num_subcarriers, 2) → [real, imaginary] parts of \(\hat{h}_i\)
    output_channel = ComplexChannelOutputLayer(
        num_subcarriers=num_subcarriers,
        name="output_complex_channel"
    )(x)
    
    # ----------------------
    # Define Model
    crnet = Model(
        inputs=[input_received, input_pilots],
        outputs=output_channel,
        name="CRNet"
    )
    
    # ----------------------
    # Compile Model 
    optimizer = Adam(learning_rate=1e-3)  # Adam, lr=0.001
    loss = MeanSquaredError(name="mse_loss")  # MSE loss 
    crnet.compile(optimizer=optimizer, loss=loss, metrics=["mse"])
    
    return crnet

# -----------------------------------------------------------------------------
# Step 3: Generate Synthetic Training Data (Matches Paper's Bellhop Dataset)
# Dataset generated via Bellhop ray-tracing (6000 samples, 80%/20% split)
# -----------------------------------------------------------------------------
def generate_bellhop_like_data(num_samples=6000,    # 6000 samples per channel
                               time_steps=1152,     #  N_t=1152
                               num_subcarriers=1024,# Table 4: 1024 subcarriers
                               pilot_dim=10):       #  num=10
    
    # Input 1: Received OFDM signal (shape: (num_samples, time_steps, num_subcarriers, pilot_dim))
    # Simulate Bellhop-like multipath + noise 
    received_signal = np.random.normal(
        loc=0, scale=0.5,
        size=(num_samples, time_steps, num_subcarriers, pilot_dim)
    ).astype(np.float32)
    
    # Input 2: Transmitted pilot symbols (shape: (num_samples, time_steps, num_subcarriers))
    # Paper: 1-108: pilot symbols are known (QPSK: 1+1j, -1+1j, etc. → real part used here)
    pilots = np.random.choice([-1.0, 1.0], size=(num_samples, time_steps, num_subcarriers)).astype(np.float32)
    
    # Ground Truth: True Channel Impulse Response (CIR) → complex coefficients (real + imaginary)
    # Shape: (num_samples, num_subcarriers, 2) → [real, imaginary]
    true_channel = np.random.normal(
        loc=0, scale=0.1,
        size=(num_samples, num_subcarriers, 2)
    ).astype(np.float32)
    
    # Split into train (80%) and validation (20%) (1-170)
    split_idx = int(num_samples * 0.8)
    train_data = (
        [received_signal[:split_idx], pilots[:split_idx]],  # train inputs
        true_channel[:split_idx]                            # train labels
    )
    val_data = (
        [received_signal[split_idx:], pilots[split_idx:]],  # val inputs
        true_channel[split_idx:]                            # val labels
    )
    
    return train_data, val_data

# -----------------------------------------------------------------------------
# Step 4: Train CRNet (Matches Paper's Training Parameters)
#  100 epochs, batch size=64, early stopping (patience=5)
# -----------------------------------------------------------------------------
def train_crnet(crnet_model, train_data, val_data, batch_size=64, epochs=100):
    # Callbacks ( early stopping with patience=5)
    callbacks = [
        EarlyStopping(
            monitor="val_loss",
            patience=5,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            #filepath="crnet_best_model.h5",
            #monitor="val_loss",
            #save_best_only=True,
            verbose=1
        )
    ]
    
    # Train model
    history = crnet_model.fit(
        x=train_data[0],
        y=train_data[1],
        batch_size=batch_size,
        epochs=epochs,
        validation_data=val_data,
        callbacks=callbacks,
        verbose=1
    )
    
    return history

# -----------------------------------------------------------------------------
# Step 5: (Initialize, Generate Data, Train)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # 1. Initialize CRNet (all params from paper)
    crnet = build_crnet(
        num_subcarriers=1024,  # Table 4
        pilot_dim=10,          
        conv_filters=64,       
        conv_kernel=3,         
        lstm_units=64,         
        dense_units=128,       
        dropout_rate=0.2       
    )
    crnet.summary()  # Print model architecture (verify layers/params)
    
    # 2. Generate Bellhop-like training/validation data (paper: 1-162-169)
    train_data, val_data = generate_bellhop_like_data(
        num_samples=6000,    #  6000 samples
        time_steps=1152,     #  N_t=1152
        num_subcarriers=1024,
        pilot_dim=10
    )
    
    # 3. Train CRNet 
    history = train_crnet(
        crnet_model=crnet,
        train_data=train_data,
        val_data=val_data,
        batch_size=64,  
        epochs=100      
    )
    
    # 4. Example Inference (Estimate Channel from New Data)
    # Generate a single test sample (shape: (1, time_steps, num_subcarriers, pilot_dim))
    test_received = np.random.normal(0, 0.5, (1, 1152, 1024, 10)).astype(np.float32)
    test_pilots = np.random.choice([-1.0, 1.0], (1, 1152, 1024)).astype(np.float32)
    # Predict complex channel coefficients
    estimated_channel = crnet.predict([test_received, test_pilots], verbose=1)
    print(f"\nEstimated Channel Shape: {estimated_channel.shape} → (batch_size, num_subcarriers, [real, imaginary])")