In [1]:
"""Run two below lines to get my_abc module"""
# !git clone https://github.com/thanhttttt/thanh.git
# !pip install -r /content/thanh/requirements.txt
"""Run two below lines to drive"""
# from google.colab import drive
# drive.mount('/content/drive')

'Run two below lines to drive'

# Import

In [2]:
import os
if os.getenv("CUDA_VISIBLE_DEVICES") is None:
    gpu_num = 0 # Use "" to use the CPU
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_num}"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


import sys
sys.path.append('../')
sys.path.append('/content/thanh/')
sys.path.append('../thanh/')

import sionna

import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)
# Avoid warnings from TensorFlow
tf.get_logger().setLevel('ERROR')

sionna.config.seed = 42 # Set seed for reproducible random number generation

# Load the required Sionna components
from sionna.nr.my_abc import *

# Load model weight

In [3]:
_model = CustomNeuralReceiver(training = False)
inputs = tf.zeros([1,48,14,18])
_model(inputs)
_model.summary()

#load_weights(_model, '/content/drive/MyDrive/Pusch_data/Model_weights/model_weight_FULL_RB_epoch_40.pkl')
# load_weights(_model, '../model_weight_FULL_RB_epoch_40.pkl')
load_weights(_model, '../weight_4RB_UMI_dynamic_config.pkl')

Model: "custom_neural_receiver"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             multiple                  20864     
                                                                 
 residual_block (ResidualBl  multiple                  639232    
 ock)                                                            
                                                                 
 residual_block_1 (Residual  multiple                  639232    
 Block)                                                          
                                                                 
 residual_block_2 (Residual  multiple                  639232    
 Block)                                                          
                                                                 
 residual_block_3 (Residual  multiple                  639232    
 Block)                                     

# Set up config

In [4]:
"""test setup"""
batch_size = 8

"""nrb config setup"""
RB_start = 0
NRB = 162
PCI = 443
RNTI = 40035
MCS = 8

"""channel setup"""
no = 2.
CDL_model = 'A'
delay_spread = 50
speed = 1

## create channel

In [5]:
channel_model = CDL(model = CDL_model,
                            delay_spread = delay_spread*1e-9,
                            carrier_frequency = CARRIER_FREQUENCY,
                            ut_array = Ue_Antenna,
                            bs_array = Gnb_AntennaArray,
                            direction = 'uplink',
                            min_speed = speed,
                            max_speed = speed)

## create 4RB samples

In [6]:
"""default config is 4RB"""
sysCfg = SystemConfig()
ueCfg = UeConfig()
myCfg = MyConfig(sysCfg, [ueCfg])
puschCfg = MyPUSCHConfig(myCfg)
# puschCfg.show() # uncomment for detail

In [7]:
simulator = MySimulator(puschCfg)
channel = OFDMChannel(channel_model=channel_model, resource_grid=simulator.resource_grid,
                                    add_awgn=False, normalize_channel=True, return_channel=True)

In [8]:
b, c, y, x ,h = simulator.sim(batch_size, channel, no, return_tx_iq=True, return_channel=True)
r = simulator.ref(batch_size)

## create N-RB samples

In [32]:
"""test setup"""
batch_size = 1

"""nrb config setup"""
RB_start = 22
NRB = 140
PCI = 442
RNTI = 30025
MCS = 0

"""channel setup"""
no = 2.
CDL_model = 'A'
delay_spread = 50
speed = 1
sysCfg = SystemConfig(
                    NCellId = PCI,
                    FrequencyRange = 1,
                    BandWidth = 60,
                    Numerology = 1,
                    CpType = 0,
                    NTxAnt = 1,
                    NRxAnt = 8,
                    BwpNRb = 162,
                    BwpRbOffset = 0,
                    harqProcFlag = 0,
                    nHarqProc = 1,
                    rvSeq = 0
                )
ueCfg = UeConfig(
                TransformPrecoding = 0,
                Rnti = RNTI,
                nId = PCI,
                CodeBookBased = 0,
                DmrsPortSetIdx = [0],
                NLayers = 1,
                NumDmrsCdmGroupsWithoutData = 2,
                Tpmi = 0,
                FirstSymb = 0,
                NPuschSymbAll = 14,
                RaType = 1,
                FirstPrb = RB_start,
                NPrb = NRB,
                FrequencyHoppingMode = 0,
                McsTable = 0,
                Mcs = MCS,
                ILbrm = 0,
                nScId = 0,
                NnScIdId = PCI,
                DmrsConfigurationType = 0,
                DmrsDuration = 1,
                DmrsAdditionalPosition = 1,
                PuschMappingType = 0,
                DmrsTypeAPosition = 3,
                HoppingMode = 0,
                NRsId = 0,
                Ptrs = 0,
                ScalingFactor = 0,
                OAck = 0,
                IHarqAckOffset = 11,
                OCsi1 = 0,
                ICsi1Offset = 7,
                OCsi2 = 0,
                ICsi2Offset = 0,
                NPrbOh = 0,
                nCw = 1,
                TpPi2Bpsk = 0
            )
myCfg = MyConfig(sysCfg, [ueCfg])
puschCfg = MyPUSCHConfig(myCfg, 4)
# puschCfg.show() # uncomment for detail

In [33]:
#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Layers for LDPC channel encoding and utility functions."""

import tensorflow as tf
import numpy as np
import scipy as sp
from tensorflow.keras.layers import Layer
from importlib_resources import files, as_file
from sionna.fec.ldpc import codes # pylint: disable=relative-beyond-top-level
import numbers # to check if n, k are numbers

from sionna.fec.linear import AllZeroEncoder as AllZeroEncoder_new

class LDPC5GEncoder(Layer):
    # pylint: disable=line-too-long
    """LDPC5GEncoder(k, n, num_bits_per_symbol=None, dtype=tf.float32, **kwargs)

    5G NR LDPC Encoder following the 3GPP NR Initiative [3GPPTS38212_LDPC]_
    including rate-matching.

    The class inherits from the Keras layer class and can be used as layer in a
    Keras model.

    Parameters
    ----------
        k: int
            Defining the number of information bit per codeword.

        n: int
            Defining the desired codeword length.

        num_bits_per_symbol: int or None
            Defining the number of bits per QAM symbol. If this parameter is
            explicitly provided, the codeword will be interleaved after
            rate-matching as specified in Sec. 5.4.2.2 in [3GPPTS38212_LDPC]_.

        dtype: tf.DType
            Defaults to `tf.float32`. Defines the output datatype of the layer
            (internal precision remains `tf.uint8`).

    Input
    -----
        inputs: [...,k], tf.float32
            2+D tensor containing the information bits to be
            encoded.

    Output
    ------
        : [...,n], tf.float32
            2+D tensor of same shape as inputs besides last dimension has
            changed to `n` containing the encoded codeword bits.

    Attributes
    ----------
        k: int
            Defining the number of information bit per codeword.

        n: int
            Defining the desired codeword length.

        coderate: float
            Defining the coderate r= ``k`` / ``n``.

        n_ldpc: int
            An integer defining the total codeword length (before
            punturing) of the lifted parity-check matrix.

        k_ldpc: int
            An integer defining the total information bit length
            (before zero removal) of the lifted parity-check matrix. Gap to
            ``k`` must be filled with so-called filler bits.

        num_bits_per_symbol: int or None.
            Defining the number of bits per QAM symbol. If this parameter is
            explicitly provided, the codeword will be interleaved after
            rate-matching as specified in Sec. 5.4.2.2 in [3GPPTS38212_LDPC]_.

        out_int: [n], ndarray of int
            Defining the rate-matching output interleaver sequence.

        out_int_inv: [n], ndarray of int
            Defining the inverse rate-matching output interleaver sequence.

        _check_input: bool
            A boolean that indicates whether the input vector
            during call of the layer should be checked for consistency (i.e.,
            binary).

        _bg: str
            Denoting the selected basegraph (either `bg1` or `bg2`).

        _z: int
            Denoting the lifting factor.

        _i_ls: int
            Defining which version of the basegraph to load.
            Can take values between 0 and 7.

        _k_b: int
            Defining the number of `information bit columns` in the
            basegraph. Determined by the code design procedure in
            [3GPPTS38212_LDPC]_.

        _bm: ndarray
            An ndarray defining the basegraph.

        _pcm: sp.sparse.csr_matrix
            A sparse matrix of shape `[k_ldpc-n_ldpc, n_ldpc]`
            containing the sparse parity-check matrix.

    Raises
    ------
        AssertionError
            If ``k`` is not `int`.

        AssertionError
            If ``n`` is not `int`.

        ValueError
            If ``code_length`` is not supported.

        ValueError
            If `dtype` is not supported.

        ValueError
            If ``inputs`` contains other values than `0` or `1`.

        InvalidArgumentError
            When rank(``inputs``)<2.

        InvalidArgumentError
            When shape of last dim is not ``k``.

    Note
    ----
        As specified in [3GPPTS38212_LDPC]_, the encoder also performs
        puncturing and shortening. Thus, the corresponding decoder needs to
        `invert` these operations, i.e., must be compatible with the 5G
        encoding scheme.
    """

    def __init__(self,
                 k,
                 n,
                 num_bits_per_symbol=None,
                 dtype=tf.float32,
                 **kwargs):

        super().__init__(dtype=dtype, **kwargs)

        assert isinstance(k, numbers.Number), "k must be a number."
        assert isinstance(n, numbers.Number), "n must be a number."
        k = int(k) # k or n can be float (e.g. as result of n=k*r)
        n = int(n) # k or n can be float (e.g. as result of n=k*r)

        if dtype is not tf.float32:
            print("Note: decoder uses tf.float32 for internal calculations.")

        if dtype not in (tf.float16, tf.float32, tf.float64, tf.int8,
            tf.int32, tf.int64, tf.uint8, tf.uint16, tf.uint32):
            raise ValueError("Unsupported dtype.")
        self._dtype = dtype

        if k>8448:
            raise ValueError("Unsupported code length (k too large).")
        if k<12:
            raise ValueError("Unsupported code length (k too small).")

        if n>(316*384):
            raise ValueError("Unsupported code length (n too large).")
        if n<0:
            raise ValueError("Unsupported code length (n negative).")

        # init encoder parameters
        self._k = k # number of input bits (= input shape)
        self._n = n # the desired length (= output shape)
        self._coderate = k / n
        self._check_input = True # check input for consistency (i.e., binary)

        # allow actual code rates slightly larger than 948/1024
        # to account for the quantization procedure in 38.214 5.1.3.1
        if self._coderate>(948/1024): # as specified in 38.212 5.4.2.1
            print(f"Warning: effective coderate r>948/1024 for n={n}, k={k}.")
        if self._coderate>(0.95): # as specified in 38.212 5.4.2.1
            raise ValueError(f"Unsupported coderate (r>0.95) for n={n}, k={k}.")
        # if self._coderate<(1/5):
        #     # outer rep. coding currently not supported
        #     raise ValueError("Unsupported coderate (r<1/5).")

        # construct the basegraph according to 38.212
        self._bg = self._sel_basegraph(self._k, self._coderate)
        self._z, self._i_ls, self._k_b = self._sel_lifting(self._k, self._bg)
        self._bm = self._load_basegraph(self._i_ls, self._bg)

        # total number of codeword bits
        self._n_ldpc = self._bm.shape[1] * self._z
        # if K_real < K _target puncturing must be applied earlier
        self._k_ldpc = self._k_b * self._z

        # construct explicit graph via lifting
        pcm = self._lift_basegraph(self._bm, self._z)

        pcm_a, pcm_b_inv, pcm_c1, pcm_c2 = self._gen_submat(self._bm,
                                                            self._k_b,
                                                            self._z,
                                                            self._bg)

        # init sub-matrices for fast encoding ("RU"-method)
        # note: dtype is tf.float32;
        self._pcm = pcm # store the sparse parity-check matrix (for decoding)

        # store indices for fast gathering (instead of explicit matmul)
        self._pcm_a_ind = self._mat_to_ind(pcm_a)
        self._pcm_b_inv_ind = self._mat_to_ind(pcm_b_inv)
        self._pcm_c1_ind = self._mat_to_ind(pcm_c1)
        self._pcm_c2_ind = self._mat_to_ind(pcm_c2)

        self._num_bits_per_symbol = num_bits_per_symbol
        if num_bits_per_symbol is not None:
            self._out_int, self._out_int_inv  = self.generate_out_int(self._n,
                                                    self._num_bits_per_symbol)

    #########################################
    # Public methods and properties
    #########################################

    @property
    def k(self):
        """Number of input information bits."""
        return self._k

    @property
    def n(self):
        "Number of output codeword bits."
        return self._n

    @property
    def coderate(self):
        """Coderate of the LDPC code after rate-matching."""
        return self._coderate

    @property
    def k_ldpc(self):
        """Number of LDPC information bits after rate-matching."""
        return self._k_ldpc

    @property
    def n_ldpc(self):
        """Number of LDPC codeword bits before rate-matching."""
        return self._n_ldpc

    @property
    def pcm(self):
        """Parity-check matrix for given code parameters."""
        return self._pcm

    @property
    def z(self):
        """Lifting factor of the basegraph."""
        return self._z

    @property
    def num_bits_per_symbol(self):
        """Modulation order used for the rate-matching output interleaver."""
        return self._num_bits_per_symbol

    @property
    def out_int(self):
        """Output interleaver sequence as defined in 5.4.2.2."""
        return self._out_int
    @property
    def out_int_inv(self):
        """Inverse output interleaver sequence as defined in 5.4.2.2."""
        return self._out_int_inv

    #########################
    # Utility methods
    #########################

    def generate_out_int(self, n, num_bits_per_symbol):
        """"Generates LDPC output interleaver sequence as defined in
        Sec 5.4.2.2 in [3GPPTS38212_LDPC]_.

        Parameters
        ----------
        n: int
            Desired output sequence length.

        num_bits_per_symbol: int
            Number of symbols per QAM symbol, i.e., the modulation order.

        Output
        ------
        (perm_seq, perm_seq_inv):
            Tuple:

        perm_seq: ndarray of length n
            Containing the permuted indices.

        perm_seq_inv: ndarray of length n
            Containing the inverse permuted indices.

        Note
        ----
        The interleaver pattern depends on the modulation order and helps to
        reduce dependencies in bit-interleaved coded modulation (BICM) schemes.
        """
        # allow float inputs, but verify that they represent integer
        assert(n%1==0), "n must be int."
        assert(num_bits_per_symbol%1==0), "num_bits_per_symbol must be int."
        n = int(n)
        assert(n>0), "n must be a positive integer."
        assert(num_bits_per_symbol>0), \
                    "num_bits_per_symbol must be a positive integer."
        num_bits_per_symbol = int(num_bits_per_symbol)

        assert(n%num_bits_per_symbol==0),\
            "n must be a multiple of num_bits_per_symbol."

        # pattern as defined in Sec 5.4.2.2
        perm_seq = np.zeros(n, dtype=int)
        for j in range(int(n/num_bits_per_symbol)):
            for i in range(num_bits_per_symbol):
                perm_seq[i + j*num_bits_per_symbol] \
                    = int(i * int(n/num_bits_per_symbol) + j)

        perm_seq_inv = np.argsort(perm_seq)

        return perm_seq, perm_seq_inv

    def _sel_basegraph(self, k, r):
        """Select basegraph according to [3GPPTS38212_LDPC]_."""

        if k <= 292:
            bg = "bg2"
        elif k <= 3824 and r <= 0.67:
            bg = "bg2"
        elif r <= 0.25:
            bg = "bg2"
        else:
            bg = "bg1"

        # add for consistency
        if bg=="bg1" and k>8448:
            raise ValueError("K is not supported by BG1 (too large).")

        if bg=="bg2" and k>3840:
            raise ValueError(
                f"K is not supported by BG2 (too large) k ={k}.")

        # if bg=="bg1" and r<1/3:
        #     raise ValueError("Only coderate>1/3 supported for BG1. \
        #     Remark: Repetition coding is currently not supported.")

        # if bg=="bg2" and r<1/5:
        #     raise ValueError("Only coderate>1/5 supported for BG2. \
        #     Remark: Repetition coding is currently not supported.")

        return bg

    def _load_basegraph(self, i_ls, bg):
        """Helper to load basegraph from csv files.

        ``i_ls`` is sub_index of the basegraph and fixed during lifting
        selection.
        """

        if i_ls > 7:
            raise ValueError("i_ls too large.")

        if i_ls < 0:
            raise ValueError("i_ls cannot be negative.")

        # csv files are taken from 38.212 and dimension is explicitly given
        if bg=="bg1":
            bm = np.zeros([46, 68]) - 1 # init matrix with -1 (None positions)
        elif bg=="bg2":
            bm = np.zeros([42, 52]) - 1 # init matrix with -1 (None positions)
        else:
            raise ValueError("Basegraph not supported.")

        # and load the basegraph from csv format in folder "codes"
        source = files(codes).joinpath(f"5G_{bg}.csv")
        with as_file(source) as codes.csv:
            bg_csv = np.genfromtxt(codes.csv, delimiter=";")

        # reconstruct BG for given i_ls
        r_ind = 0
        for r in np.arange(2, bg_csv.shape[0]):
            # check for next row index
            if not np.isnan(bg_csv[r, 0]):
                r_ind = int(bg_csv[r, 0])
            c_ind = int(bg_csv[r, 1]) # second column in csv is column index
            value = bg_csv[r, i_ls + 2] # i_ls entries start at offset 2
            bm[r_ind, c_ind] = value

        return bm

    def _lift_basegraph(self, bm, z):
        """Lift basegraph with lifting factor ``z`` and shifted identities as
        defined by the entries of ``bm``."""

        num_nonzero = np.sum(bm>=0) # num of non-neg elements in bm

        # init all non-zero row/column indices
        r_idx = np.zeros(z*num_nonzero)
        c_idx = np.zeros(z*num_nonzero)
        data = np.ones(z*num_nonzero)

        # row/column indices of identity matrix for lifting
        im = np.arange(z)

        idx = 0
        for r in range(bm.shape[0]):
            for c in range(bm.shape[1]):
                if bm[r,c]==-1: # -1 is used as all-zero matrix placeholder
                    pass #do nothing (sparse)
                else:
                    # roll matrix by bm[r,c]
                    c_roll = np.mod(im+bm[r,c], z)
                    # append rolled identity matrix to pcm
                    r_idx[idx*z:(idx+1)*z] = r*z + im
                    c_idx[idx*z:(idx+1)*z] = c*z + c_roll
                    idx += 1

        # generate lifted sparse matrix from indices
        pcm = sp.sparse.csr_matrix((data,(r_idx, c_idx)),
                                   shape=(z*bm.shape[0], z*bm.shape[1]))
        return pcm

    def _sel_lifting(self, k, bg):
        """Select lifting as defined in Sec. 5.2.2 in [3GPPTS38212_LDPC]_.

        We assume B < K_cb, thus B'= B and C = 1, i.e., no
        additional CRC is appended. Thus, K' = B'/C = B and B is our K.

        Z is the lifting factor.
        i_ls is the set index ranging from 0...7 (specifying the exact bg
        selection).
        k_b is the number of information bit columns in the basegraph.
        """
        # lifting set according to 38.212 Tab 5.3.2-1
        s_val = [[2, 4, 8, 16, 32, 64, 128, 256],
                [3, 6, 12, 24, 48, 96, 192, 384],
                [5, 10, 20, 40, 80, 160, 320],
                [7, 14, 28, 56, 112, 224],
                [9, 18, 36, 72, 144, 288],
                [11, 22, 44, 88, 176, 352],
                [13, 26, 52, 104, 208],
                [15, 30, 60, 120, 240]]

        if bg == "bg1":
            k_b = 22
        else:
            if k > 640:
                k_b = 10
            elif k > 560:
                k_b = 9
            elif k > 192:
                k_b = 8
            else:
                k_b = 6

        # find the min of Z from Tab. 5.3.2-1 s.t. k_b*Z>=K'
        min_val = 100000
        z = 0
        i_ls = 0
        i = -1
        for s in s_val:
            i += 1
            for s1 in s:
                x = k_b *s1
                if  x >= k:
                    # valid solution
                    if x < min_val:
                        min_val = x
                        z = s1
                        i_ls = i

        # and set K=22*Z for bg1 and K=10Z for bg2
        if bg == "bg1":
            k_b = 22
        else:
            k_b = 10
        # print(z, i_ls, k_b)
        return z, i_ls, k_b

    def _gen_submat(self, bm, k_b, z, bg):
        """Split the basegraph into multiple sub-matrices such that efficient
        encoding is possible.
        """
        g = 4 # code property (always fixed for 5G)
        mb = bm.shape[0] # number of CN rows in basegraph (BG property)

        bm_a = bm[0:g, 0:k_b]
        bm_b = bm[0:g, k_b:(k_b+g)]
        bm_c1 = bm[g:mb, 0:k_b]
        bm_c2 = bm[g:mb, k_b:(k_b+g)]
        # print(bm_a, bm_b, bm_c1, bm_c2)
        # print(np.linalg.norm(bm_a, 'fro'),
        #       np.linalg.norm(bm_b, 'fro'),
        #       np.linalg.norm(bm_c1, 'fro'),
        #       np.linalg.norm(bm_c2, 'fro'))
        # H could be sliced immediately (but easier to implement if based on B)
        hm_a = self._lift_basegraph(bm_a, z)

        # not required for encoding, but helpful for debugging
        #hm_b = self._lift_basegraph(bm_b, z)

        hm_c1 = self._lift_basegraph(bm_c1, z)
        hm_c2 = self._lift_basegraph(bm_c2, z)

        hm_b_inv = self._find_hm_b_inv(bm_b, z, bg)
        # print(hm_a, hm_b_inv, hm_c1, hm_c2)
        # print(np.linalg.norm(hm_a.toarray(), 'fro'),
        #       np.linalg.norm(hm_b_inv.toarray(), 'fro'),
        #       np.linalg.norm(hm_c1.toarray(), 'fro'),
        #       np.linalg.norm(hm_c2.toarray(), 'fro'))
        return hm_a, hm_b_inv, hm_c1, hm_c2

    def _find_hm_b_inv(self, bm_b, z, bg):
        """ For encoding we need to find the inverse of `hm_b` such that
        `hm_b^-1 * hm_b = I`.

        Could be done sparse
        For BG1 the structure of hm_b is given as (for all values of i_ls)
        hm_b =
        [P_A I 0 0
         P_B I I 0
         0 0 I I
         P_A 0 0 I]
        where P_B and P_A are Shifted identities.

        The inverse can be found by solving a linear system of equations
        hm_b_inv =
        [P_B^-1, P_B^-1, P_B^-1, P_B^-1,
         I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
         P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1, I+P_A*P_B^-1,
         P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1].


        For bg2 the structure of hm_b is given as (for all values of i_ls)
        hm_b =
        [P_A I 0 0
         0 I I 0
         P_B 0 I I
         P_A 0 0 I]
        where P_B and P_A are Shifted identities

        The inverse can be found by solving a linear system of equations
        hm_b_inv =
        [P_B^-1, P_B^-1, P_B^-1, P_B^-1,
         I + P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
         I+P_A*P_B^-1, I+P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1,
         P_A*P_B^-1, P_A*P_B^-1, P_A*P_B^-1, I+P_A*P_B^-1]

        Note: the inverse of B is simply a shifted identity matrix with
        negative shift direction.
        """

        # permutation indices
        pm_a= int(bm_b[0,0])
        if bg=="bg1":
            pm_b_inv = int(-bm_b[1, 0])
        else: # structure of B is slightly different for bg2
            pm_b_inv = int(-bm_b[2, 0])

        hm_b_inv = np.zeros([4*z, 4*z])

        im = np.eye(z)

        am = np.roll(im, pm_a, axis=1)
        b_inv = np.roll(im, pm_b_inv, axis=1)
        ab_inv = np.matmul(am, b_inv)

        # row 0
        hm_b_inv[0:z, 0:z] = b_inv
        hm_b_inv[0:z, z:2*z] = b_inv
        hm_b_inv[0:z, 2*z:3*z] = b_inv
        hm_b_inv[0:z, 3*z:4*z] = b_inv

        # row 1
        hm_b_inv[z:2*z, 0:z] = im + ab_inv
        hm_b_inv[z:2*z, z:2*z] = ab_inv
        hm_b_inv[z:2*z, 2*z:3*z] = ab_inv
        hm_b_inv[z:2*z, 3*z:4*z] = ab_inv

        # row 2
        if bg=="bg1":
            hm_b_inv[2*z:3*z, 0:z] = ab_inv
            hm_b_inv[2*z:3*z, z:2*z] = ab_inv
            hm_b_inv[2*z:3*z, 2*z:3*z] = im + ab_inv
            hm_b_inv[2*z:3*z, 3*z:4*z] = im + ab_inv
        else: # for bg2 the structure is slightly different
            hm_b_inv[2*z:3*z, 0:z] = im + ab_inv
            hm_b_inv[2*z:3*z, z:2*z] = im + ab_inv
            hm_b_inv[2*z:3*z, 2*z:3*z] = ab_inv
            hm_b_inv[2*z:3*z, 3*z:4*z] = ab_inv

        # row 3
        hm_b_inv[3*z:4*z, 0:z] = ab_inv
        hm_b_inv[3*z:4*z, z:2*z] = ab_inv
        hm_b_inv[3*z:4*z, 2*z:3*z] = ab_inv
        hm_b_inv[3*z:4*z, 3*z:4*z] = im + ab_inv

        # return results as sparse matrix
        return sp.sparse.csr_matrix(hm_b_inv)

    def _mat_to_ind(self, mat):
        """Helper to transform matrix into index representation for
        tf.gather. An index pointing to the `last_ind+1` is used for non-existing edges due to irregular degrees."""
        m = mat.shape[0]
        n = mat.shape[1]

        # transpose mat for sorted column format
        c_idx, r_idx, _ = sp.sparse.find(mat.transpose())

        # sort indices explicitly, as scipy.sparse.find changed from column to
        # row sorting in scipy>=1.11
        idx = np.argsort(r_idx)
        c_idx = c_idx[idx]
        r_idx = r_idx[idx]

        # find max number of no-zero entries
        n_max = np.max(mat.getnnz(axis=1))

        # init index array with n (pointer to last_ind+1, will be a default
        # value)
        gat_idx = np.zeros([m, n_max]) + n

        r_val = -1
        c_val = 0
        for idx in range(len(c_idx)):
            # check if same row or if a new row starts
            if r_idx[idx] != r_val:
                r_val = r_idx[idx]
                c_val = 0
            gat_idx[r_val, c_val] = c_idx[idx]
            c_val += 1

        gat_idx = tf.cast(tf.constant(gat_idx), tf.int32)
        return gat_idx

    def _matmul_gather(self, mat, vec):
        """Implements a fast sparse matmul via gather function."""

        # add 0 entry for gather-reduce_sum operation
        # (otherwise ragged Tensors are required)
        bs = tf.shape(vec)[0]
        vec = tf.concat([vec, tf.zeros([bs, 1], dtype=self.dtype)], 1)

        retval = tf.gather(vec, mat, batch_dims=0, axis=1)
        retval = tf.reduce_sum(retval, axis=-1)

        return retval

    def _encode_fast(self, s):
        """Main encoding function based on gathering function."""
        s_2 = s
        print('s_2[:, 2*self._z:]', poly_hash(s_2[:, 2*self._z:]), s_2[:, 2*self._z:].shape)

        s = tf.where(s == -1, tf.zeros_like(s), s)

        p_a = self._matmul_gather(self._pcm_a_ind, s)
        p_a = self._matmul_gather(self._pcm_b_inv_ind, p_a)
        # calc second part of parity bits p_b
        # second parities are given by C_1*s' + C_2*p_a' + p_b' = 0
        p_b_1 = self._matmul_gather(self._pcm_c1_ind, s)
        p_b_2 = self._matmul_gather(self._pcm_c2_ind, p_a)
        p_b = p_b_1 + p_b_2

        w = tf.concat([p_a, p_b], 1)
        w = tf.cast(w, tf.uint8)
        w = tf.bitwise.bitwise_and(w, tf.constant(1, tf.uint8))
        w = tf.cast(w, self.dtype)
        print('w', poly_hash(w), w.shape)

        c = tf.concat([s_2, w], 1)
        print('c[:, 2*self._z:]', poly_hash(c[:, 2*self._z:]), c[:, 2*self._z:].shape)
        return c
        c = tf.concat([s_2, p_a, p_b], 1)
        # faster implementation of mod-2 operation c = tf.math.mod(c, 2)
        c_uint8 = tf.cast(c, tf.uint8)
        c_bin = tf.bitwise.bitwise_and(c_uint8, tf.constant(1, tf.uint8))
        c = tf.cast(c_bin, self.dtype)
        c = tf.expand_dims(c, axis=-1) # returns nx1 vector
        return c

    #########################
    # Keras layer functions
    #########################

    def build(self, input_shape):
        """"Build layer."""
        # check if k and input shape match
        assert (input_shape[-1]==self._k), "Last dimension must be of length k."
        assert (len(input_shape)>=2), "Rank of input must be at least 2."

    def call(self, inputs):
        """5G LDPC encoding function including rate-matching.

        This function returns the encoded codewords as specified by the 3GPP NR Initiative [3GPPTS38212_LDPC]_ including puncturing and shortening.

        Args:
            inputs (tf.float32): Tensor of shape `[...,k]` containing the
                information bits to be encoded.

        Returns:
            `tf.float32`: Tensor of shape `[...,n]`.

        Raises:
            ValueError: If ``inputs`` contains other values than `0` or `1`.

            InvalidArgumentError: When rank(``inputs``)<2.

            InvalidArgumentError: When shape of last dim is not ``k``.
        """

        tf.debugging.assert_type(inputs, self.dtype, "Invalid input dtype.")

        # Reshape inputs to [...,k]
        input_shape = inputs.get_shape().as_list()
        new_shape = [-1, input_shape[-1]]
        u = tf.reshape(inputs, new_shape)
        print('u', poly_hash(u), u.shape)
        # assert if u is non binary
        if self._check_input:
            tf.debugging.assert_equal(
                tf.reduce_min(
                    tf.cast(
                        tf.logical_or(
                            tf.equal(u, tf.constant(0, self.dtype)),
                            tf.equal(u, tf.constant(1, self.dtype)),
                            ),
                        self.dtype)),
                tf.constant(1, self.dtype),
                "Input must be binary.")
            # input datatype consistency should be only evaluated once
            self._check_input = False

        batch_size = tf.shape(u)[0]

        # add "filler" bits to last positions to match info bit length k_ldpc
        u_fill = tf.concat([u,
                    -tf.ones([batch_size, self._k_ldpc-self._k], self.dtype)],
                            1)
        print('u_fill', poly_hash(u_fill), u_fill.shape)
        # use optimized encoding based on tf.gather
        c = self._encode_fast(u_fill)
        c = tf.reshape(c, [batch_size, self._n_ldpc]) # remove last dim
        print('c', poly_hash(c), c.shape)
        print('c[:, 2*self._z:]', poly_hash(c[:, 2*self._z:]), c[:, 2*self._z:].shape)

      
        print('z', self._z, 'n', self.n, 'n_lpdc', self._n_ldpc, 'k', self._k, 'k_lpdc', self._k_ldpc)

        shorten_size = self._n+self._k_ldpc-self._k
        print('shorten_size', shorten_size, 'shorten_size-self._k_ldpc+2*self._z', shorten_size-self._k_ldpc+2*self._z, 'self._k-2*self._z', self._k-2*self._z)
        c_short = tf.slice(c, [0, 2*self._z], [batch_size,  shorten_size])
        print('c_short', poly_hash(c_short), c_short.shape)
        c_no_filler1 = tf.slice(c_short, [0, 0], [batch_size, self._k-2*self._z])
        c_no_filler2 = tf.slice(c_short,
                               [0, self._k_ldpc-2*self._z],
                               [batch_size, shorten_size - self._k_ldpc+2*self._z])
        
        c_no_filler = tf.concat([c_no_filler1, c_no_filler2], 1)

        print('c_no_filler', poly_hash(c_no_filler), c_no_filler.shape)

        # remove filler bits at pos (k, k_ldpc)
        c_no_filler1 = tf.slice(c, [0, 0], [batch_size, self._k])
        c_no_filler2 = tf.slice(c,
                               [0, self._k_ldpc],
                               [batch_size, self._n_ldpc-self._k_ldpc])
        
        print('c_no_filler1', poly_hash(c_no_filler1), c_no_filler1.shape)
        print('c_no_filler2', poly_hash(c_no_filler2), c_no_filler2.shape)
        c_no_filler = tf.concat([c_no_filler1, c_no_filler2], 1)
        print('c_no_filler', poly_hash(c_no_filler), c_no_filler.shape)
        # shorten the first 2*Z positions and end after n bits
        # (remaining parity bits can be used for IR-HARQ)
        
        c_short = tf.slice(c_no_filler, [0, 2*self._z], [batch_size, self.n])
        print('c_short', poly_hash(c_short), c_short.shape)
        # incremental redundancy could be generated by accessing the last bits
        # print('c_short',c_short.shape, poly_hash(c_short, [1]))
        # if num_bits_per_symbol is provided, apply output interleaver as
        # specified in Sec. 5.4.2.2 in 38.212
        if self._num_bits_per_symbol is not None:
            c_short = tf.gather(c_short, self._out_int, axis=-1)
        print('c_short', poly_hash(c_short), c_short.shape)
        # Reshape c_short so that it matches the original input dimensions
        output_shape = input_shape[0:-1] + [self.n]
        output_shape[0] = -1
        c_reshaped = tf.reshape(c_short, output_shape)

        return tf.cast(c_reshaped, self._dtype)


###########################################################
# Deprecated aliases that will not be included in the next
# major release
###########################################################

def AllZeroEncoder(k,
                   n,
                   dtype=tf.float32,
                   **kwargs):
    print("Warning: The alias fec.ldpc.AllZeroEncoder will not be included in "\
          "Sionna 1.0. Please use sionna.fec.linear.AllZeroEncoder instead.")
    return AllZeroEncoder_new(k=k,
                              n=n,
                              dtype=dtype,
                              **kwargs)


In [34]:
from sionna.fec.crc import CRCEncoder
from sionna.fec.scrambling import TB5GScrambler
from sionna.nr.utils import calculate_tb_size

class TBEncoder(Layer):
    # pylint: disable=line-too-long
    r"""TBEncoder(target_tb_size,num_coded_bits,target_coderate,num_bits_per_symbol,num_layers=1,n_rnti=1,n_id=1,channel_type="PUSCH",codeword_index=0,use_scrambler=True,verbose=False,output_dtype=tf.float32,, **kwargs)
    5G NR transport block (TB) encoder as defined in TS 38.214
    [3GPP38214]_ and TS 38.211 [3GPP38211]_

    The transport block (TB) encoder takes as input a `transport block` of
    information bits and generates a sequence of codewords for transmission.
    For this, the information bit sequence is segmented into multiple codewords,
    protected by additional CRC checks and FEC encoded. Further, interleaving
    and scrambling is applied before a codeword concatenation generates the
    final bit sequence. Fig. 1 provides an overview of the TB encoding
    procedure and we refer the interested reader to [3GPP38214]_ and
    [3GPP38211]_ for further details.

    ..  figure:: ../figures/tb_encoding.png

        Fig. 1: Overview TB encoding (CB CRC does not always apply).

    If ``n_rnti`` and ``n_id`` are given as list, the TBEncoder encodes
    `num_tx = len(` ``n_rnti`` `)` parallel input streams with different
    scrambling sequences per user.

    The class inherits from the Keras layer class and can be used as layer in a
    Keras model.

    Parameters
    ----------
        target_tb_size: int
            Target transport block size, i.e., how many information bits are
            encoded into the TB. Note that the effective TB size can be
            slightly different due to quantization. If required, zero padding
            is internally applied.

        num_coded_bits: int
            Number of coded bits after TB encoding.

        target_coderate : float
            Target coderate.

        num_bits_per_symbol: int
            Modulation order, i.e., number of bits per QAM symbol.

        num_layers: int, 1 (default) | [1,...,8]
            Number of transmission layers.

        n_rnti: int or list of ints, 1 (default) | [0,...,65335]
            RNTI identifier provided by higher layer. Defaults to 1 and must be
            in range `[0, 65335]`. Defines a part of the random seed of the
            scrambler. If provided as list, every list entry defines the RNTI
            of an independent input stream.

        n_id: int or list of ints, 1 (default) | [0,...,1023]
            Data scrambling ID :math:`n_\text{ID}` related to cell id and
            provided by higher layer.
            Defaults to 1 and must be in range `[0, 1023]`. If provided as
            list, every list entry defines the scrambling id of an independent
            input stream.

        channel_type: str, "PUSCH" (default) | "PDSCH"
            Can be either "PUSCH" or "PDSCH".

        codeword_index: int, 0 (default) | 1
            Scrambler can be configured for two codeword transmission.
            ``codeword_index`` can be either 0 or 1. Must be 0 for
            ``channel_type`` = "PUSCH".

        use_scrambler: bool, True (default)
            If False, no data scrambling is applied (non standard-compliant).

        verbose: bool, False (default)
            If `True`, additional parameters are printed during initialization.

        dtype: tf.float32 (default)
            Defines the datatype for internal calculations and the output dtype.

    Input
    -----
        inputs: [...,target_tb_size] or [...,num_tx,target_tb_size], tf.float
            2+D tensor containing the information bits to be encoded. If
            ``n_rnti`` and ``n_id`` are a list of size `num_tx`, the input must
            be of shape `[...,num_tx,target_tb_size]`.

    Output
    ------
        : [...,num_coded_bits], tf.float
            2+D tensor containing the sequence of the encoded codeword bits of
            the transport block.

    Note
    ----
    The parameters ``tb_size`` and ``num_coded_bits`` can be derived by the
    :meth:`~sionna.nr.calculate_tb_size` function or
    by accessing the corresponding :class:`~sionna.nr.PUSCHConfig` attributes.
    """

    def __init__(self,
                 target_tb_size,
                 num_coded_bits,
                 target_coderate,
                 num_bits_per_symbol,
                 num_layers=1,
                 n_rnti=1,
                 n_id=1,
                 channel_type="PUSCH",
                 codeword_index=0,
                 use_scrambler=True,
                 verbose=False,
                 output_dtype=tf.float32,
                 **kwargs):

        super().__init__(dtype=output_dtype, **kwargs)

        assert isinstance(use_scrambler, bool), \
                                "use_scrambler must be bool."
        self._use_scrambler = use_scrambler
        assert isinstance(verbose, bool), \
                                "verbose must be bool."
        self._verbose = verbose

        # check input for consistency
        assert channel_type in ("PDSCH", "PUSCH"), \
                                "Unsupported channel_type."
        self._channel_type = channel_type

        assert(target_tb_size%1==0), "target_tb_size must be int."
        self._target_tb_size = int(target_tb_size)

        assert(num_coded_bits%1==0), "num_coded_bits must be int."
        self._num_coded_bits = int(num_coded_bits)

        assert(0.<target_coderate <= 948/1024), \
                    "target_coderate must be in range(0,0.925)."
        self._target_coderate = target_coderate

        assert(num_bits_per_symbol%1==0), "num_bits_per_symbol must be int."
        self._num_bits_per_symbol = int(num_bits_per_symbol)

        assert(num_layers%1==0), "num_layers must be int."
        self._num_layers = int(num_layers)

        if channel_type=="PDSCH":
            assert(codeword_index in (0,1)), "codeword_index must be 0 or 1."
        else:
            assert codeword_index==0, 'codeword_index must be 0 for "PUSCH".'
        self._codeword_index = int(codeword_index)

        if isinstance(n_rnti, (list, tuple)):
            assert isinstance(n_id, (list, tuple)), "n_id must be also a list."
            assert (len(n_rnti)==len(n_id)), \
                                "n_id and n_rnti must be of same length."
            self._n_rnti = n_rnti
            self._n_id = n_id
        else:
            self._n_rnti = [n_rnti]
            self._n_id = [n_id]

        for idx, n in enumerate(self._n_rnti):
            assert(n%1==0), "n_rnti must be int."
            self._n_rnti[idx] = int(n)
        for idx, n in enumerate(self._n_id):
            assert(n%1==0), "n_id must be int."
            self._n_id[idx] = int(n)

        self._num_tx = len(self._n_id)

        tbconfig = calculate_tb_size(target_tb_size=self._target_tb_size,
                                     num_coded_bits=self._num_coded_bits,
                                     target_coderate=self._target_coderate,
                                     modulation_order=self._num_bits_per_symbol,
                                     num_layers=self._num_layers,
                                     verbose=verbose)
        self._tb_size = tbconfig[0]
        self._cb_size = tbconfig[1]
        self._num_cbs = tbconfig[2]
        self._cw_lengths = tbconfig[3]
        self._tb_crc_length = tbconfig[4]
        self._cb_crc_length = tbconfig[5]

        assert self._tb_size <= self._tb_crc_length + np.sum(self._cw_lengths),\
            "Invalid TB parameters."

        # due to quantization, the tb_size can slightly differ from the
        # target tb_size.
        self._k_padding = self._tb_size - self._target_tb_size
        if self._tb_size != self._target_tb_size:
            print(f"Note: actual tb_size={self._tb_size} is slightly "\
                  f"different than requested " \
                  f"target_tb_size={self._target_tb_size} due to "\
                  f"quantization. Internal zero padding will be applied.")

        # calculate effective coderate (incl. CRC)
        self._coderate = self._tb_size / self._num_coded_bits

        # Remark: CRC16 is only used for k<3824 (otherwise CRC24)
        if self._tb_crc_length==16:
            self._tb_crc_encoder = CRCEncoder("CRC16")
        else:
            # CRC24A as defined in 7.2.1
            self._tb_crc_encoder = CRCEncoder("CRC24A")

        # CB CRC only if more than one CB is used
        if self._cb_crc_length==24:
            self._cb_crc_encoder = CRCEncoder("CRC24B")
        else:
            self._cb_crc_encoder = None

        # scrambler can be deactivated (non-standard compliant)
        if self._use_scrambler:
            self._scrambler = TB5GScrambler(n_rnti=self._n_rnti,
                                            n_id=self._n_id,
                                            binary=True,
                                            channel_type=channel_type,
                                            codeword_index=codeword_index,
                                            dtype=tf.float32,)
        else: # required for TBDecoder
            self._scrambler = None

        # ---- Init LDPC encoder ----
        # remark: as the codeword length can be (slightly) different
        # within a TB due to rounding, we initialize the encoder
        # with the max length and apply puncturing if required.
        # Thus, also the output interleaver cannot be applied in the encoder.
        # The procedure is defined in in 5.4.2.1 38.212
        self._encoder = LDPC5GEncoder(self._cb_size,
                                      np.max(self._cw_lengths),
                                      num_bits_per_symbol=1) #deact. interleaver

        # ---- Init interleaver ----
        # remark: explicit interleaver is required as the rate matching from
        # Sec. 5.4.2.1 38.212 could otherwise not be applied here
        perm_seq_short, _ = self._encoder.generate_out_int(
                                            np.min(self._cw_lengths),
                                            num_bits_per_symbol)
        perm_seq_long, _ = self._encoder.generate_out_int(
                                            np.max(self._cw_lengths),
                                            num_bits_per_symbol)

        perm_seq = []
        perm_seq_punc = []

        # define one big interleaver that moves the punctured positions to the
        # end of the TB
        payload_bit_pos = 0 # points to current pos of payload bits

        for l in self._cw_lengths:
            if np.min(self._cw_lengths)==l:
                perm_seq = np.concatenate([perm_seq,
                                           perm_seq_short + payload_bit_pos])
                # move unused bit positions to the end of TB
                # this simplifies the inverse permutation
                r = np.arange(payload_bit_pos+np.min(self._cw_lengths),
                              payload_bit_pos+np.max(self._cw_lengths))
                perm_seq_punc = np.concatenate([perm_seq_punc, r])

                # update pointer
                payload_bit_pos += np.max(self._cw_lengths)
            elif np.max(self._cw_lengths)==l:
                perm_seq = np.concatenate([perm_seq,
                                           perm_seq_long + payload_bit_pos])
                # update pointer
                payload_bit_pos += l
            else:
                raise ValueError("Invalid cw_lengths.")

        # add punctured positions to end of sequence (only relevant for
        # deinterleaving)
        perm_seq = np.concatenate([perm_seq, perm_seq_punc])

        self._output_perm = tf.constant(perm_seq, tf.int32)
        self._output_perm_inv = tf.argsort(perm_seq, axis=-1)

    #########################################
    # Public methods and properties
    #########################################


    @property
    def tb_size(self):
        r"""Effective number of information bits per TB.
        Note that (if required) internal zero padding can be applied to match
        the request exact ``target_tb_size``."""
        return self._tb_size

    @property
    def k(self):
        r"""Number of input information bits. Equals `tb_size` except for zero
        padding of the last positions if the ``target_tb_size`` is quantized."""
        return self._target_tb_size

    @property
    def k_padding(self):
        """Number of zero padded bits at the end of the TB."""
        return self._k_padding

    @property
    def n(self):
        "Total number of output bits."
        return self._num_coded_bits

    @property
    def num_cbs(self):
        "Number code blocks."
        return self._num_cbs

    @property
    def coderate(self):
        """Effective coderate of the TB after rate-matching including overhead
        for the CRC."""
        return self._coderate

    @property
    def ldpc_encoder(self):
        """LDPC encoder used for TB encoding."""
        return self._encoder

    @property
    def scrambler(self):
        """Scrambler used for TB scrambling. `None` if no scrambler is used."""
        return self._scrambler

    @property
    def tb_crc_encoder(self):
        """TB CRC encoder"""
        return self._tb_crc_encoder

    @property
    def cb_crc_encoder(self):
        """CB CRC encoder. `None` if no CB CRC is applied."""
        return self._cb_crc_encoder

    @property
    def num_tx(self):
        """Number of independent streams"""
        return self._num_tx

    @property
    def cw_lengths(self):
        r"""Each list element defines the codeword length of each of the
        codewords after LDPC encoding and rate-matching. The total number of
        coded bits is :math:`\sum` `cw_lengths`."""
        return self._cw_lengths

    @property
    def output_perm_inv(self):
        r"""Inverse interleaver pattern for output bit interleaver."""
        return self._output_perm_inv

    #########################
    # Keras layer functions
    #########################

    def build(self, input_shapes):
        """Test input shapes for consistency."""

        assert input_shapes[-1]==self.k, \
            f"Invalid input shape. Expected TB length is {self.k}."

    def call(self, inputs):
        """Apply transport block encoding procedure."""

        # store shapes
        input_shape = inputs.shape.as_list()
        u = tf.cast(inputs, tf.float32)
        print('u', poly_hash(u), u.shape)
        # apply zero padding if tb_size is slightly different to target_tb_size
        if self._k_padding>0:
            s = tf.shape(u)
            s = tf.concat((s[:-1], [self._k_padding]), axis=0)
            u = tf.concat((u, tf.zeros(s, u.dtype)), axis=-1)

        # apply TB CRC
        u_crc = self._tb_crc_encoder(u)
        print('u_crc', poly_hash(u_crc), u_crc.shape)
        # CB segmentation
        u_cb = tf.reshape(u_crc,
                          (-1, self._num_tx, self._num_cbs,
                          self._cb_size-self._cb_crc_length))
        print('u_cb', poly_hash(u_cb), u_cb.shape)
        # if relevant apply CB CRC
        if self._cb_crc_length==24:
            u_cb_crc = self._cb_crc_encoder(u_cb)
        else:
            u_cb_crc = u_cb # no CRC applied if only one CB exists
        print('u_cb_crc', poly_hash(u_cb_crc), u_cb_crc.shape)
        c_cb = self._encoder(u_cb_crc)
        print('c_cb', poly_hash(c_cb), u_cb.shape)
        # CB concatenation
        c = tf.reshape(c_cb,
                       (-1, self._num_tx,
                       self._num_cbs*np.max(self._cw_lengths)))

        # apply interleaver (done after CB concatenation)
        print('c', poly_hash(c), c.shape)
        print('output_perm', poly_hash(self._output_perm), self._output_perm.shape)
        c = tf.gather(c, self._output_perm, axis=-1)
        print('c', poly_hash(c), c.shape, np.sum(self._cw_lengths), self._n_rnti, self._n_id)
        # puncture last bits
        c = c[:, :, :np.sum(self._cw_lengths)]
        print('c', poly_hash(c), c.shape, np.sum(self._cw_lengths), self._n_rnti, self._n_id)
        # scrambler
        if self._use_scrambler:
            c_scr = self._scrambler(c)
        else: # disable scrambler (non-standard compliant)
            c_scr = c
        print('c_scr', poly_hash(c_scr[:,:,:]), c_scr.shape)
        # cast to output dtype
        c_scr = tf.cast(c_scr, self.dtype)

        # ensure output shapes
        output_shape = input_shape
        output_shape[0] = -1
        output_shape[-1] = np.sum(self._cw_lengths)
        c_tb = tf.reshape(c_scr, output_shape)
        print('c_tb', poly_hash(c_tb), c_tb.shape)
        return c_tb


In [35]:
class MySimulator():
    def __init__(self, pusch_config: MyPUSCHConfig):

        tb_size = pusch_config.tb_size
        num_coded_bits = pusch_config.num_coded_bits
        target_coderate = pusch_config.tb.target_coderate
        num_bits_per_symbol = pusch_config.tb.num_bits_per_symbol

        num_layers = pusch_config.num_layers
        n_rnti = pusch_config.n_rnti
        n_id = pusch_config.tb.n_id

        self.Binary_Source = BinarySource(dtype=tf.float32)
        self.TB_Encoder = TBEncoder(target_tb_size=tb_size,
                            num_coded_bits=num_coded_bits,
                            target_coderate=target_coderate,
                            num_bits_per_symbol=num_bits_per_symbol,
                            num_layers=num_layers,
                            n_rnti=n_rnti,
                            n_id=n_id,
                            channel_type="PUSCH",
                            codeword_index=0,
                            use_scrambler=True,
                            verbose=False,
                            output_dtype=tf.float32)
        
        self.Constellation_Mapper = Mapper("qam", num_bits_per_symbol, dtype=tf.complex64)

        self.Layer_Mapper = LayerMapper(num_layers=num_layers, dtype=tf.complex64)
    
        self.Pilot_Pattern = PUSCHPilotPattern([pusch_config], dtype=tf.complex64)

        num_subcarriers = pusch_config.num_subcarriers
        subcarrier_spacing = pusch_config.carrier.subcarrier_spacing*1e3
        fft_size = num_subcarriers
        cp_length = min(num_subcarriers, 288)
        guard_subcarriers = (0,0)
        # Define the resource grid.
        resource_grid = ResourceGrid(
            num_ofdm_symbols=14,
            fft_size=fft_size,
            subcarrier_spacing=subcarrier_spacing,
            num_tx=NUM_TX,
            num_streams_per_tx=NUM_STREAMS_PER_TX,
            cyclic_prefix_length=cp_length,
            num_guard_carriers=guard_subcarriers,
            dc_null=False,
            pilot_pattern=self.Pilot_Pattern,
            dtype=tf.complex64
        )

        self.Resource_Grid_Mapper = ResourceGridMapper(resource_grid, dtype=tf.complex64)        
        
        self.AWGN = AWGN()

 
        self.Channel_Estimator = PUSCHLSChannelEstimator(
                        resource_grid,
                        pusch_config.dmrs.length,
                        pusch_config.dmrs.additional_position,
                        pusch_config.dmrs.num_cdm_groups_without_data,
                        interpolation_type='nn',
                        dtype=tf.complex64)

        rxtx_association = np.ones([NUM_RX, NUM_TX], bool)
        stream_management = StreamManagement(rxtx_association, pusch_config.num_layers)
        self.Mimo_Detector = LinearDetector("lmmse", "bit", "maxlog", resource_grid, stream_management,
                                    "qam", pusch_config.tb.num_bits_per_symbol, dtype=tf.complex64)
        
        self.Equalizer = LinearDetector("lmmse", "symbol", "maxlog", resource_grid, stream_management,
                                    "qam", pusch_config.tb.num_bits_per_symbol, dtype=tf.complex64)

        self.Layer_Demapper = LayerDemapper(self.Layer_Mapper, num_bits_per_symbol=num_bits_per_symbol)
        # self.TB_Decoder = TBDecoder(self.TB_Encoder, output_dtype=tf.float32)

        self.tb_size = tb_size
        self.resource_grid = resource_grid
        self.pusch_config = pusch_config
        
    def update_pilots(self, pilots):
        self.Resource_Grid_Mapper._resource_grid.pilot_pattern.pilots = pilots
        """Channel Estimationand Detection will reflect this update since they reference the same object."""

    def sim(self, batch_size, channel_model, no_scaling, gen_prng_seq=None, return_tx_iq=False, return_channel=False):
        if gen_prng_seq:
            b = tf.reshape(tf.constant(generate_prng_seq(batch_size * NUM_TX * self.tb_size, gen_prng_seq), dtype=tf.float32), [batch_size, NUM_TX, self.tb_size])
        else:
            b = self.Binary_Source([batch_size, NUM_TX, self.tb_size])

        c = self.TB_Encoder(b)
        x_map = self.Constellation_Mapper(c)
        x_layer = self.Layer_Mapper(x_map)
        x = self.Resource_Grid_Mapper(x_layer)

        y, h = channel_model(x)
        no = no_scaling * tf.math.reduce_variance(y, axis=[-1,-2,-3,-4])

        y = self.AWGN([y, no])

        if return_channel:
            if return_tx_iq:
                return b, c, y, x, h
            return b, c, y, h
        
        if return_tx_iq:
            return b, c, y, x
        return b, c, y
    
    def ref(self, batch_size, dtype=tf.complex64):
        return tf.repeat(tf.transpose(tf.constant(self.pusch_config.dmrs_grid, dtype=dtype), [0, 2, 1])[None, None], repeats=batch_size, axis=0)
    
    def rec(self, y, snr_ = 1e3):
        no_ = tf.math.reduce_variance(y, axis=[-1,-2,-3,-4])/(snr_ + 1)
        h_hat, err_var = self.Channel_Estimator([y, no_])
        x_hat = self.Equalizer([y, h_hat, err_var, no_])
        llr_det = self.Mimo_Detector([y, h_hat, err_var, no_])
        llr_layer = self.Layer_Demapper(llr_det)
        b_hat, tb_crc_status = self.TB_Decoder(llr_layer)

        return h_hat, x_hat, llr_det, b_hat, tb_crc_status

In [36]:
simulator = MySimulator(puschCfg)
channel = OFDMChannel(channel_model=channel_model, resource_grid=simulator.resource_grid,
                                    add_awgn=False, normalize_channel=True, return_channel=True)

In [37]:
b, c, y, x, h = simulator.sim(batch_size, channel, no, gen_prng_seq=20044,return_tx_iq=True, return_channel=True)
r = simulator.ref(batch_size)

u [826608888.0] (1, 1, 4744)
u_crc [28112077.0] (1, 1, 4768)
u_cb [566695538.0, 782310056.0] (1, 1, 2, 2384)
u_cb_crc [151042419.0, 911073634.0] (1, 1, 2, 2408)
u [151042419.0, 911073634.0] (2, 2408)
u_fill [334937524.0, 48362687.0] (2, 2560)
s_2[:, 2*self._z:] [539127621.0, 348169825.0] (2, 2048)
w [635687629.0, 75742688.0] (2, 10752)
c[:, 2*self._z:] [521379575.0, 984834594.0] (2, 12800)
c [197274699.0, 330434006.0] (2, 13312)
c[:, 2*self._z:] [521379575.0, 984834594.0] (2, 12800)
z 256 n 20160 n_lpdc 13312 k 2408 k_lpdc 2560
shorten_size 20312 shorten_size-self._k_ldpc+2*self._z 18264 self._k-2*self._z 1896


InvalidArgumentError: Exception encountered when calling layer 'ldpc5g_encoder_5' (type LDPC5GEncoder).

{{function_node __wrapped__Slice_device_/job:localhost/replica:0/task:0/device:CPU:0}} Expected size[1] in [0, 12800], but got 20312 [Op:Slice]

Call arguments received by layer 'ldpc5g_encoder_5' (type LDPC5GEncoder):
  • inputs=tf.Tensor(shape=(1, 1, 2, 2408), dtype=float32)

# Evaluate

In [None]:
"""Evaluate"""
preds = predict(_model, y, r)
c_pred = tf.reshape(preds, [preds.shape[0], 1, c.shape[-1]])
b_hat, crc = simulator.TB_Decode(c_pred)
# loss_cal(c_pred, c)
compute_ber(b, b_hat), crc

(<tf.Tensor: shape=(), dtype=float64, numpy=0.003543534856812265>,
 <tf.Tensor: shape=(8, 1), dtype=bool, numpy=
 array([[ True],
        [ True],
        [ True],
        [False],
        [ True],
        [ True],
        [ True],
        [ True]])>)

In [None]:
h_est, x_hat, llr_det, b_hat, crc = simulator.rec(y)
compute_ber(b, b_hat), crc

(<tf.Tensor: shape=(), dtype=float64, numpy=0.13842385015909747>,
 <tf.Tensor: shape=(8, 1), dtype=bool, numpy=
 array([[False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False]])>)