In [1]:
import sionna
import os
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pickle

gpu_num = 0 # Use "" to use the CPU
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_num}"
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Import Sionna
try:
    import sionna
except ImportError as e:
    # Install Sionna if package is not already installed
    import os
    os.system("pip install sionna")
    import sionna

#from sionna.channel import AWGN
from sionna.phy.channel import RayleighBlockFading
from sionna.phy.utils import ebnodb2no, log10, expand_to_rank, insert_dims
from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder
from sionna.phy.mapping import BinarySource, Mapper, Demapper, Constellation
from sionna.phy.channel import FlatFadingChannel, KroneckerModel
from sionna.phy.utils import sim_ber
# Configure the notebook to use only a single GPU and allocate only as much memory as needed
# For more details, see https://www.tensorflow.org/guide/gpu
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)
# Avoid warnings from TensorFlow
tf.get_logger().setLevel('ERROR')

from tensorflow.keras import Model
from tensorflow.keras.layers import Layer, Dense
from tqdm.auto import tqdm
import numpy as np

In [2]:
###############################################
# SNR range for evaluation and training [dB]
###############################################
ebno_db_min = -2
ebno_db_max = 4

###############################################
# Modulation and coding configuration
###############################################
num_bits_per_symbol = 6 # Baseline is 64-QAM
modulation_order = 2**num_bits_per_symbol
coderate = 0.5 # Coderate for the outer code
n = 1500 # Codeword length [bit]. Must be a multiple of num_bits_per_symbol
num_symbols_per_codeword = n//num_bits_per_symbol # Number of modulated baseband symbols per codeword
k = int(n*coderate) # Number of information bits per codeword

###############################################
# Training configuration
###############################################
num_training_iterations_conventional = 1000 #10000 # Number of training iterations for conventional training
# Number of training iterations with RL-based training for the alternating training phase and fine-tuning of the receiver phase
num_training_iterations_rl_alt = 700 #7000
num_training_iterations_rl_finetuning = 300 #3000
###############################################
# Meta-RL Training configuration
###############################################
num_training_iterations_meta_rl = 500 #1000 # Number of training iterations for Meta-RL training
meta_batch_size = 8 # Number of tasks for each meta-training iteration

training_batch_size = tf.constant(32, tf.int32) # Training batch size
rl_perturbation_var = 0.01 # Variance of the perturbation used for RL-based training of the transmitter
model_weights_path_conventional_training = "awgn_autoencoder_weights_conventional_training" # Filename to save the autoencoder weights once conventional training is done
model_weights_path_rl_training = "awgn_autoencoder_weights_rl_training" # Filename to save the autoencoder weights once RL-based training is done
model_weights_path_metarl_training = "awgn_autoencoder_weights_metarl_training" # Filename to save the autoencoder weights once RL-based training is done

###############################################
# Evaluation configuration
###############################################
results_filename = "awgn_autoencoder_results" # Location to save the results
def save_weights(model, model_weights_path):
    weights = model.get_weights()
    with open(model_weights_path + "__rf", 'wb') as f:
        pickle.dump(weights, f)

# Utility function to load and set weights of a model
def load_weights(model, model_weights_path):
    model(1, tf.constant(10.0, tf.float32))
    with open(model_weights_path + "__rf", 'rb') as f:
        weights = pickle.load(f)
    model.set_weights(weights)

In [None]:
from sionna.phy.utils import ebnodb2no, compute_ser, compute_ber, PlotBER
from sionna.phy.channel import FlatFadingChannel, KroneckerModel
from sionna.phy.channel.utils import exp_corr_mat
from sionna.phy.mimo import lmmse_equalizer
from sionna.phy.mapping import BinarySource, QAMSource, SymbolDemapper, Mapper, Demapper
from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder
num_tx = 4
num_rx = 16

tf.config.xla_compat=True
class Baseline(tf.keras.Model):
    def __init__(self, spatial_corr=None):
        super().__init__()
        self.n = n 
        self.k = k  
        self.coderate = coderate
        self.num_bits_per_symbol = num_bits_per_symbol
        self.num_tx_ant = num_tx
        self.num_rx_ant = num_rx
        self.binary_source = BinarySource()
        self.encoder = LDPC5GEncoder(self.k, self.n)
        self.mapper = Mapper("qam", self.num_bits_per_symbol)
        self.demapper = Demapper("app", "qam", self.num_bits_per_symbol)
        self.decoder = LDPC5GDecoder(self.encoder, hard_out=True)
        self.channel = FlatFadingChannel(self.num_tx_ant,
                                         self.num_rx_ant,
                                         spatial_corr=spatial_corr,
                                         add_awgn=True,
                                         return_channel=True)
        
    @tf.function(jit_compile=True)
    def call(self, batch_size, ebno_db):
        b = self.binary_source([batch_size, self.num_tx_ant, self.k])
        c = self.encoder(b)
        
        x = self.mapper(c)
        shape = tf.shape(x)
        x = tf.reshape(x, [-1, self.num_tx_ant])
        
        no = ebnodb2no(ebno_db, self.num_bits_per_symbol, self.coderate)
        # print(ebno_db)
        # print(no, no.shape)
        no *= np.sqrt(self.num_rx_ant)

        y, h = self.channel(x, no)
        # print(no, no.shape)
        # print(self.num_rx_ant)
        s = tf.complex(no*tf.eye(self.num_rx_ant, self.num_rx_ant), 0.0)
        
        x_hat, no_eff = lmmse_equalizer(y, h, s)
        
        x_hat = tf.reshape(x_hat, shape)
        no_eff = tf.reshape(no_eff, shape)
        
        llr = self.demapper(x_hat, no_eff)
        b_hat = self.decoder(llr)
        
        return b,  b_hat

In [None]:
class NeuralDemapper(Layer):
    def __init__(self):
        super().__init__()
        
        # Deeper network for MIMO-OFDM processing
        self._dense_1 = Dense(256, 'relu')
        self._dense_2 = Dense(256, 'relu')
        self._dense_3 = Dense(128, 'relu')
        self._dense_4 = Dense(num_bits_per_symbol, None)

    @tf.function(jit_compile=True)
    def call(self, y, no):
        # Get shapes for dynamic handling
        batch_size = tf.shape(y)[0]
        total_symbols = tf.shape(y)[1]
        
        # Using log10 scale helps with the performance
        no_db = log10(no)
        
        # Stack real part, imaginary part, and noise as features
        z = tf.stack([tf.math.real(y),
                     tf.math.imag(y),
                     no_db], axis=2)
        
        # Process through neural network
        h = self._dense_1(z)
        h = self._dense_2(h)
        h = self._dense_3(h)
        llr = self._dense_4(h)
        
        return llr

In [None]:
from sionna.phy.ofdm import ResourceGrid, ResourceGridMapper, RemoveNulledSubcarriers, OFDMModulator
from sionna.phy.ofdm import OFDMDemodulator
from sionna.phy.ofdm import LSChannelEstimator, RZFPrecoder, LMMSEEqualizer
from sionna.phy.channel.tr38901 import AntennaArray, CDL
from sionna.phy.channel import cir_to_ofdm_channel, time_lag_discrete_time_channel, cir_to_time_channel, subcarrier_frequencies, ApplyOFDMChannel, ApplyTimeChannel
from sionna.phy.mimo import StreamManagement
from sionna.phy.mapping import Mapper, BinarySource, Constellation
from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder
from sionna.phy.utils import ebnodb2no
import tensorflow as tf

class E2ESystemCDLTraining(tf.keras.Model):
    def __init__(self, training, 
                 domain="freq", 
                 direction="uplink", 
                 cdl_model="A",
                 delay_spread=100e-9,
                 perfect_csi=True,
                 speed=0.0,
                 cyclic_prefix_length=16,
                 pilot_ofdm_symbol_indices=[2, 11]):
        super().__init__()
        self.training = training

        # CDL Parameters
        self._domain = domain
        self._direction = direction
        self._cdl_model = cdl_model
        self._delay_spread = delay_spread
        self._perfect_csi = perfect_csi
        self._speed = speed
        self._cyclic_prefix_length = cyclic_prefix_length
        self._pilot_ofdm_symbol_indices = pilot_ofdm_symbol_indices

        # Basic system parameters
        self.n = n
        self.k = k
        self.coderate = coderate
        self.num_bits_per_symbol = num_bits_per_symbol
        
        # System parameters from CDL model
        self._carrier_frequency = 2.6e9
        self._subcarrier_spacing = 15e3
        self._fft_size = 72
        self._num_ofdm_symbols = 14
        self._num_ut_ant = 4  # Must be a multiple of two as dual-polarized antennas are used
        self._num_bs_ant = 8  # Must be a multiple of two as dual-polarized antennas are used
        self._num_streams_per_tx = self._num_ut_ant
        self._dc_null = True
        self._num_guard_carriers = [5, 6]
        self._pilot_pattern = "kronecker"
        
        # Set up components
        self.binary_source = BinarySource()
        if not self.training:
            self.encoder = LDPC5GEncoder(k, self.n)
            
        # Trainable constellation for transmitter
        constellation = Constellation("qam", num_bits_per_symbol, trainable=True)
        self.constellation = constellation
        self.mapper = Mapper(constellation=constellation)
        
        # Stream management
        self._sm = StreamManagement(np.array([[1]]), self._num_streams_per_tx)
        
        # Set up resource grid
        self._rg = ResourceGrid(num_ofdm_symbols=self._num_ofdm_symbols,
                            fft_size=self._fft_size,
                            subcarrier_spacing=self._subcarrier_spacing,
                            num_tx=1,
                            num_streams_per_tx=self._num_streams_per_tx,
                            cyclic_prefix_length=self._cyclic_prefix_length,
                            num_guard_carriers=self._num_guard_carriers,
                            dc_null=self._dc_null,
                            pilot_pattern=self._pilot_pattern,
                            pilot_ofdm_symbol_indices=self._pilot_ofdm_symbol_indices)
        
        # Then calculate bit counts based on resource grid
        self.n = int(self._rg.num_data_symbols * num_bits_per_symbol)
        self.k = int(self.n * coderate)
        # Resource grid mapper
        self._rg_mapper = ResourceGridMapper(self._rg)
        
        # Antenna arrays
        self._ut_array = AntennaArray(num_rows=1,
                                      num_cols=int(self._num_ut_ant/2),
                                      polarization="dual",
                                      polarization_type="cross",
                                      antenna_pattern="38.901",
                                      carrier_frequency=self._carrier_frequency)

        self._bs_array = AntennaArray(num_rows=1,
                                      num_cols=int(self._num_bs_ant/2),
                                      polarization="dual",
                                      polarization_type="cross",
                                      antenna_pattern="38.901",
                                      carrier_frequency=self._carrier_frequency)
                                      
        # Channel model
        self._cdl = CDL(model=self._cdl_model,
                        delay_spread=self._delay_spread,
                        carrier_frequency=self._carrier_frequency,
                        ut_array=self._ut_array,
                        bs_array=self._bs_array,
                        direction=self._direction,
                        min_speed=self._speed)
        
        # Frequency domain computations
        self._frequencies = subcarrier_frequencies(self._rg.fft_size, self._rg.subcarrier_spacing)
        
        # Channel application components
        if self._domain == "freq":
            self._channel_freq = ApplyOFDMChannel(add_awgn=True)
        elif self._domain == "time":
            self._l_min, self._l_max = time_lag_discrete_time_channel(self._rg.bandwidth)
            self._l_tot = self._l_max - self._l_min + 1
            self._channel_time = ApplyTimeChannel(self._rg.num_time_samples,
                                                  l_tot=self._l_tot,
                                                  add_awgn=True)
            self._modulator = OFDMModulator(self._cyclic_prefix_length)
            self._demodulator = OFDMDemodulator(self._fft_size, self._l_min, self._cyclic_prefix_length)
        
        # Receiver components
        if self._direction == "downlink":
            self._zf_precoder = RZFPrecoder(self._rg, self._sm, return_effective_channel=True)
            
        # Channel estimation
        self._ls_est = LSChannelEstimator(self._rg, interpolation_type="nn")
        self._lmmse_equ = LMMSEEqualizer(self._rg, self._sm)
        self._remove_nulled_scs = RemoveNulledSubcarriers(self._rg)
        
        # Replace standard demapper with our neural demapper
        self.demapper = NeuralDemapper()
        
        # Add decoder if not in training mode
        if not self.training:
            self.decoder = LDPC5GDecoder(self.encoder, hard_out=True)
            
        # Loss function
        if self.training:
            self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    @tf.function
    def call(self, batch_size, ebno_db):
        # Generate bits
        if self.training:
            c = self.binary_source([batch_size, 1, self._num_streams_per_tx, self.n])
        else:
            b = self.binary_source([batch_size, 1, self._num_streams_per_tx, self.k])
            c = self.encoder(b)
        
        # Map bits to symbols
        x = self.mapper(c)
        x_rg = self._rg_mapper(x)
        
        # Calculate noise power
        no = ebnodb2no(ebno_db, self.num_bits_per_symbol, self.coderate, self._rg)
        
        # Apply channel - time or frequency domain
        if self._domain == "time":
            # Time-domain simulations
            a, tau = self._cdl(batch_size, self._rg.num_time_samples+self._l_tot-1, self._rg.bandwidth)
            h_time = cir_to_time_channel(self._rg.bandwidth, a, tau,
                                        l_min=self._l_min, l_max=self._l_max, normalize=True)

            # Downsample path gains for frequency domain
            a_freq = a[...,self._rg.cyclic_prefix_length:-1:(self._rg.fft_size+self._rg.cyclic_prefix_length)]
            a_freq = a_freq[...,:self._rg.num_ofdm_symbols]
            h_freq = cir_to_ofdm_channel(self._frequencies, a_freq, tau, normalize=True)

            if self._direction == "downlink":
                x_rg, g = self._zf_precoder(x_rg, h_freq)

            x_time = self._modulator(x_rg)
            y_time = self._channel_time(x_time, h_time, no)
            y = self._demodulator(y_time)

        else:  # Frequency domain
            cir = self._cdl(batch_size, self._rg.num_ofdm_symbols, 1/self._rg.ofdm_symbol_duration)
            h_freq = cir_to_ofdm_channel(self._frequencies, *cir, normalize=True)

            if self._direction == "downlink":
                x_rg, g = self._zf_precoder(x_rg, h_freq)

            y = self._channel_freq(x_rg, h_freq, no)
        
        # Channel estimation
        if self._perfect_csi:
            if self._direction == "uplink":
                h_hat = self._remove_nulled_scs(h_freq)
            elif self._direction == "downlink":
                h_hat = g
            err_var = 0.0
        else:
            h_hat, err_var = self._ls_est(y, no)
        
        # Equalize
        x_hat, no_eff = self._lmmse_equ(y, h_hat, err_var, no)
        print(f"Batch Size: {batch_size}, x_hat shape: {x_hat.shape}, no_eff shape: {no_eff.shape}")
        x_hat_reshaped = tf.reshape(x_hat, [batch_size, -1])  # Flatten all dimensions after batch
        no_eff_reshaped = tf.reshape(no_eff, [batch_size, -1])

        # Use neural demapper
        llr = self.demapper(x_hat_reshaped, no_eff_reshaped)
        # Reshape LLR outputs to match expected format for loss calculation or decoding
        if self.training:

            llr = tf.reshape(llr, tf.shape(c))
            loss = self.bce(c, llr)
            return loss
        else:
            llr = tf.reshape(llr, [batch_size, 1, self._num_streams_per_tx, -1])
            b_hat = self.decoder(llr)
            return b, b_hat

def conventional_training(model):
    # Optimizer used to apply gradients
    optimizer = tf.keras.optimizers.Adam()

    @tf.function
    def train_step():
        # Sampling a batch of SNRs
        ebno_db = tf.random.uniform(shape=[training_batch_size], minval=ebno_db_min, maxval=ebno_db_max)
        # Forward pass
        with tf.GradientTape() as tape:
            loss = model(training_batch_size, ebno_db)
        # Computing and applying gradients
        weights = model.trainable_variables
        grads = tape.gradient(loss, weights)
        optimizer.apply_gradients(zip(grads, weights))
        return loss

    for i in range(num_training_iterations_conventional):
        loss = train_step()
        # Printing periodically the progress
        if i % 100 == 0:
            print('Iteration {}/{}  BCE: {:.4f}'.format(i, num_training_iterations_conventional, loss.numpy()), end='\r')

In [59]:
# Fix the seed for reproducible trainings
tf.random.set_seed(1)
# Instantiate and train the end-to-end system with CDL channel
model = E2ESystemCDLTraining(training=True)
conventional_training(model)
# Save weights
save_weights(model, model_weights_path_conventional_training)

Batch Size: Tensor("batch_size:0", shape=(), dtype=int32), x_hat shape: (32, 1, 4, 720), no_eff shape: (32, 1, 4, 720)




Batch Size: Tensor("batch_size:0", shape=(), dtype=int32), x_hat shape: (32, 1, 4, 720), no_eff shape: (32, 1, 4, 720)


ValueError: in user code:

    File "C:\Users\Ahmad\AppData\Local\Temp\ipykernel_9808\3536717681.py", line 226, in train_step  *
        loss = model(training_batch_size, ebno_db)
    File "e:\Projects\Unfinished\Sionna-CDL-Channel\.env\Lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "C:\Users\Ahmad\AppData\Local\Temp\__autograph_generated_file3vomspah.py", line 188, in tf__call
        ag__.if_stmt(ag__.ld(self).training, if_body_7, else_body_7, get_state_7, set_state_7, ('do_return', 'retval_', 'llr'), 2)
    File "C:\Users\Ahmad\AppData\Local\Temp\__autograph_generated_file3vomspah.py", line 167, in if_body_7
        llr = ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(llr), ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(c),), None, fscope)), None, fscope)

    ValueError: Exception encountered when calling E2ESystemCDLTraining.call().
    
    [1min user code:
    
        File "C:\Users\Ahmad\AppData\Local\Temp\ipykernel_9808\4048474048.py", line 208, in call  *
            llr = tf.reshape(llr, tf.shape(c))
    
        ValueError: Tried to convert 'tensor' to a tensor and failed. Error: None values not supported.
    [0m
    
    Arguments received by E2ESystemCDLTraining.call():
      • batch_size=tf.Tensor(shape=(), dtype=int32)
      • ebno_db=tf.Tensor(shape=(32,), dtype=float32)


In [9]:
class E2ESystemRLTraining(tf.keras.Model):
    def __init__(self, training, spatial_corr=None):
        super().__init__()
        self.training = training

        self.n = n 
        self.k = k  
        self.coderate = coderate
        self.num_bits_per_symbol = num_bits_per_symbol
        self.num_tx_ant = num_tx
        self.num_rx_ant = num_rx
        self.binary_source = BinarySource()
        if not self.training:
            self.encoder = LDPC5GEncoder(k, n, num_bits_per_symbol) 
        # Trainable constellation
        constellation = Constellation("qam", num_bits_per_symbol, trainable=True)
        self.constellation = constellation
        self.mapper = Mapper(constellation=constellation)
        self.channel = FlatFadingChannel(self.num_tx_ant,
                                         self.num_rx_ant,
                                         spatial_corr=spatial_corr,
                                         add_awgn=True,
                                         return_channel=True)
        self.demapper = NeuralDemapper()
        # To reduce the computational complexity of training, the outer code is not used when training,
        # as it is not required
        if not self.training:
            self.decoder = LDPC5GDecoder(self.encoder, hard_out=True)
            
        #################
        # Loss function
        #################
        if self.training:
            self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
     
    def call(self, batch_size, ebno_db, perturbation_variance=tf.constant(0.0, tf.float32)):
        if self.training:
            c = self.binary_source([batch_size, self.num_tx_ant, self.n])
        else:
            b = self.binary_source([batch_size, self.num_tx_ant, self.k])
            c = self.encoder(b)
            
        x = self.mapper(c)
        shape = tf.shape(x)
        x = tf.reshape(x, [-1, self.num_tx_ant])
        
        epsilon_r = tf.random.normal(tf.shape(x))*tf.sqrt(0.5*perturbation_variance)
        epsilon_i = tf.random.normal(tf.shape(x))*tf.sqrt(0.5*perturbation_variance)
        epsilon = tf.complex(epsilon_r, epsilon_i) # [batch size, num_symbols_per_codeword]
        x_p = x + epsilon

        
        no = ebnodb2no(ebno_db, self.num_bits_per_symbol, self.coderate)
        no *= np.sqrt(self.num_rx_ant)
        
        y, h = self.channel([x_p, no])
        s = tf.complex(no*tf.eye(self.num_rx_ant, self.num_rx_ant), 0.0)
        
        x_hat, no_eff = lmmse_equalizer(y, h, s)

        x_hat = tf.reshape(x_hat, shape)
        no_eff = tf.reshape(no_eff, shape)

        llr = self.demapper([x_hat, no_eff])
        # If training, outer decoding is not performed and the BCE is returned
        if self.training:
            c = tf.reshape(c, [-1, num_symbols_per_codeword * self.num_tx_ant, num_bits_per_symbol])
            llr = tf.reshape(llr, [-1, num_symbols_per_codeword * self.num_tx_ant, num_bits_per_symbol])
            bce = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(c, llr), axis=2) # Avergare over the bits mapped to a same baseband symbol
            # The RX loss is the usual average BCE
            rx_loss = tf.reduce_mean(bce)
            # From the TX side, the BCE is seen as a feedback from the RX through which backpropagation is not possible
            bce = tf.stop_gradient(bce) # [batch size, num_symbols_per_codeword]
            x_p = tf.stop_gradient(x_p)
            p = x_p-x # [batch size, num_symbols_per_codeword] Gradient is backpropagated through `x`
            p = tf.reshape(p, bce.shape)
            tx_loss = tf.square(tf.math.real(p)) + tf.square(tf.math.imag(p)) # [batch size, num_symbols_per_codeword]
            tx_loss = -bce*tx_loss/rl_perturbation_var # [batch size, num_symbols_per_codeword]
            tx_loss = tf.reduce_mean(tx_loss)
            return tx_loss, rx_loss
        else:
            # Outer decoding
            llr = tf.reshape(llr, [batch_size, self.num_tx_ant, self.n])
            b_hat = self.decoder(llr)
            return b, b_hat # Ground truth and reconstructed information bits returned for BER/BLER computation

def rl_based_training(model):
    # Optimizers used to apply gradients
    optimizer_tx = tf.keras.optimizers.Adam() # For training the transmitter
    optimizer_rx = tf.keras.optimizers.Adam() # For training the receiver

    # Function that implements one transmitter training iteration using RL.
    @tf.function
    def train_tx():
        # Sampling a batch of SNRs
        ebno_db = tf.random.uniform(shape=[4], minval=ebno_db_min, maxval=ebno_db_max)
        # Forward pass
        for no in ebno_db:
            with tf.GradientTape() as tape:
                # Keep only the TX loss
                tx_loss, _ = model(4, no,
                                   tf.constant(rl_perturbation_var, tf.float32)) # Perturbation are added to enable RL exploration
            ## Computing and applying gradients
            weights = model.trainable_weights
            grads = tape.gradient(tx_loss, weights)
            optimizer_tx.apply_gradients(zip(grads, weights))
    
    # Function that implements one receiver training iteration
    @tf.function
    def train_rx():
        # Sampling a batch of SNRs
        ebno_db = tf.random.uniform(shape=[4], minval=ebno_db_min, maxval=ebno_db_max)
        # Forward pass
        for no in ebno_db:
            with tf.GradientTape() as tape:
                # Keep only the RX loss
                _, rx_loss = model(4, no) # No perturbation is added
            ## Computing and applying gradients
            weights = model.trainable_weights
            grads = tape.gradient(rx_loss, weights)
            optimizer_rx.apply_gradients(zip(grads, weights))
    
    # Training loop.
    for i in tqdm(pbar := tqdm(range(num_training_iterations_rl_alt))):
        # 10 steps of receiver training are performed to keep it ahead of the transmitter
        # as it is used for computing the losses when training the transmitter
        for _ in range(10):
            train_rx()
        # One step of transmitter training
        # train_tx()     
        # get progress training
        ebno_db = tf.random.uniform(shape=[3], minval=ebno_db_min, maxval=ebno_db_max)
        with tf.GradientTape() as tape:
            _, rx_loss = model(4, ebno_db[1])
        # Printing periodically the progress
        txt = f"BCE: {rx_loss.numpy()}"
        pbar.set_description(txt)
    
    # Once alternating training is done, the receiver is fine-tuned.
    print('Receiver fine-tuning... ')
    for i in tqdm(pbar := tqdm(range(num_training_iterations_rl_finetuning))):
        train_rx()

In [10]:
# Fix the seed for reproducible trainings
tf.random.set_seed(1)
# Instantiate and train the end-to-end system
model = E2ESystemRLTraining(training=True)
rl_based_training(model)
# Save weights
save_weights(model, model_weights_path_rl_training)

NameError: name 'num_tx' is not defined

In [11]:
class E2ESystemMetaRLTraining(tf.keras.Model):
    def __init__(self, training, spatial_corr=None):
        super().__init__()
        self.training = training

        self.n = n 
        self.k = k  
        self.coderate = coderate
        self.num_bits_per_symbol = num_bits_per_symbol
        self.num_tx_ant = num_tx
        self.num_rx_ant = num_rx
        self.binary_source = BinarySource()
        if not self.training:
            self.encoder = LDPC5GEncoder(k, n, num_bits_per_symbol) 
        # Trainable constellation
        constellation = Constellation("qam", num_bits_per_symbol, trainable=True)
        self.constellation = constellation
        self.mapper = Mapper(constellation=constellation)
        self.channel = FlatFadingChannel(self.num_tx_ant,
                                         self.num_rx_ant,
                                         spatial_corr=spatial_corr,
                                         add_awgn=True,
                                         return_channel=True)
        self.demapper = NeuralDemapper()
        # To reduce the computational complexity of training, the outer code is not used when training,
        # as it is not required
        if not self.training:
            self.decoder = LDPC5GDecoder(self.encoder, hard_out=True)
            
        #################
        # Loss function
        #################
        if self.training:
            self.bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
     
    def call(self, batch_size, ebno_db, perturbation_variance=tf.constant(0.0, tf.float32)):
        if self.training:
            c = self.binary_source([batch_size, self.num_tx_ant, self.n])
        else:
            b = self.binary_source([batch_size, self.num_tx_ant, self.k])
            c = self.encoder(b)
            
        x = self.mapper(c)
        shape = tf.shape(x)
        x = tf.reshape(x, [-1, self.num_tx_ant])
        
        epsilon_r = tf.random.normal(tf.shape(x))*tf.sqrt(0.5*perturbation_variance)
        epsilon_i = tf.random.normal(tf.shape(x))*tf.sqrt(0.5*perturbation_variance)
        epsilon = tf.complex(epsilon_r, epsilon_i) # [batch size, num_symbols_per_codeword]
        x_p = x + epsilon

        
        no = ebnodb2no(ebno_db, self.num_bits_per_symbol, self.coderate)
        no *= np.sqrt(self.num_rx_ant)
        
        y, h = self.channel([x_p, no])
        s = tf.complex(no*tf.eye(self.num_rx_ant, self.num_rx_ant), 0.0)
        
        x_hat, no_eff = lmmse_equalizer(y, h, s)

        x_hat = tf.reshape(x_hat, shape)
        no_eff = tf.reshape(no_eff, shape)

        llr = self.demapper([x_hat, no_eff])
        # If training, outer decoding is not performed and the BCE is returned
        if self.training:
            c = tf.reshape(c, [-1, num_symbols_per_codeword * self.num_tx_ant, num_bits_per_symbol])
            llr = tf.reshape(llr, [-1, num_symbols_per_codeword * self.num_tx_ant, num_bits_per_symbol])
            bce = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(c, llr), axis=2) # Avergare over the bits mapped to a same baseband symbol
            # The RX loss is the usual average BCE
            rx_loss = tf.reduce_mean(bce)
            # From the TX side, the BCE is seen as a feedback from the RX through which backpropagation is not possible
            bce = tf.stop_gradient(bce) # [batch size, num_symbols_per_codeword]
            x_p = tf.stop_gradient(x_p)
            p = x_p-x # [batch size, num_symbols_per_codeword] Gradient is backpropagated through `x`
            p = tf.reshape(p, bce.shape)
            tx_loss = tf.square(tf.math.real(p)) + tf.square(tf.math.imag(p)) # [batch size, num_symbols_per_codeword]
            tx_loss = -bce*tx_loss/rl_perturbation_var # [batch size, num_symbols_per_codeword]
            tx_loss = tf.reduce_mean(tx_loss)
            return tx_loss, rx_loss
        else:
            # Outer decoding
            llr = tf.reshape(llr, [batch_size, self.num_tx_ant, self.n])
            b_hat = self.decoder(llr)
            return b, b_hat # Ground truth and reconstructed information bits returned for BER/BLER computation


def copy_model(model):
    adapted_model = tf.keras.models.clone_model(model)
    adapted_model(1, tf.constant(10.0, tf.float32))
    adapted_model.set_weights(model.get_weights())
    return adapted_model

def inner_train(model, task_snr_values, inner_interations = 3):
    # return grads & avg loss
    def get_avg_grads(grads):
        _grads = []
        for j in range(len(grads[0])):
            __grads = []
            for x in grads:
                __grads.append(x[j])
            __grads = tf.math.reduce_mean(__grads, axis=0)
            _grads.append(
                tf.Variable(__grads)
            )
        return _grads
            
    inner_loss = []
    tx_grad = []
    rx_grad = []
    
    adapted_model = copy_model(model)
    for _ in range(inner_interations): 
        for snr in task_snr_values:
            with tf.GradientTape() as tx_tape:
                _, rx_loss = adapted_model(4, snr)
            weights = adapted_model.demapper.trainable_variables
            grads = tx_tape.gradient(rx_loss, weights)
            rx_grad.append(grads)    
        
            inner_loss.append(rx_loss.numpy())
            
            with tf.GradientTape() as rx_tape:
                tx_loss, _ = adapted_model(4, snr, tf.constant(rl_perturbation_var))
            weights = adapted_model.constellation.trainable_variables
            grads = rx_tape.gradient(tx_loss, weights)
            tx_grad.append(grads)    
        
    return np.mean(inner_loss), get_avg_grads(tx_grad), get_avg_grads(rx_grad) 

def meta_rl_training(model):
    lr_outer = 0.01
    # Meta-training loop
    num_training_iterations_meta_rl = 1000
    for i in tqdm(pbar := tqdm(range(num_training_iterations_meta_rl))):
        optimizer_constellation = tf.keras.optimizers.Adam()
        optimizer_demapper = tf.keras.optimizers.Adam()
        fixed_snr = ebno_db_max
        task_snr_values = tf.random.uniform([8], minval=ebno_db_min, maxval=ebno_db_max)  
        
        inner_loss, tx_grad, rx_grads = inner_train(model, task_snr_values)
        weights = model.demapper.trainable_variables  
        optimizer_demapper.apply_gradients(zip(rx_grads, weights))
        # weights = model.constellation.trainable_variables  
        # optimizer_constellation.apply_gradients(zip(tx_grad, weights))
        
        # manually update, due to not able to use the optimizer directly
        # model.constellation.set_weights(
        #     tf.subtract(model.constellation.weights, tf.multiply(lr_outer, tx_grad[0])),
        # )
        # k = 0
        # for j in range(len(model.demapper.layers)):
        #     model.demapper.layers[j].set_weights([
        #         tf.subtract(model.demapper.layers[j].kernel, tf.multiply(lr_outer, rx_grads[k])),
        #         tf.subtract(model.demapper.layers[j].bias, tf.multiply(lr_outer, rx_grads[k+1])),
        #     ])
        #     k += 2
            
        # for the outerloop, train with variety of SNRs, but very few data
        task_snr_values = tf.random.uniform([4], minval=ebno_db_max, maxval=ebno_db_max) 
        for snr in task_snr_values:
            # with tf.GradientTape() as tx_tape:
            #     tx_loss, _ = model(4, snr, tf.constant(rl_perturbation_var))
            # weights = model.constellation.trainable_variables  
            # grads = tx_tape.gradient(tx_loss, weights)
            # optimizer_constellation.apply_gradients(zip(grads, weights))  
    
            with tf.GradientTape() as rx_tape:
                _, rx_loss = model(4, snr)
            weights = model.demapper.trainable_variables  
            grads = rx_tape.gradient(rx_loss, weights)
            optimizer_demapper.apply_gradients(zip(grads, weights))
        
        # txt = f"Inner BCE: {inner_loss} || Tx BCE: {tx_loss.numpy()} - Rx BCE: {rx_loss.numpy()}"
        txt = f"Avg BCE: {inner_loss}"

        pbar.set_description(txt)
        
        if (i%50)==0:
            print(f"Epoch {i}:\t" + txt)
            save_weights(meta_model, model_weights_path_metarl_training)
        # break
    print("Meta-RL Training complete.")

# Now, use the modified code in place of the original RL-based training
tf.random.set_seed(1)
# init model
meta_model = E2ESystemMetaRLTraining(training=True)
meta_model(1, tf.Variable(5.0))
# load pre-trained model before -- you can use this to continously train the model if the memory capacity is not enough # Previous total iterations = 1k
# load_weights(meta_model, model_weights_path_metarl_training) 
# train meta-model
meta_rl_training(meta_model)
save_weights(meta_model, model_weights_path_metarl_training)

NameError: name 'num_tx' is not defined

In [None]:
# BLER = {}
# BER = {}

# # Range of SNRs over which the systems are evaluated
# ebno_dbs = np.arange(ebno_db_min, # Min SNR for evaluation
#                      ebno_db_max, # Max SNR for evaluation
#                      0.5) # Step

# model_metarl = E2ESystemMetaRLTraining(training=False)
# load_weights(model_metarl, model_weights_path_metarl_training)
# ber,bler = sim_ber(model_metarl, ebno_dbs, batch_size=128, num_target_block_errors=1000, max_mc_iter=1000)
# BLER['autoencoder-metarl'] = bler.numpy()
# BER['autoencoder-metarl'] = ber.numpy()

In [12]:
# Dictionnary storing the results
# with open('base_result', 'rb') as f:
#     BLER = pickle.load(f)

BLER = {}
BER = {}

# Range of SNRs over which the systems are evaluated
ebno_dbs = np.arange(ebno_db_min, # Min SNR for evaluation
                     ebno_db_max, # Max SNR for evaluation
                     0.5) # Step

model_baseline = Baseline()
ber,bler = sim_ber(model_baseline, ebno_dbs, batch_size=128, num_target_block_errors=1000, max_mc_iter=1000)
BLER['baseline'] = bler.numpy()
BER['baseline'] = ber.numpy()


model_conventional = E2ESystemConventionalTraining(training=False)
load_weights(model_conventional, model_weights_path_conventional_training)
ber,bler = sim_ber(model_conventional, ebno_dbs, batch_size=128, num_target_block_errors=1000, max_mc_iter=1000)
BLER['autoencoder-conv'] = bler.numpy()
BER['autoencoder-conv'] = ber.numpy()

model_rl = E2ESystemRLTraining(training=False)
load_weights(model_rl, model_weights_path_rl_training)
ber,bler = sim_ber(model_rl, ebno_dbs, batch_size=128, num_target_block_errors=1000, max_mc_iter=1000)
BLER['autoencoder-rl'] = bler.numpy()
BER['autoencoder-rl'] = ber.numpy()

model_metarl = E2ESystemMetaRLTraining(training=False)
load_weights(model_metarl, model_weights_path_metarl_training)
ber,bler = sim_ber(model_metarl, ebno_dbs, batch_size=128, num_target_block_errors=1000, max_mc_iter=1000)
BLER['autoencoder-metarl'] = bler.numpy()
BER['autoencoder-metarl'] = ber.numpy()

with open(results_filename, 'wb') as f:
    pickle.dump((ebno_dbs, BLER), f)

NameError: name 'Baseline' is not defined

In [13]:
plt.figure(figsize=(10,8))
plt.semilogy(ebno_dbs, BLER['baseline'], 'o-', c=f'C0', label=f'Baseline')    
plt.semilogy(ebno_dbs, BLER['autoencoder-conv'], 'x-.', c=f'C1', label=f'Autoencoder - conventional training')
plt.semilogy(ebno_dbs, BLER['autoencoder-rl'], 'o-.', c=f'C2', label=f'Autoencoder - RL-based training')
plt.semilogy(ebno_dbs, BLER['autoencoder-metarl'], 'o-.', c=f'C3', label=f'Autoencoder - Meta-RL-based training')

plt.xlabel(r"$E_b/N_0$ (dB)")
plt.ylabel("BLER")
plt.grid(which="both")
plt.ylim((1e-5, 1.0))
plt.legend()
plt.tight_layout()

KeyError: 'baseline'

<Figure size 1000x800 with 0 Axes>

In [14]:
plt.figure(figsize=(10,8))
plt.semilogy(ebno_dbs, BER['baseline'], 'o-', c=f'C0', label=f'Baseline')    
plt.semilogy(ebno_dbs, BER['autoencoder-conv'], 'x-.', c=f'C1', label=f'Autoencoder - conventional training')
plt.semilogy(ebno_dbs, BER['autoencoder-rl'], 'o-.', c=f'C2', label=f'Autoencoder - RL-based training')
plt.semilogy(ebno_dbs, BER['autoencoder-metarl'], 'o-.', c=f'C3', label=f'Autoencoder - Meta-RL-based training')

plt.xlabel(r"$E_b/N_0$ (dB)")
plt.ylabel("BER")
plt.grid(which="both")
plt.ylim((1e-8, 1.0))
plt.legend()
plt.tight_layout()

KeyError: 'baseline'

<Figure size 1000x800 with 0 Axes>