In [1]:
import os
if os.getenv("CUDA_VISIBLE_DEVICES") is None:
    gpu_num = 0 # Use "" to use the CPU
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_num}"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import sionna.phy


import tensorflow as tf
# Configure the notebook to use only a single GPU and allocate only as much memory as needed
# For more details, see https://www.tensorflow.org/guide/gpu
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)
# Avoid warnings from TensorFlow
tf.get_logger().setLevel('ERROR')

sionna.phy.config.seed = 42 # Set seed for reproducible results

E0000 00:00:1746781684.292198   70065 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746781684.303833   70065 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746781684.325254   70065 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746781684.325283   70065 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746781684.325285   70065 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746781684.325287   70065 computation_placer.cc:177] computation placer already registered. Please check linka

In [2]:
import sys
sys.path.append('../')
from src import SystemConfig, UeConfig, MyConfig, MySimulator, MyPUSCHConfig, tic, toc

In [3]:
from src.my_pusch_config import MyPUSCHConfig
from src.my_encoder import MyTBEncoder
from src.my_decoder import MyTBDecoder

from sionna.phy.mapping import BinarySource, Mapper
from sionna.phy.nr import PUSCHPilotPattern, LayerMapper, LayerDemapper, PUSCHLSChannelEstimator
from sionna.phy.ofdm import ResourceGrid, ResourceGridMapper, LinearDetector
from sionna.phy.nr.utils import generate_prng_seq
from sionna.phy.channel import AWGN, OFDMChannel, gen_single_sector_topology as gen_topology
from sionna.phy.mimo import StreamManagement
from sionna.phy.channel.tr38901 import Antenna, AntennaArray, UMi, UMa, CDL
from sionna.phy import Block


import tensorflow as tf
import numpy as np

In [4]:
sys_cfg = SystemConfig(
    NCellId=442
)
ue_cfg = UeConfig(NPrb=23)
my_cfg = MyConfig(sys_cfg, ue_cfg)
my_pusch_cfg = MyPUSCHConfig(my_config=my_cfg)

In [5]:
my_pusch_cfg.show()

Carrier Configuration
cyclic_prefix : normal
cyclic_prefix_length : 2.3437500000000002e-06
frame_duration : 0.01
frame_number : 0
kappa : 64.0
mu : 1
n_cell_id : 442
n_size_grid : 162
n_start_grid : 0
num_slots_per_frame : 20
num_slots_per_subframe : 2
num_symbols_per_slot : 14
slot_number : 4
sub_frame_duration : 0.001
subcarrier_spacing : 30
t_c : 5.086263020833334e-10
t_s : 3.2552083333333335e-08

PUSCH Configuration
dmrs_grid : shape (1, 276, 14)
dmrs_grid_precoded : shape ()
dmrs_mask : shape (276, 14)
dmrs_symbol_indices : [3, 11]
first_resource_block : 0
first_subcarrier : 0
frequency_hopping : neither
l : [3, 11]
l_0 : 3
l_bar : [3, 11]
l_d : 14
l_prime : [0]
l_ref : 0
mapping_type : A
n : shape (69,)
n_rnti : 20002
n_size_bwp : 162
n_start_bwp : 0
num_antenna_ports : 1
num_coded_bits : 6624
num_layers : 1
num_ov : 0
num_res_per_prb : 144
num_resource_blocks : 23
num_subcarriers : 276
phy_cell_id : 442
precoding : non-codebook
precoding_matrix : None
symbol_allocation : [0, 14]

In [6]:
my_pusch_cfg.c_init(3)

517473140

In [7]:
# # @tf.function(input_signature=[tf.TensorSpec([], tf.int32) for _ in range(5)],
# #     jit_compile=True)
# def c_init_tf(l, n_id, n_scid, slot_number, num_symbols_per_slot):
#     """TensorFlow version of c_init from 3GPP 38.211"""
#     term1 = 2**17 * (
#         num_symbols_per_slot * slot_number + l + 1
#     ) * (2 * n_id + 1)
#     term2 = 2**17 * 0
#     term3 = 2 * n_id + n_scid
#     print(term1,term2,term3)
#     c_init = tf.math.mod(term1 + term2 + term3, 2**31)
    
#     return c_init

#     # print("num_symbols_per_slot", num_symbols_per_slot, "slot_number", slot_number, "n_id", n_id, "n_scid_bar", n_scid)

#     # lambda_bar = tf.constant(0, dtype=tf.float32)
#     # l = tf.cast(l, tf.int64)
#     # n_id = tf.cast(n_id, tf.int64)
#     # n_scid = tf.cast(n_scid, tf.int64)
#     # slot_number = tf.cast(slot_number, tf.int64)
#     # num_symbols_per_slot = tf.cast(num_symbols_per_slot, tf.int64)

#     # term1 = tf.constant(2**17, dtype=tf.int64) * (
#     #     num_symbols_per_slot * slot_number + l + 1
#     # ) * (2 * n_id + 1)

#     # term2 = tf.constant(2**17, dtype=tf.int64) * tf.cast(tf.floor(lambda_bar / 2), tf.int64)

#     # term3 = 2 * n_id + n_scid

#     # print("terms", term1, term2, term3)

#     # c_init = (term1 + term2 + term3) % tf.constant(2**31, dtype=tf.int64)
#     # print("c_init", c_init)
#     # return tf.cast(c_init, tf.int32)

In [8]:
your_pusch_config = my_pusch_cfg.clone()

In [9]:
my_simulator = MySimulator(my_pusch_cfg)
my_simulator.reference

I0000 00:00:1746781688.417945   70065 service.cc:152] XLA service 0x3a0db860 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1746781688.417987   70065 service.cc:160]   StreamExecutor device (0): Host, Default Version
I0000 00:00:1746781688.564675   70065 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


<tf.Tensor: shape=(1, 1, 14, 276), dtype=complex64, numpy=
array([[[[ 0.+0.j,  0.+0.j,  0.+0.j, ...,  0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j, ...,  0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j, ...,  0.+0.j,  0.+0.j,  0.+0.j],
         ...,
         [-1.+1.j,  0.+0.j,  1.-1.j, ...,  0.+0.j,  1.+1.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j, ...,  0.+0.j,  0.+0.j,  0.+0.j],
         [ 0.+0.j,  0.+0.j,  0.+0.j, ...,  0.+0.j,  0.+0.j,  0.+0.j]]]],
      dtype=complex64)>

In [10]:
b, c, x = my_simulator(1, prng_seed=20044, return_items=('b', 'c', 'x'))
np.sum(b), np.sum(c), np.sum(x)

(394.0, 3321.0, (-8.485285-20.242645j))

In [15]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import tensorflow as tf
import tempfile
import random
import os

class MyWriter:
    def __init__(self, num_shards=256,
                 samples_per_shard=1024,
                 output_dir="tfrecords",
                 isTempfile=False):

        self.samples_per_shard = samples_per_shard
        self.num_shards = num_shards

        self.output_dir = output_dir
        if isTempfile:
            self.temp_dir = tempfile.TemporaryDirectory()
            self.output_dir = os.path.join(self.temp_dir.name, output_dir)
        
        os.makedirs(self.output_dir, exist_ok=True)

    def _serialize_example(self, feature_tensors):
        def _bytes_feature(tensor):
            return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(tensor).numpy()]))
        
        # Tạo dictionary feature từ danh sách tên và tensor tương ứng
        feature_dict = {
            name: _bytes_feature(tensor) 
            for name, tensor in zip(self.return_items, feature_tensors)
        }
        return tf.train.Example(features=tf.train.Features(feature=feature_dict)).SerializeToString()

    # @tf.function(jit_compile=True)
    # def update_fn(rnti, slot_num, pci, rb_start):
            
    @tf.function(jit_compile=True)
    def _get_example(self, simulator):
        bz = 1
        b, c, y, x, h = simulator(bz, return_items=('b', 'c', 'y', 'x', 'h'))
        r = simulator.get_r_rg(bz)
        c = simulator.get_c_rg(c)
        
        output_map = {
            'b': b[0],
            'c': c[0],
            'y': y[0],
            'x': x[0],
            'r': r[0],
            'h': h[0]
        }
        return tuple(output_map[item] for item in self.return_items)

    def _write_shard(self, shard_id, progress_callback, simulator):
        shard_path = os.path.join(self.output_dir, f"data-{shard_id:04d}-of-{self.num_shards:04d}.tfrecord")
        buffer = []
        with tf.io.TFRecordWriter(shard_path) as tfwriter:
            for _ in range(self.samples_per_shard):
                # tic()
                # Lấy tất cả các tensor từ simulator
                simulator.update(rnti=1230+shard_id,
                        slot_num=[4,5,14,15][shard_id%4],
                        pci=100+shard_id,
                        rb_start=12
                )
                # toc("total update")
                
                # tic()
                example_tensors = self._get_example(simulator)
                # toc("get example")

                # tf.print("=======================")
                # Serialize và ghi vào buffer
                serialized = self._serialize_example(example_tensors)
                buffer.append(serialized)

                # Ghi buffer khi đủ 256 mẫu
                if len(buffer) == 256:
                    for example in buffer:
                        tfwriter.write(example)
                        progress_callback()
                    buffer.clear()

            # Ghi phần còn lại trong buffer
            for example in buffer:
                tfwriter.write(example)
                progress_callback()
            

    def write(self, simulator, return_items):        
        self.return_items = return_items

        simulator = simulator.clone()
        _ = simulator(1)  # Khởi tạo simulator
        
        original_stdout = sys.stdout
        with open('output.txt', 'w') as f:
            sys.stdout = f
            simulator.pusch_config.show()
        sys.stdout = original_stdout

        total_samples = self.num_shards * self.samples_per_shard
        pbar = tqdm(total=total_samples, desc="Generating TFRecords", unit=" sample")
        
        def progress_callback_gen():
            # pass
            return lambda: pbar.update(1)

        # Gọi trực tiếp _write_shard cho từng shard
        for shard_id in range(self.num_shards):
            progress_callback = progress_callback_gen()
            self._write_shard(shard_id, progress_callback, simulator)

        # pbar.close()

In [16]:
tf.config.run_functions_eagerly(False)

In [17]:
writer = MyWriter(num_shards=4, samples_per_shard=32, output_dir="tfrecords")

In [None]:
writer.write(simulator=my_simulator, return_items=('b', 'c', 'y', 'x', 'r', 'h'))

: 

In [15]:
@tf.function(
    input_signature=[
        tf.TensorSpec([], tf.int32),
        tf.TensorSpec([], tf.int32)
    ],
    jit_compile=True)
def generate_prng_seq_tf(length, c_init):
    n_seq = 31
    n_c = 1600
    total_len = length + n_c + n_seq

    # Convert c_init to 31-bit tensor (LSB first)
    c_init_bits = tf.bitwise.bitwise_and(
        tf.bitwise.right_shift(c_init, tf.range(n_seq, dtype=tf.int32)),
        1
    )

    # Initialize x1 and x2 as TensorArrays
    x1 = tf.TensorArray(dtype=tf.int32, size=total_len, dynamic_size=False, clear_after_read=False)
    x2 = tf.TensorArray(dtype=tf.int32, size=total_len, dynamic_size=False, clear_after_read=False)

    # Set initial conditions
    x1 = x1.write(0, 1)
    for i in range(1, n_seq):
        x1 = x1.write(i, 0)
    for i in range(n_seq):
        x2 = x2.write(i, c_init_bits[i])

    # Define the loop body
    def body(idx, x1, x2):
        x1_val = tf.bitwise.bitwise_and(x1.read(idx + 3) + x1.read(idx), 1)
        x2_val = tf.bitwise.bitwise_and(
            x2.read(idx + 3) + x2.read(idx + 2) + x2.read(idx + 1) + x2.read(idx),
            1
        )
        x1 = x1.write(idx + n_seq, x1_val)
        x2 = x2.write(idx + n_seq, x2_val)
        return idx + 1, x1, x2

    # Run the loop
    idx = 0
    cond = lambda i, *_: i < (length + n_c)
    idx, x1, x2 = tf.while_loop(cond, body, [idx, x1, x2])

    # Compute c = x1[n_c:n_c+length] + x2[n_c:n_c+length] mod 2
    c = tf.TensorArray(dtype=tf.int32, size=length)
    for i in range(length):
        val = tf.bitwise.bitwise_and(x1.read(i + n_c) + x2.read(i + n_c), 1)
        c = c.write(i, val)

    return c.stack()

In [16]:
@tf.function(
    input_signature=[
        tf.TensorSpec([], tf.int32),
        tf.TensorSpec([], tf.int32)
    ],
    jit_compile=True)
def generate_prng_seq_tf_2(length, c_init):
    n_seq = 31
    n_c = 1600
    total_len = length + n_c + n_seq

    # Convert c_init to 31-bit tensor (LSB first)
    c_init_bits = tf.bitwise.bitwise_and(
        tf.bitwise.right_shift(c_init, tf.range(n_seq, dtype=tf.int32)),
        1
    )

    # 1. Chuẩn bị mảng ban đầu
    x1_init = tf.concat([
        tf.constant([1], dtype=tf.int32),
        tf.zeros(n_seq - 1, dtype=tf.int32),
        tf.zeros(total_len - n_seq, dtype=tf.int32)
    ], axis=0)

    x2_init = tf.concat([
        c_init_bits,
        tf.zeros(total_len - n_seq, dtype=tf.int32)
    ], axis=0)

    # 2. Khởi tạo TensorArray bằng unstack
    x1 = tf.TensorArray(dtype=tf.int32, size=total_len, clear_after_read=False)
    x1 = x1.unstack(x1_init)

    x2 = tf.TensorArray(dtype=tf.int32, size=total_len, clear_after_read=False)
    x2 = x2.unstack(x2_init)

    # Define the loop body
    def body(idx, x1, x2):
        # Cache các giá trị
        x1_i = x1.read(idx)
        x1_i3 = x1.read(idx + 3)
        x1_val = tf.bitwise.bitwise_xor(x1_i, x1_i3)

        x2_0 = x2.read(idx)
        x2_1 = x2.read(idx + 1)
        x2_2 = x2.read(idx + 2)
        x2_3 = x2.read(idx + 3)
        x2_val = tf.bitwise.bitwise_xor(
            tf.bitwise.bitwise_xor(x2_0, x2_1),
            tf.bitwise.bitwise_xor(x2_2, x2_3)
        )

        x1 = x1.write(idx + n_seq, x1_val)
        x2 = x2.write(idx + n_seq, x2_val)
        return idx + 1, x1, x2

    # Run the loop
    idx = 0
    cond = lambda i, *_: i < (length + n_c)
    idx, x1, x2 = tf.while_loop(cond, body, [idx, x1, x2])

    c = tf.TensorArray(dtype=tf.int32, size=length)
    for i in range(length):
        x1_inc = x1.read(i + n_c)
        x2_inc = x2.read(i + n_c)
        val = tf.bitwise.bitwise_xor(x1_inc, x2_inc)
        c = c.write(i, val)

    return c.stack()

In [18]:
def generate_prng_seq(length, c_init):
    r"""Implements pseudo-random sequence generator as defined in Sec. 5.2.1
    in [3GPP38211]_ based on a length-31 Gold sequence.

    Parameters
    ----------
    length: `int`
        Desired output sequence length

    c_init: `int`
        Initialization sequence of the PRNG. Must be in the range of 0 to
        :math:`2^{32}-1`.

    Output
    ------
    :[``length``], `ndarray` of 0s and 1s
        Containing the scrambling sequence

    Note
    ----
    The initialization sequence ``c_init`` is application specific and is
    usually provided be higher layer protocols.
    """

    # check inputs for consistency
    assert(length%1==0), "length must be a positive integer."
    length = int(length)
    assert(length>0), "length must be a positive integer."

    assert(c_init%1==0), "c_init must be integer."
    c_init = int(c_init)
    assert(c_init<2**32), "c_init must be in [0, 2^32-1]."
    assert(c_init>=0), "c_init must be in [0, 2^32-1]."

    # internal parameters
    n_seq = 31 # length of gold sequence
    n_c = 1600 # defined in 5.2.1 in 38.211

    # init sequences
    c = np.zeros(length)
    x1 = np.zeros(length + n_c + n_seq)
    x2 = np.zeros(length + n_c + n_seq)

    #int2bin
    bin_ = format(c_init, f'0{n_seq}b')
    c_init = [int(x) for x in bin_[-n_seq:]] if n_seq else []
    c_init = np.flip(c_init) # reverse order
    # init x1 and x2
    x1[0] = 1
    x2[0:n_seq] = c_init


    # and run the generator
    for idx in range(length + n_c):
        x1[idx+31] = np.mod(x1[idx+3] + x1[idx], 2)
        x2[idx+31] = np.mod(x2[idx+3] + x2[idx+2] + x2[idx+1] + x2[idx], 2)

    # update output sequence
    for idx in range(length):
        c[idx] = np.mod(x1[idx+n_c] + x2[idx+n_c], 2)

    return c

In [19]:
generate_prng_seq_tf_2(44, 1000).numpy(), generate_prng_seq_tf(44, 1000).numpy()

(array([0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1,
        1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0],
       dtype=int32),
 array([0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1,
        1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0],
       dtype=int32))

In [20]:
import timeit

for i in range(3000,3010):
    t = timeit.timeit(lambda: generate_prng_seq_tf(1200, i), number=1)
    print(f"i = {i}, time = {t:.6f} s")
    t = timeit.timeit(lambda: generate_prng_seq_tf(1200, i), number=1)
    print(f"i = {i}, time = {t:.6f} s")
    t = timeit.timeit(lambda: generate_prng_seq_tf_2(1200, i), number=1)
    print(f"i = {i}, time = {t:.6f} s")
    t = timeit.timeit(lambda: generate_prng_seq_tf_2(1200, i), number=1)
    print(f"i = {i}, time = {t:.6f} s")
    print("="*32)

i = 3000, time = 0.234441 s
i = 3000, time = 0.004801 s
i = 3000, time = 0.155830 s
i = 3000, time = 0.002451 s
i = 3001, time = 0.002427 s
i = 3001, time = 0.002217 s
i = 3001, time = 0.002474 s
i = 3001, time = 0.002552 s
i = 3002, time = 0.002433 s
i = 3002, time = 0.002289 s
i = 3002, time = 0.002233 s
i = 3002, time = 0.002212 s
i = 3003, time = 0.002299 s
i = 3003, time = 0.002123 s
i = 3003, time = 0.002171 s
i = 3003, time = 0.002099 s
i = 3004, time = 0.002181 s
i = 3004, time = 0.002078 s
i = 3004, time = 0.002090 s
i = 3004, time = 0.002072 s
i = 3005, time = 0.002923 s
i = 3005, time = 0.003832 s
i = 3005, time = 0.002491 s
i = 3005, time = 0.002345 s
i = 3006, time = 0.003984 s
i = 3006, time = 0.002340 s
i = 3006, time = 0.002190 s
i = 3006, time = 0.002218 s
i = 3007, time = 0.002217 s
i = 3007, time = 0.002111 s
i = 3007, time = 0.002102 s
i = 3007, time = 0.002144 s
i = 3008, time = 0.002190 s
i = 3008, time = 0.002173 s
i = 3008, time = 0.002299 s
i = 3008, time = 0.0

In [108]:
%timeit generate_prng_seq_tf(96, 100)

1.46 ms ± 51.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [107]:
%timeit generate_prng_seq_tf_2(96, 100)

1.41 ms ± 46.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [109]:
%timeit generate_prng_seq(34, 122)

6.18 ms ± 342 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [111]:
for i in range(2000,2050):
    print(np.var(generate_prng_seq_tf(96, i).numpy() - generate_prng_seq(96, i)), end=" ")
    print(np.var(generate_prng_seq_tf_2(96, i).numpy() - generate_prng_seq(96, i)), end=" ")
    print("\n")

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 

0.0 0.0 



In [20]:
class MyLoader:
    def __init__(self, shards_dir, batch_size=64, shuffle_buffer_size=10000, return_request=None):
        if return_request is None:
            return_request = {
                'b': tf.float32,
                'c': tf.float32,
                'y': tf.complex64,
                'x': tf.complex64,
                'r': tf.complex64,
                'h': tf.complex64,
                # 'rnti': tf.int32,
                # 'slot_num': tf.int32,
                # 'pci': tf.int32,
                # 'rb_start': tf.int32,
                # 'snr': tf.float32
            }
        
        self.shards_dir = shards_dir
        self.batch_size = batch_size
        self.shuffle_buffer_size = shuffle_buffer_size
        self.return_request = return_request
        self.dataset = self._load_data()

    def _parse_tfrecord_fn(self, example_proto):
        feature_description = {item_name: tf.io.FixedLenFeature([], tf.string) for item_name in self.return_request.keys()}
        parsed_tmp = tf.io.parse_single_example(example_proto, feature_description)
        parsed = {item_name: tf.io.parse_tensor(parsed_tmp[item_name], out_type=item_type) 
                  for item_name, item_type in self.return_request.items()}
        return parsed

    def _load_data(self):
        file_pattern = os.path.join(self.shards_dir, "*.tfrecord")
        dataset = tf.data.Dataset.list_files(file_pattern, shuffle=True)

        dataset = dataset.interleave(
            lambda filename: tf.data.TFRecordDataset(filename),
            cycle_length=tf.data.AUTOTUNE,
            num_parallel_calls=tf.data.AUTOTUNE
        )

        dataset = dataset.map(self._parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.shuffle(self.shuffle_buffer_size)
        dataset = dataset.batch(self.batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        dataset = dataset.cache()
        return dataset

    def get_dataset(self):
        return self.dataset

    def take(self, num_batches):
        return self.dataset.take(num_batches)

In [24]:
# Khởi tạo DataLoader
data_loader = MyLoader(shards_dir='tfrecords', batch_size=1)

# Lấy dataset từ DataLoader
ds = data_loader.get_dataset()

# Duyệt thử với update_fn
clone3 = my_simulator.clone()

@tf.function(jit_compile=True)
def update_fn(rnti, slot_num, pci, rb_start):
    clone3.update(rnti=rnti,
                  slot_num=slot_num,
                  pci=pci,
                  rb_start=rb_start)

In [26]:
for batch in ds.take(10):
    for key, value in batch.items():
        if len(value.shape) == 1: print(key, value)
        else: print(key, value.shape)

    update_fn(rnti=123+7*0,
                        slot_num=5,
                        pci=1+0,
                        rb_start=12)

    h_hat, llr_det, b_hat, crc = clone3.rec(batch["y"])
    print(f"-----{crc}-----")

b (1, 1, 808)
c (1, 1, 1, 14, 276, 2)
y (1, 1, 4, 14, 276)
x (1, 1, 1, 14, 276)
r (1, 1, 1, 14, 276)
h (1, 1, 4, 1, 1, 14, 276)


2025-05-06 04:45:38.292083: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


TypeError: in user code:

    File "/tmp/ipykernel_84399/1755984049.py", line 12, in update_fn  *
        clone3.update(rnti=rnti,
    File "/workspaces/dsp/notebook/../src/my_simulator.py", line 177, in update  *
        self.tbEnc.scrambler.c_init = pusch_config._scb_c_init
    File "/workspaces/dsp/notebook/../src/my_encoder.py", line 356, in c_init
        self.sequence = self._generate_scrambling(self._input_shape)
    File "/workspaces/dsp/.venv/lib/python3.10/site-packages/sionna/phy/fec/scrambling.py", line 416, in _generate_scrambling
        seq = generate_prng_seq(input_shape[-1], self._c_init[0])

    TypeError: 'NoneType' object is not subscriptable
