In [None]:
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

This notebook trains the deep learning-based NPRACH synchronization algorithm from [AIT] considering a 3GPP UMi channel and using the [Sionna link-level simulater](https://nvlabs.github.io/sionna/).

[AIT] https://arxiv.org/abs/2205.10805

In [None]:
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print('Number of GPUs available :', len(gpus))
if gpus:
    gpu_num = 0 # Index of the GPU to use
    try:
        tf.config.set_visible_devices(gpus[gpu_num], 'GPU')
        print('Only GPU number', gpu_num, 'used.')
        tf.config.experimental.set_memory_growth(gpus[gpu_num], True)
    except RuntimeError as e:
        print(e)

In [None]:
import sionna as sn
sn.config.xla_compat = True

import numpy as np
import pickle
import datetime

from parameters import *
from e2e import E2E

## Utility function for saving the weights

In [None]:
def save_weights(sys):
    with open(DEEPNSYNCH_WEIGHTS, "wb") as f:
        pickle.dump(sys.get_weights(), f)

## Training loop

In [None]:
def training_loop(sys):
    optimizer = tf.optimizers.Adam()
    
    @tf.function(jit_compile=True)
    def training_step():
        with tf.GradientTape() as tape:
            # Forward pass
            loss_tx_ue, loss_toa, loss_cfo = sys(BATCH_SIZE_TRAIN)
            # Loss aggregation
            loss = loss_tx_ue + loss_toa + loss_cfo
        # Compute and apply gradients
        grads = tape.gradient(loss, tape.watched_variables())
        optimizer.apply_gradients(zip(grads, tape.watched_variables()))
        #
        return loss_tx_ue, loss_toa, loss_cfo
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    test_summary_writer = tf.summary.create_file_writer(f'logs/{current_time}')
    with test_summary_writer.as_default():
        for i in range(NUM_IT_TRAIN):
            loss_tx_ue, loss_toa, loss_cfo = training_step()
            # Periodically print update
            if (i%128) == 0:
                tf.summary.scalar('Det', loss_tx_ue.numpy(), step=i)
                tf.summary.scalar('WMSE ToA', loss_toa.numpy(), step=i)
                tf.summary.scalar('WMSE CFO', loss_cfo.numpy(), step=i)
            # Periodically save the weights
            if (i%1024) == 0:
                save_weights(sys)

    save_weights(sys)

## Training

In [None]:
tf.random.set_seed(42)
sys = E2E('dl', True, nprach_num_rep=NPRACH_NUM_REP, nprach_num_sc=NPRACH_NUM_SC)
training_loop(sys)