# Using pyAerial for PUSCH decoding on Aerial Data Lake data
This example shows how to use the pyAerial bindings to run cuPHY GPU accelerated PUSCH decoding for 5G NR PUSCH. The 5G NR PUSCH data is read from an example over the air captured PUSCH dataset collected and stored using Aerial Data Lake. A complete PUSCH receiver using individual separate Python function calls to individual PUSCH receiver components is used so that channel estimates can be shown. 

In this setup there are two cells: 

- Cell 51 is an active cell being used with the OAI L2+ using slot pattern DDDSU and commercial UEs.

- Cell 41 is a passive "listener" cell using the same slot pattern as cell 51 but is being driven by testmac to request IQ samples for every uplink slot. 

There are multiple UEs connected to cell 51, as is shown in the `nUEs` field of the fapi table.
The first two plots show the power in all of the resource elements for the given slot on both cells. 
The plots after that show just the resource elements scheduled for a given UE (RNTI) across both cells, as well pre- and post- equalized samples, then channel estimates, for IQ samples from each cell, followed by text indicating whether the transmission decoded successfully.

**Note:** This example requires that the clickhouse server is running and that the example data has been stored in the database. Refer to the Aerial Data Lake documentation on how to do this.
https://docs.nvidia.com/aerial/cuda-accelerated-ran/aerial_data_lake/index.html#clickhouse-client

## Imports

In [None]:
import math
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import numpy as np
import pandas as pd
from IPython.display import Markdown
from IPython.display import display

# Connecting to clickhouse on remote server
import clickhouse_connect

# Plotting with Matplotlib.
import matplotlib.pyplot as plt
from matplotlib import dates as mdates


# pyAerial imports
from aerial.phy5g.config import PuschConfig
from aerial.phy5g.config import PuschUeConfig
from aerial.phy5g.algorithms import ChannelEstimator
from aerial.phy5g.algorithms import ChannelEqualizer
from aerial.phy5g.algorithms import NoiseIntfEstimator
from aerial.phy5g.ldpc import LdpcDeRateMatch
from aerial.phy5g.ldpc import LdpcDecoder
from aerial.phy5g.ldpc import CrcChecker
from aerial.util.cuda import get_cuda_stream
from aerial.util.fapi import dmrs_fapi_to_bit_array

# Hide log10(10) warning
_ = np.seterr(divide='ignore', invalid='ignore')
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

np.set_printoptions(threshold=100)  # Control the number of elements to display
np.set_printoptions(edgeitems=100)  # Control the number of edge items to display
np.set_printoptions(linewidth=200) # Control the width of the display

plt.rcParams['figure.figsize'] = [10, 4]

## Create the PUSCH pipelines
This is a PUSCH receiver pipeline made up of separately called pyAerial PUSCH receiver components.

In [None]:
# Whether to plot intermediate results within the PUSCH pipeline, such as channel estimates and equalized symbols.
plot_figures = True

num_ues = 1 
num_tx_ant = 2             # UE antennas
num_rx_ant = 4             # gNB antennas
enable_pusch_tdi = 0       # Enable time interpolation for equalizer coefficients
eq_coeff_algo = 1          # Equalizer algorithm

# The PUSCH receiver chain built from separately called pyAerial Python components is defined here.
class PuschRxSeparate:
    """PUSCH receiver class.
    
    This class encapsulates the whole PUSCH receiver chain built using
    pyAerial components.    
    """

    def __init__(self,
                 num_rx_ant,
                 enable_pusch_tdi,
                 eq_coeff_algo,
                 plot_figures):
        """Initialize the PUSCH receiver."""
        self.cuda_stream = get_cuda_stream()

        # Build the components of the receiver.
        self.channel_estimator = ChannelEstimator(
            num_rx_ant=num_rx_ant,
            cuda_stream=self.cuda_stream)
        self.channel_equalizer = ChannelEqualizer(
            num_rx_ant=num_rx_ant,
            enable_pusch_tdi=enable_pusch_tdi,
            eq_coeff_algo=eq_coeff_algo,
            cuda_stream=self.cuda_stream)
        self.noise_intf_estimator = NoiseIntfEstimator(
            num_rx_ant=num_rx_ant,
            eq_coeff_algo=eq_coeff_algo,
            cuda_stream=self.cuda_stream)
        self.derate_match = LdpcDeRateMatch(
            enable_scrambling=True,
            cuda_stream=self.cuda_stream)
        self.decoder = LdpcDecoder(cuda_stream=self.cuda_stream)
        self.crc_checker = CrcChecker(cuda_stream=self.cuda_stream)

        # Whether to plot the intermediate results.
        self.plot_figures = plot_figures

    def run(
        self,
        rx_slot,
        slot,
        pusch_configs,
        cell_num
    ):
        """Run the receiver."""
        # Channel estimation.
        ch_est = self.channel_estimator.estimate(
            rx_slot=rx_slot,
            slot=slot,
            pusch_configs=pusch_configs
        ) 

        # Noise and interference estimation.
        lw_inv, noise_var_pre_eq = self.noise_intf_estimator.estimate(
            rx_slot=rx_slot,
            channel_est=ch_est,
            slot=slot,
            pusch_configs=pusch_configs
        )

        # Channel equalization and soft demapping. The first return value are the LLRs,
        # second are the equalized symbols. We only want the LLRs now.
        llrs, sym = self.channel_equalizer.equalize(
            rx_slot=rx_slot,
            channel_est=ch_est,
            lw_inv=lw_inv,
            noise_var_pre_eq=noise_var_pre_eq,
            pusch_configs=pusch_configs
        )

        if self.plot_figures:    
            fig, axs = plt.subplots(1,4)
            for ant in range(4):
                axs[ant].imshow(10*np.log10(np.abs(rx_slot[:, :, ant]**2)), aspect='auto')
                axs[ant].set_ylim([(pusch_record.rbStart + pusch_record.rbSize) * 12,pusch_record.rbStart * 12])
                axs[ant].set_title('Ant ' + str(ant))
                axs[ant].set(xlabel='Symbol', ylabel='Resource Element')
                axs[ant].label_outer()
            fig.suptitle('Power in PUSCH REs cell {} for RNTI {}'.format(cell_num,pusch_record.rnti)) 
            
            fig, axs = plt.subplots(1,2)
            axs[0].scatter(rx_slot.reshape(-1).real, rx_slot.reshape(-1).imag)
            axs[0].set_title("Pre-Equalized samples")
            axs[0].set_aspect('equal')
            
            axs[1].scatter(np.array(sym).reshape(-1).real, np.array(sym).reshape(-1).imag)
            axs[1].set_title("Post-Equalized samples")
            axs[1].set_aspect('equal')
    
            fig, axs = plt.subplots(1)
            axs.set_title("Channel estimates from the PUSCH pipeline")
            for ant in range(4):
                axs.plot(np.abs(ch_est[0][ant, 0, :, 0]))
            axs.legend(["Rx antenna 0, estimate",
                        "Rx antenna 1, estimate",
                        "Rx antenna 2, estimate",
                        "Rx antenna 3, estimate"])
            axs.grid(True)
            plt.show()

        coded_blocks = self.derate_match.derate_match(
            input_llrs=llrs,
            pusch_configs=pusch_configs
        )
    
        code_blocks = self.decoder.decode(
            input_llrs=coded_blocks,
            pusch_configs=pusch_configs
        )

        decoded_tbs, _ = self.crc_checker.check_crc(
            input_bits=code_blocks,
            pusch_configs=pusch_configs
        )

        return decoded_tbs

pusch_rx_separate = PuschRxSeparate(
    num_rx_ant=num_rx_ant, 
    enable_pusch_tdi=enable_pusch_tdi,
    eq_coeff_algo=eq_coeff_algo,
    plot_figures=plot_figures
)

## Querying the database
Below shows how to connect to the clickhouse database and querying the data from it.

In [None]:
# Connect to the local database
import clickhouse_connect

client = clickhouse_connect.get_client(host='localhost')

# Pick a packet from the database,
pusch_records = client.query_df('select * from fapi order by TsTaiNs,rbStart limit 10')
#print(pusch_records[['TsTaiNs','SFN','Slot','nUEs','rnti','rbStart','rbSize','StartSymbolIndex','NrOfSymbols','CQI']])

## Extract the PUSCH parameters and run the pipelines
In this section we use the timestamp of the start of a slot to query the IQ sample database for both cells and demonstrate that the transmission can be decoded in the IQ samples of both cells.

In [None]:
runReceiver = True
#Only show the full slot pattern the first time through:
shownTsTaiNs = -1
pusch_records = client.query_df('select * from fapi order by TsTaiNs,rbStart limit 8')
print(pusch_records[['TsTaiNs','SFN','Slot','CellId','nUEs','rnti','rbStart','rbSize','StartSymbolIndex','NrOfSymbols','nrOfLayers','TBSize','CQI']])

for index, pusch_record in pusch_records.iterrows():    
    query = f"""select TsTaiNs,CellId,fhData from fh where TsTaiNs == toDateTime64(\'{pusch_record.TsTaiNs.timestamp()}\',9)"""
    fh = client.query_df(query)

    # Extract all the needed parameters from the PUSCH record and create the PuschConfig.
    pusch_ue_config = PuschUeConfig(
        scid=int(pusch_record.SCID),
        layers=pusch_record.nrOfLayers,
        dmrs_ports=pusch_record.dmrsPorts,
        rnti=pusch_record.rnti,
        data_scid=pusch_record.dataScramblingId,
        mcs_table=pusch_record.mcsTable,
        mcs_index=pusch_record.mcsIndex,
        code_rate=pusch_record.targetCodeRate,
        mod_order=pusch_record.qamModOrder,
        tb_size=pusch_record.TBSize
    )

    slot = int(pusch_record.Slot)
    tb_input = np.array(pusch_record.pduData)
    
    # Note that this is a list. One UE group only in this case.
    pusch_configs = [PuschConfig(
        ue_configs=[pusch_ue_config],
        num_dmrs_cdm_grps_no_data=pusch_record.numDmrsCdmGrpsNoData,
        dmrs_scrm_id=pusch_record.ulDmrsScramblingId,
        start_prb=pusch_record.rbStart,
        num_prbs=pusch_record.rbSize,
        dmrs_syms=dmrs_fapi_to_bit_array(int(pusch_record.ulDmrsSymbPos)),
        dmrs_max_len=1,
        dmrs_add_ln_pos=(pusch_record.ulDmrsSymbPos-4).bit_count(),
        start_sym=pusch_record.StartSymbolIndex,
        num_symbols=pusch_record.NrOfSymbols
    )]
    numCells = fh.index.size
    if shownTsTaiNs != pusch_record.TsTaiNs.timestamp():
        shownTsTaiNs = pusch_record.TsTaiNs.timestamp()
        display(Markdown("### SFN.Slot {}.{}, RNTI {} at time {}"
                     .format(pusch_record.SFN, pusch_record.Slot, pusch_record.rnti,pusch_record.TsTaiNs
        )))
        fig=plt.figure(constrained_layout=True)
        if numCells > 1:
            outer = fig.subfigures(1,numCells)
            for cellNum, cellPlot in enumerate(outer.flat):
                fh_samp = (np.array(fh['fhData'].iloc[cellNum], dtype=np.int16).view(np.float16)).astype(np.float32)
                rx_slot = np.swapaxes(fh_samp.view(np.complex64).reshape(4, 14, 273 * 12), 2, 0)
            
                axs = cellPlot.subplots(1,4)
                for ant, ax in enumerate(axs.flat):
                    ax.imshow(10*np.log10(np.abs(rx_slot[:, :, ant]**2)), aspect='auto')
                    ax.set_title('Ant ' + str(ant))
                    ax.set(xlabel='Symbol', ylabel='Resource Element')
                    ax.label_outer()
                    cellPlot.suptitle('Power in RU Antennas Cell '+str(fh['CellId'].iloc[cellNum]))
            plt.show()
        else:
            cellNum = 0
            fig, axs = plt.subplots(1,4)
            fh_samp = (np.array(fh['fhData'].iloc[cellNum], dtype=np.int16).view(np.float16)).astype(np.float32)
            rx_slot = np.swapaxes(fh_samp.view(np.complex64).reshape(4, 14, 273 * 12), 2, 0)
            for ant in range(4):
                axs[ant].imshow(10*np.log10(np.abs(rx_slot[:, :, ant]**2)), aspect='auto')
                axs[ant].set_ylim([pusch_record.rbStart * 12, (pusch_record.rbStart+pusch_record.rbSize) * 12])
                axs[ant].set_title('Ant ' + str(ant))
                axs[ant].set(xlabel='Symbol', ylabel='Resource Element')
                axs[ant].label_outer()
            fig.suptitle('Power in RU Antennas') 
            
            fig, axs = plt.subplots(1,2)
            axs[0].scatter(rx_slot.reshape(-1).real, rx_slot.reshape(-1).imag)
            axs[0].set_title("Pre-Equalized samples")
            axs[0].set_aspect('equal')
            
            axs[1].scatter(np.array(sym).reshape(-1).real, np.array(sym).reshape(-1).imag)
            axs[1].set_title("Post-Equalized samples")
            axs[1].set_aspect('equal')
    
            fig, axs = plt.subplots(1)
            axs.set_title("Channel estimates from the PUSCH pipeline")
            for ant in range(4):
                axs.plot(np.abs(ch_est[0][ant, 0, :, 0]))
            axs.legend(["Rx antenna 0, estimate",
                        "Rx antenna 1, estimate",
                        "Rx antenna 2, estimate",
                        "Rx antenna 3, estimate"])
            axs.grid(True)
            plt.show()

    if runReceiver:
        for cellNum in range(0,numCells):
            cellId = fh['CellId'].iloc[cellNum]
            fh_samp = (np.array(fh['fhData'].iloc[cellNum], dtype=np.int16).view(np.float16)).astype(np.float32)
            rx_slot = np.swapaxes(fh_samp.view(np.complex64).reshape(4, 14, 273 * 12), 2, 0)
            tbs = pusch_rx_separate.run(
                rx_slot=rx_slot,           
                slot=slot,
                pusch_configs=pusch_configs,
                cell_num=cellId
            )
            if np.array_equal(tbs[0][:tb_input.size], tb_input):
                display(Markdown("**PUSCH decoding success** for SFN.Slot {}.{} RNTI {} originally on cell {} using IQ data from cell {} "
                         .format(pusch_record.SFN, pusch_record.Slot, pusch_record.rnti, pusch_record.CellId, cellId)))
            else:
                display(Markdown("**PUSCH decoding failure** for SFN.Slot {}.{} RNTI {} originally on cell {} using IQ data from cell {} "
                         .format(pusch_record.SFN, pusch_record.Slot, pusch_record.rnti, pusch_record.CellId, cellId)))
                print("Output bytes:")
                print(tbs[0][:tb_input.size])
                print("Expected output:")
                print(tb_input)       