In [2]:
# Block 1: Import Libraries and Set Random Seed
import tensorflow as tf
import numpy as np
import sionna
import matplotlib.pyplot as plt
from sionna.ofdm import ResourceGrid, ResourceGridMapper
from sionna.mimo import StreamManagement
from sionna.channel.tr38901 import CDL, AntennaArray
from sionna.channel import subcarrier_frequencies, cir_to_ofdm_channel, ApplyOFDMChannel
from sionna.utils import ebnodb2no
from sionna.fec.ldpc.encoding import LDPC5GEncoder
from sionna.fec.ldpc.decoding import LDPC5GDecoder
from sionna.mapping import Mapper, Demapper
from sionna.utils import BinarySource

# Set random seeds for reproducibility
sionna.config.seed = 42
tf.random.set_seed(42)
np.random.seed(42)

print("Block 1 Complete: Libraries imported, seeds set.")

Block 1 Complete: Libraries imported, seeds set.


In [6]:
# Block 2: Revised System Parameters and MIMO-OFDM Setup
import numpy as np
import sionna
from sionna.ofdm import ResourceGrid, ResourceGridMapper
from sionna.mimo import StreamManagement
from sionna.channel.tr38901 import CDL, AntennaArray
from sionna.channel import subcarrier_frequencies
from sionna.utils import BinarySource

num_ut = 1
num_bs = 1
num_ut_ant = 4  # N_t
num_bs_ant = 4  # N_r
num_streams_per_tx = num_ut_ant
rx_tx_association = np.array([[1]])
sm = StreamManagement(rx_tx_association, num_streams_per_tx)

# Resource Grid: 14 OFDM symbols, 256 subcarriers, 16-QAM
rg = ResourceGrid(num_ofdm_symbols=14,
                              fft_size=256,
                              subcarrier_spacing=15e3,
                              num_tx=1,
                              num_streams_per_tx=num_streams_per_tx,
                              cyclic_prefix_length=40,
                              num_guard_carriers=[5,6],
                              dc_null=True,
                              pilot_pattern="kronecker",
                              pilot_ofdm_symbol_indices=[2,11])

# Carrier & Antenna setup
carrier_frequency = 2.6e9
ut_array = AntennaArray(num_rows=1, num_cols=int(num_ut_ant/2),
                        polarization="dual", polarization_type="cross",
                        antenna_pattern="38.901", carrier_frequency=carrier_frequency)
bs_array = AntennaArray(num_rows=1, num_cols=int(num_bs_ant/2),
                        polarization="dual", polarization_type="cross",
                        antenna_pattern="38.901", carrier_frequency=carrier_frequency)

# CDL Channel Model
delay_spread = 300e-9
cdl_model = "B"
speed = 10  # m/s
direction = "uplink"
cdl = CDL(model=cdl_model,
          delay_spread=delay_spread,
          carrier_frequency=carrier_frequency,
          ut_array=ut_array,
          bs_array=bs_array,
          direction=direction,
          min_speed=speed)

print("Block 2 Complete: MIMO-OFDM system and CDL channel set up.")
print("FFT size:", rg.fft_size)
print("Cyclic Prefix:", rg.cyclic_prefix_length)
print("Tx Antennas:", num_ut_ant, "| Rx Antennas:", num_bs_ant)

Block 2 Complete: MIMO-OFDM system and CDL channel set up.
FFT size: 256
Cyclic Prefix: 40
Tx Antennas: 4 | Rx Antennas: 4


In [7]:
# Block 3: Signal Generation and Encoding
num_bits_per_symbol = 4  # 16-QAM
coderate = 0.5
batch_size = 4  # Q = 4 as per paper

# Calculate number of bits
n = int(rg.num_data_symbols * num_bits_per_symbol)
k = int(n * coderate)

# Initialize modules
binary_source = BinarySource()
encoder = LDPC5GEncoder(k, n)
mapper = Mapper("qam", num_bits_per_symbol)
rg_mapper = ResourceGridMapper(rg)

# Generate bits, encode, and map to symbols
b = binary_source([batch_size, 1, rg.num_streams_per_tx, k])
c = encoder(b)
x = mapper(c)
x_rg = rg_mapper(x)

print("Block 3 Complete: Bits → LDPC → 16-QAM → ResourceGrid")
print("Shape of bits (b):", b.shape)
print("Shape of encoded bits (c):", c.shape)
print("Shape of mapped symbols (x):", x.shape)
print("Shape of OFDM-mapped grid (x_rg):", x_rg.shape)

Block 3 Complete: Bits → LDPC → 16-QAM → ResourceGrid
Shape of bits (b): (4, 1, 4, 5856)
Shape of encoded bits (c): (4, 1, 4, 11712)
Shape of mapped symbols (x): (4, 1, 4, 2928)
Shape of OFDM-mapped grid (x_rg): (4, 1, 4, 14, 256)


In [8]:
# Block 4: Nonlinear PA and Channel Application
import tensorflow as tf
from sionna.utils import ebnodb2no
from sionna.channel import ApplyOFDMChannel

def rapp_pa(signal, A=1.0, p=3.0):
    abs_signal = tf.abs(signal)
    gain = A / tf.pow(1 + tf.pow(abs_signal / A, 2 * p), 1 / (2 * p))
    angle = tf.math.angle(signal)
    zero = tf.zeros_like(angle)
    phase = tf.exp(tf.complex(zero, angle))
    return tf.cast(gain, tf.complex64) * phase

# IFFT to time domain
x_time_pa = tf.signal.ifft(tf.cast(x_rg, tf.complex64))  # [4, 1, 4, 14, 256]
print("IFFT complete: x_time_pa shape:", x_time_pa.shape)

# Apply Rapp PA
x_pa = rapp_pa(x_time_pa, A=1.0, p=3.0)
print("PA applied: x_pa shape:", x_pa.shape)

# Channel preparation
frequencies = subcarrier_frequencies(rg.fft_size, rg.subcarrier_spacing)
cir = cdl(batch_size, rg.num_ofdm_symbols, 1 / rg.ofdm_symbol_duration)
h_freq = cir_to_ofdm_channel(frequencies, *cir, normalize=True)

# Apply channel with AWGN
ebno_db = 15.0
no = ebnodb2no(ebno_db, num_bits_per_symbol, coderate, rg)
channel = ApplyOFDMChannel(add_awgn=True)
y_rc = channel([x_pa, h_freq, no])  # [4, 1, 4, 14, 256]

print("Block 4 Complete: Channel output shape y_rc:", y_rc.shape)

IFFT complete: x_time_pa shape: (4, 1, 4, 14, 256)
PA applied: x_pa shape: (4, 1, 4, 14, 256)
Block 4 Complete: Channel output shape y_rc: (4, 1, 4, 14, 256)


In [13]:
# Block 5 (Revised): Prepare RC Inputs and Targets
import tensorflow as tf

# Time-domain received signal
y_td_rc = tf.signal.ifft(tf.cast(y_rc, tf.complex64))  # [4, 1, 4, 14, 256]

# Select stream 0, antenna 0
y_td_flat = tf.reshape(y_td_rc[:, 0, 0, :, :], [batch_size, -1])  # [4, 3584]
window_size = 128
y_td_windowed = tf.signal.frame(y_td_flat, window_size, 1, axis=1)  # [4, 3457, 128]

# Stack real and imaginary parts for input
x_input_rc1 = tf.stack([tf.math.real(y_td_windowed), tf.math.imag(y_td_windowed)], axis=-1)  # [4, 3457, 128, 2]

# Normalize input
mean_rc1 = tf.reduce_mean(x_input_rc1)
std_rc1 = tf.math.reduce_std(x_input_rc1)
x_input_rc1 = (x_input_rc1 - mean_rc1) / std_rc1

# Prepare ground truth
x_time_target = tf.signal.ifft(tf.cast(x_rg, tf.complex64))  # [4, 1, 4, 14, 256]
x_time_target_flat = tf.reshape(x_time_target[:, 0, 0, :, :], [batch_size, -1])  # [4, 3584]
x_time_target_ri = tf.stack([tf.math.real(x_time_target_flat), tf.math.imag(x_time_target_flat)], axis=-1)  # [4, 3584, 2]

z_target = tf.reshape(x[:, 0, 0, :], [batch_size, -1])  # [4, 2928]

print("Block 5 Complete: RC1 input and targets prepared")
print("x_input_rc1 shape:", x_input_rc1.shape)
print("x_time_target_ri shape:", x_time_target_ri.shape)
print("z_target shape:", z_target.shape)
print("Normalization parameters: mean_rc1 =", mean_rc1.numpy(), "std_rc1 =", std_rc1.numpy())

Block 5 Complete: RC1 input and targets prepared
x_input_rc1 shape: (4, 3457, 128, 2)
x_time_target_ri shape: (4, 3584, 2)
z_target shape: (4, 2928)
Normalization parameters: mean_rc1 = 0.00029930193 std_rc1 = 0.10280468


In [19]:
# Block 6 (Revised): Define Time-Frequency RC Model
import tensorflow as tf

class TimeFreqRC(tf.keras.Model):
    def __init__(self, input_dim, reservoir_size, output_length, sequence_length, leak_rate=1.0, dropout_rate=0.0):
        """Initialize Time-Frequency RC model."""
        super().__init__()
        self.reservoir_size = reservoir_size
        self.output_length = output_length  # Frequency-domain symbols (2928)
        self.sequence_length = sequence_length  # Time-domain samples (3584)
        self.leak_rate = leak_rate
        self.dropout_rate = dropout_rate

        # Fixed weights
        self.Win = tf.Variable(tf.random.uniform([input_dim, reservoir_size]) * 0.1, trainable=False)
        Wres_init = tf.random.normal([reservoir_size, reservoir_size]) * 0.05
        eigvals = tf.abs(tf.linalg.eigvals(Wres_init))
        self.Wres = tf.Variable(Wres_init / tf.reduce_max(eigvals), trainable=False)

        # Trainable weights
        self.W_out = tf.Variable(tf.random.normal([reservoir_size, 2], stddev=0.01))
        self.phase_angles = tf.Variable(tf.zeros([output_length]), trainable=True)

    def call(self, x, window_size=128, training=True):
        """Compute reservoir states and time-domain output."""
        B, T, W, C = x.shape  # [4, 3457, 128, 2]
        h = tf.zeros([B, self.reservoir_size], dtype=tf.float32)
        states = []
        
        # Process windowed input for full sequence length
        for t in range(self.sequence_length):
            if t < T:
                xt = x[:, min(t, T-1), :, :]  # [B, 128, 2]
            else:
                xt = tf.zeros([B, window_size, C], dtype=tf.float32)
            xt_flat = tf.reshape(xt, [B, -1])  # [B, 128*2]
            preact = tf.matmul(xt_flat, self.Win) + tf.matmul(h, self.Wres)
            h = (1 - self.leak_rate) * h + self.leak_rate * tf.math.tanh(preact)
            if training and self.dropout_rate > 0:
                h = tf.nn.dropout(h, rate=self.dropout_rate)
            states.append(h)
        
        S = tf.stack(states, axis=1)  # [B, 3584, R]
        y_time = tf.matmul(S, self.W_out)  # [B, 3584, 2]
        return y_time, S

# ALS optimization
def als_optimization(model, S, target_time, target_freq, num_iterations=5):
    """Optimize W_out and phase_angles using ALS."""
    # S: [B, T, R], target_time: [B, T, 2], target_freq: [B, F]
    W_tout = tf.linalg.pinv(S) @ target_time  # [B, R, 2]
    w_fout = tf.ones([model.output_length], dtype=tf.float32)
    
    for _ in range(num_iterations):
        y_time = tf.matmul(S, W_tout)  # [B, T, 2]
        y_complex = tf.complex(y_time[..., 0], y_time[..., 1])  # [B, T]
        z_pred = tf.signal.fft(y_complex)[:, :model.output_length]  # [B, F]
        phase_angles = -tf.math.angle(tf.reduce_sum(tf.math.conj(target_freq) * z_pred, axis=0))
        w_fout = tf.complex(tf.cos(phase_angles), tf.sin(phase_angles))
        z_corr = z_pred * w_fout
        W_tout = tf.linalg.pinv(S) @ target_time
    
    model.W_out.assign(W_tout[0])  # Assign for first batch
    model.phase_angles.assign(phase_angles)
    return W_tout, w_fout

print("Block 6 Complete: TimeFreqRC model and ALS optimization defined.")

Block 6 Complete: TimeFreqRC model and ALS optimization defined.


In [20]:
# Block 7 (Revised): Train RC1
import tensorflow as tf

# Instantiate RC1
rc1 = TimeFreqRC(input_dim=128*2, reservoir_size=128, output_length=2928, sequence_length=3584, leak_rate=0.3, dropout_rate=0.1)

# Compute reservoir states
y_time_rc1, S_rc1 = rc1(x_input_rc1, window_size=128, training=True)  # y_time_rc1: [4, 3584, 2], S_rc1: [4, 3584, 128]

# Reshape states for ALS
S_rc1_flat = tf.reshape(S_rc1, [batch_size, -1, 128])  # [4, 3584, 128]

# Apply ALS optimization
W_tout_rc1, w_fout_rc1 = als_optimization(rc1, S_rc1_flat, x_time_target_ri, z_target, num_iterations=5)

# Compute losses
y_ri_rc1 = tf.matmul(S_rc1_flat, rc1.W_out)  # [4, 3584, 2]
waveform_loss = tf.reduce_mean(tf.square(y_ri_rc1 - x_time_target_ri))
y_complex_rc1 = tf.complex(y_ri_rc1[..., 0], y_ri_rc1[..., 1])  # [4, 3584]
z_pred_rc1 = tf.signal.fft(y_complex_rc1)[:, :2928]  # [4, 2928]
phase_rc1 = tf.complex(tf.cos(rc1.phase_angles), tf.sin(rc1.phase_angles))  # [2928]
z_corr_rc1 = z_pred_rc1 * phase_rc1  # [4, 2928]
qam_loss = tf.reduce_mean(tf.square(tf.math.real(z_corr_rc1 - z_target)) + tf.square(tf.math.imag(z_corr_rc1 - z_target)))
total_loss = waveform_loss + 0.5 * qam_loss

print("Block 7 Complete: RC1 trained")
print(f"Total Loss: {total_loss:.4f} | Waveform Loss: {waveform_loss:.4f} | QAM Loss: {qam_loss:.4f}")

Block 7 Complete: RC1 trained
Total Loss: 0.5838 | Waveform Loss: 0.0019 | QAM Loss: 1.1638


In [24]:
# Block 8 (Revised): Prepare RC2 Input
import tensorflow as tf

# Compute frequency-domain residual
residual_z1 = z_target - z_corr_rc1  # [4, 2928]

# Pad residual to match time-domain length for IFFT
residual_z1_padded = tf.pad(residual_z1, paddings=[[0, 0], [0, 656]])  # [4, 3584]

# Transform to time domain
x_time_residual2 = tf.signal.ifft(residual_z1_padded)  # [4, 3584]

# Apply 128-sample window
x_time_residual2_windowed = tf.signal.frame(x_time_residual2, frame_length=128, frame_step=1, axis=1)  # [4, 3457, 128]

# Stack real and imaginary parts
x_input_rc2 = tf.stack([tf.math.real(x_time_residual2_windowed), tf.math.imag(x_time_residual2_windowed)], axis=-1)  # [4, 3457, 128, 2]

# Normalize input
mean_rc2 = tf.reduce_mean(x_input_rc2)
std_rc2 = tf.math.reduce_std(x_input_rc2)
x_input_rc2 = (x_input_rc2 - mean_rc2) / std_rc2

print("Block 8 Complete: RC2 input prepared")
print("x_input_rc2 shape:", x_input_rc2.shape)
print("x_time_target_ri shape:", x_time_target_ri.shape)
print("z_target shape:", z_target.shape)
print("Normalization parameters: mean_rc2 =", mean_rc2.numpy(), "std_rc2 =", std_rc2.numpy())

Block 8 Complete: RC2 input prepared
x_input_rc2 shape: (4, 3457, 128, 2)
x_time_target_ri shape: (4, 3584, 2)
z_target shape: (4, 2928)
Normalization parameters: mean_rc2 = -8.409324e-05 std_rc2 = 0.011528453


In [25]:
# Block 9: Train RC2
import tensorflow as tf

# Instantiate RC2
rc2 = TimeFreqRC(input_dim=128*2, reservoir_size=128, output_length=2928, sequence_length=3584, leak_rate=0.3, dropout_rate=0.1)

# Compute reservoir states
y_time_rc2, S_rc2 = rc2(x_input_rc2, window_size=128, training=True)  # y_time_rc2: [4, 3584, 2], S_rc2: [4, 3584, 128]

# Reshape states for ALS
S_rc2_flat = tf.reshape(S_rc2, [batch_size, -1, 128])  # [4, 3584, 128]

# Apply ALS optimization
W_tout_rc2, w_fout_rc2 = als_optimization(rc2, S_rc2_flat, x_time_target_ri, z_target, num_iterations=5)

# Compute losses
y_ri_rc2 = tf.matmul(S_rc2_flat, rc2.W_out)  # [4, 3584, 2]
waveform_loss = tf.reduce_mean(tf.square(y_ri_rc2 - x_time_target_ri))
y_complex_rc2 = tf.complex(y_ri_rc2[..., 0], y_ri_rc2[..., 1])  # [4, 3584]
z_pred_rc2 = tf.signal.fft(y_complex_rc2)[:, :2928]  # [4, 2928]
phase_rc2 = tf.complex(tf.cos(rc2.phase_angles), tf.sin(rc2.phase_angles))  # [2928]
z_corr_rc2 = z_pred_rc2 * phase_rc2  # [4, 2928]
qam_loss = tf.reduce_mean(tf.square(tf.math.real(z_corr_rc2 - z_target)) + tf.square(tf.math.imag(z_corr_rc2 - z_target)))
total_loss = waveform_loss + 0.5 * qam_loss

print("Block 9 Complete: RC2 trained")
print(f"Total Loss: {total_loss:.4f} | Waveform Loss: {waveform_loss:.4f} | QAM Loss: {qam_loss:.4f}")

Block 9 Complete: RC2 trained
Total Loss: 0.6096 | Waveform Loss: 0.0019 | QAM Loss: 1.2153


In [26]:
# Block 10: Prepare RC3 Input
import tensorflow as tf

# Compute frequency-domain residual
residual_z2 = z_target - z_corr_rc2  # [4, 2928]

# Pad residual to match time-domain length for IFFT
residual_z2_padded = tf.pad(residual_z2, paddings=[[0, 0], [0, 656]])  # [4, 3584]

# Transform to time domain
x_time_residual3 = tf.signal.ifft(residual_z2_padded)  # [4, 3584]

# Apply 128-sample window
x_time_residual3_windowed = tf.signal.frame(x_time_residual3, frame_length=128, frame_step=1, axis=1)  # [4, 3457, 128]

# Stack real and imaginary parts
x_input_rc3 = tf.stack([tf.math.real(x_time_residual3_windowed), tf.math.imag(x_time_residual3_windowed)], axis=-1)  # [4, 3457, 128, 2]

# Normalize input
mean_rc3 = tf.reduce_mean(x_input_rc3)
std_rc3 = tf.math.reduce_std(x_input_rc3)
x_input_rc3 = (x_input_rc3 - mean_rc3) / std_rc3

print("Block 10 Complete: RC3 input prepared")
print("x_input_rc3 shape:", x_input_rc3.shape)
print("x_time_target_ri shape:", x_time_target_ri.shape)
print("z_target shape:", z_target.shape)
print("Normalization parameters: mean_rc3 =", mean_rc3.numpy(), "std_rc3 =", std_rc3.numpy())

Block 10 Complete: RC3 input prepared
x_input_rc3 shape: (4, 3457, 128, 2)
x_time_target_ri shape: (4, 3584, 2)
z_target shape: (4, 2928)
Normalization parameters: mean_rc3 = 5.200019e-05 std_rc3 = 0.01177334


In [27]:
# Block 11: Train RC3
import tensorflow as tf

# Instantiate RC3
rc3 = TimeFreqRC(input_dim=128*2, reservoir_size=128, output_length=2928, sequence_length=3584, leak_rate=0.3, dropout_rate=0.1)

# Compute reservoir states
y_time_rc3, S_rc3 = rc3(x_input_rc3, window_size=128, training=True)  # y_time_rc3: [4, 3584, 2], S_rc3: [4, 3584, 128]

# Reshape states for ALS
S_rc3_flat = tf.reshape(S_rc3, [batch_size, -1, 128])  # [4, 3584, 128]

# Apply ALS optimization
W_tout_rc3, w_fout_rc3 = als_optimization(rc3, S_rc3_flat, x_time_target_ri, z_target, num_iterations=5)

# Compute losses
y_ri_rc3 = tf.matmul(S_rc3_flat, rc3.W_out)  # [4, 3584, 2]
waveform_loss = tf.reduce_mean(tf.square(y_ri_rc3 - x_time_target_ri))
y_complex_rc3 = tf.complex(y_ri_rc3[..., 0], y_ri_rc3[..., 1])  # [4, 3584]
z_pred_rc3 = tf.signal.fft(y_complex_rc3)[:, :2928]  # [4, 2928]
phase_rc3 = tf.complex(tf.cos(rc3.phase_angles), tf.sin(rc3.phase_angles))  # [2928]
z_corr_rc3 = z_pred_rc3 * phase_rc3  # [4, 2928]
qam_loss = tf.reduce_mean(tf.square(tf.math.real(z_corr_rc3 - z_target)) + tf.square(tf.math.imag(z_corr_rc3 - z_target)))
total_loss = waveform_loss + 0.5 * qam_loss

print("Block 11 Complete: RC3 trained")
print(f"Total Loss: {total_loss:.4f} | Waveform Loss: {waveform_loss:.4f} | QAM Loss: {qam_loss:.4f}")

Block 11 Complete: RC3 trained
Total Loss: 0.6686 | Waveform Loss: 0.0019 | QAM Loss: 1.3333
