# Unittest of TRT Engine using Sionna Encoder

In [3]:
import tensorflow as tf
import numpy as np

from sionna.phy.mapping import Demapper
from sionna.phy.nr import PUSCHConfig, PUSCHTransmitter, PUSCHReceiver
from sionna.phy.channel import AWGN
from sionna.phy.utils import hard_decisions


2025-04-10 17:07:47.241290: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-10 17:07:47.241359: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-10 17:07:47.242790: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [24]:
channel = AWGN()

def generate_test_data(num_ofdm_symbols, num_prbs, num_rx_ant,
                       num_add_pilot_positions, scaling_factor):

    # Set the PUSCH configuration parameters
    pusch_config = PUSCHConfig()
    pusch_config.mapping_type = "B" # allows DMRS at pos 0,5,10
    pusch_config.dmrs.num_cdm_groups_without_data = 1 # no multi-user
    pusch_config.dmrs.dmrs_port_set = [0] # we use only first port (no multi-user)
    pusch_config.dmrs.additional_position = num_add_pilot_positions
    pusch_config.symbol_allocation = [0,num_ofdm_symbols]
    pusch_config.n_size_bwp = num_prbs
    pusch_config.n_rnti = 1
    pusch_config.carrier.subcarrier_spacing = 30
    pusch_config.carrier.slot_number = 0
    pusch_config.dmrs.n_id = 0
    pusch_config.dmrs.n_scid = 0
    pusch_config.tb.n_id = 0

    # Instantiate a PUSCHTransmitter from the PUSCHConfig
    pusch_transmitter = PUSCHTransmitter(pusch_config)
    pusch_receiver = PUSCHReceiver(pusch_transmitter, return_tb_crc_status=True)

    pilot_pos = tf.where(pusch_transmitter.pilot_pattern.mask[0,0,:]==1).numpy()
    # dmrs ofdm positions
    dmrs_ofdm_pos = np.unique(pilot_pos[:,0])
    dmrs_ofdm_pos = np.sort(dmrs_ofdm_pos)

    prb_pilot_pos = np.unique(pilot_pos[:,1])
    prb_pilot_pos = prb_pilot_pos[prb_pilot_pos < 12]
    prb_pilot_pos = np.sort(prb_pilot_pos)

    x, bits = pusch_transmitter(1)

    slot = tf.squeeze(x,axis=(0,1)) # remove unused dimensions

    # ground truth bits in resourcegrid
    demapper = Demapper("maxlog","qam", 4)
    bits_rg = hard_decisions(demapper(slot[...,tf.newaxis], 0.01))
    bits_rg = tf.squeeze(bits_rg, axis=0)

    # duplicated antennas (for testing only)
    slot = tf.tile(slot, [num_rx_ant,1,1])

    # add noise
    no = 0.1
    y = channel(slot, no)
    # y has shape [num_rx_ant, num_ofdm_symbols, num_subcarriers]

    # generate noisy channel estimates
    p_idx = tf.where(pusch_transmitter.pilot_pattern.mask[0,:]==1)
    #p_idx = tf.tile(p_idx[tf.newaxis,:,:], [1,1,1])
    #p_idx = p_idx[tf.newaxis,:,:]
    ch_est = tf.gather_nd(slot, p_idx)
    no_chest = 0.1
    ch_est_n = channel(ch_est, no_chest)

    # draw random complex phase rotation in tensorflow
    # Generate random phase between 0 and 2pi
    phase = tf.random.uniform([1,], minval=0.0,
                              maxval=2*np.pi, dtype=tf.float32)
    phase = tf.exp(tf.complex(0., 1.) * tf.cast(phase, tf.complex64))

    # apply scaling
    y *= scaling_factor * phase
    ch_est_n *= scaling_factor * phase

    return y, ch_est_n, bits_rg, dmrs_ofdm_pos, prb_pilot_pos


In [25]:
num_ofdm_symbols = [3, 5, 13] # between 1 and 14
num_prbs = [1, 5, 106] # between 1 and 273
num_rx_ant = [1, 2, 4]
scaling_factor = [5.0, 10.0, 20.0] # scale outputs
num_add_pilot_positions = [0, 1, 2, 3]

for num_ofdm_symbols_ in num_ofdm_symbols:
    for num_prbs_ in num_prbs:
        for num_rx_ant_ in num_rx_ant:
            for scaling_factor_ in scaling_factor:
                for num_add_pilot_positions_ in num_add_pilot_positions:

                    y, ch_est_n, bits_rg, dmrs_ofdm_pos, prb_pilot_pos = generate_test_data(num_ofdm_symbols_, num_prbs_, num_rx_ant_, num_add_pilot_positions_, scaling_factor_)

                    print("--------------------------------")
                    print("y:", y.shape)
                    print("ch_est:", ch_est_n.shape)
                    print("bits_rg:", bits_rg.shape)
                    print("dmrs_ofdm_pos:", dmrs_ofdm_pos)
                    print("prb_pilot_pos:", prb_pilot_pos)


--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bits_rg: (3, 12, 4)
dmrs_ofdm_pos: [0]
prb_pilot_pos: [ 0  2  4  6  8 10]
--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bits_rg: (3, 12, 4)
dmrs_ofdm_pos: [0]
prb_pilot_pos: [ 0  2  4  6  8 10]
--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bits_rg: (3, 12, 4)
dmrs_ofdm_pos: [0]
prb_pilot_pos: [ 0  2  4  6  8 10]
--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bits_rg: (3, 12, 4)
dmrs_ofdm_pos: [0]
prb_pilot_pos: [ 0  2  4  6  8 10]
--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bits_rg: (3, 12, 4)
dmrs_ofdm_pos: [0]
prb_pilot_pos: [ 0  2  4  6  8 10]
--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bits_rg: (3, 12, 4)
dmrs_ofdm_pos: [0]
prb_pilot_pos: [ 0  2  4  6  8 10]
--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bits_rg: (3, 12, 4)
dmrs_ofdm_pos: [0]
prb_pilot_pos: [ 0  2  4  6  8 10]
--------------------------------
y: (1, 3, 12)
ch_est: (6,)
bi

KeyboardInterrupt: 