diff --git a/sionna/utils/misc.py b/sionna/utils/misc.py index 3415167d..d7668b4c 100644 --- a/sionna/utils/misc.py +++ b/sionna/utils/misc.py @@ -410,7 +410,8 @@ def sim_ber(mc_fun, graph_mode=None, verbose=True, forward_keyboard_interrupt=True, - dtype=tf.complex64): + dtype=tf.complex64, + callback=None): """Simulates until target number of errors is reached and returns BER/BLER. The simulation continues with the next SNR point if either @@ -420,7 +421,7 @@ def sim_ber(mc_fun, Input ----- - mc_fun: + mc_fun: callable Callable that yields the transmitted bits `b` and the receiver's estimate `b_hat` for a given ``batch_size`` and ``ebno_db``. If ``soft_estimates`` is True, b_hat is interpreted as @@ -470,6 +471,22 @@ def sim_ber(mc_fun, dtype: tf.complex64 Datatype of the model / function to be used (``mc_fun``). + callback: callable + Defaults to `None`. If specified, ``callback`` + will be called after each Monte-Carlo step. Can be used for + logging or advanced early stopping. + Input signature of ``callback`` must match `callback(mc_iter, + ebno_dbs, bit_errors, block_errors, nb_bits, nb_blocks)` where + ``mc_iter`` denotes the number of processed batches for the current + SNR, ``ebno_dbs`` is the current SNR point, ``bit_errors`` the number + of bit errors, ``block_errors`` the number of block errors, ``nb_bits`` + the number of simulated bits, ``nb_blocks`` the number of simulated + blocks. If ``callable`` returns `sim_ber.CALLBACK_NEXT_SNR`, early + stopping is detected and the simulation will continue with the next SNR + point. If ``callable`` returns `sim_ber.CALLBACK_STOP`, the simulation + is stopped immediately. For `sim_ber.CALLBACK_CONTINUE` continues with + the simulation. + Output ------ (ber, bler) : @@ -567,7 +584,8 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None): "reached max iter ", # status=1; spacing for impr. layout "no errors - early stop", # status=2 "reached target bit errors", # status=3 - "reached target block errors"] # status=4 + "reached target block errors", # status=4 + "callback triggered stopping"] # status=5 # check inputs for consistency assert isinstance(early_stop, bool), "early_stop must be bool." @@ -658,6 +676,18 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None): nb_blocks = tf.tensor_scatter_nd_add( nb_blocks, [[i]], tf.cast([block_n], tf.int64)) + cb_state = sim_ber.CALLBACK_CONTINUE + if callback is not None: + cb_state = callback (ii, ebno_dbs[i], bit_errors[i], + block_errors[i], nb_bits[i], + nb_blocks[i]) + if cb_state in (sim_ber.CALLBACK_STOP, + sim_ber.CALLBACK_NEXT_SNR): + # stop runtime timer + runtime[i] = time.perf_counter() - runtime[i] + status[i] = 5 # change internal status for summary + break # stop for this SNR point have been simulated + # print progress summary if verbose: # print summary header during first iteration @@ -714,6 +744,14 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None): print("\nSimulation stopped as no error occurred " \ f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n") break + # allow callback to end the entire simulation + if cb_state is sim_ber.CALLBACK_STOP: + # stop runtime timer + status[i] = 5 # change internal status for summary + if verbose: + print("\nSimulation stopped by callback funtion " \ + f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n") + break # Stop if KeyboardInterrupt is detected and set remaining SNR points to -1 except KeyboardInterrupt as e: @@ -745,6 +783,9 @@ def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None): return ber, bler +sim_ber.CALLBACK_CONTINUE = None +sim_ber.CALLBACK_STOP = 2 +sim_ber.CALLBACK_NEXT_SNR = 1 def complex_normal(shape, var=1.0, dtype=tf.complex64): r"""Generates a tensor of complex normal random variables.