# bilik bom

## Imports

In [86]:
import argparse
import sys
from collections import OrderedDict
import datetime
import scipy
import os
import shutil
import glob
import pickle
from pathlib import Path
import tensorflow as tf
import logging
import mlflow
import numpy as np
from tensorflow.keras import metrics
from tensorflow.keras.optimizers import Adam
import commentjson
from random import randint
from bunch import Bunch
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, Callback, EarlyStopping
logger = logging.getLogger("logger")
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Input
from tensorflow.keras.layers import Dropout,  Activation, LeakyReLU, AveragePooling1D



## Config functions 

### Define Config class

In [87]:
CONFIG_VERBOSE_WAIVER = ['save_model', 'tracking_uri', 'quiet', 'sim_dir', 'train_writer', 'test_writer', 'valid_writer']
class Config(Bunch):
    """ class for handling dicrionary as class attributes """

    def __init__(self, *args, **kwargs):
        super(Config, self).__init__(*args, **kwargs)

    def print(self):
        line_len = 122
        line = "-" * line_len
        logger.info(line + "\n" +
              "| {:^35s} | {:^80} |\n".format('Feature', 'Value') +
              "=" * line_len)
        for key, val in sorted(self.items(), key= lambda x: x[0]):
            if isinstance(val, OrderedDict):
                raise NotImplementedError("Nested configs are not implemented")
            else:
                if key not in CONFIG_VERBOSE_WAIVER:
                    logger.info("| {:35s} | {:80} |\n".format(key, str(val)) + line)
        logger.info("\n")



### define arguments

In [88]:
def get_args(argv):
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument('--config', default=None, type=str, help='path to config file')
    argparser.add_argument('--seed', default=None, type=int, help='randomization seed')
    argparser.add_argument('--exp_name', default=None, type=int, help='Experiment name')
    argparser.add_argument('--num_targets', default=None, type=int, help='Number of simulated targets')
    argparser.set_defaults(quiet=False)
    args, unknown = argparser.parse_known_args(argv)
    #args = argparser.parse_args()

    return args


### Read the config file

In [89]:
def read_json_to_dict(fname):
    """ read json config file into ordered-dict """
    fname = Path(fname)
    with fname.open('rt') as handle:
        config_dict = commentjson.load(handle, object_hook=OrderedDict)
        return config_dict
    

In [90]:
def read_config(args):
    """ read config from json file and update by the command line arguments """
    if args.config is not None:
        json_file = args.config
    else:
        json_file = "/Users/arigra/Desktop/DL projects/Dafc/config.json"  # Replace with your default config file path

    config_dict = read_json_to_dict(json_file)
    config = Config(config_dict)

    for arg in sorted(vars(args)):
        key = arg
        val = getattr(args, arg)
        if val is not None:
            setattr(config, key, val)

    if args.seed is None and config.seed is None:
        
        MAX_SEED = sys.maxsize
        config.seed = randint(0, MAX_SEED)

    return config


### GPU initialization

In [91]:
def gpu_init():
    """ Allows GPU memory growth """

    gpus = tf.config.experimental.list_physical_devices('GPU')
    logger.info("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            logger.info(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            logger.info("MESSAGE", e)


### Logger and tracker

In [92]:
def set_logger_and_tracker(config):
    ''' configure the mlflow tracker:
        1. set tracking location (uri)
        2. configure exp name/id
        3. define parameters to be documented
    '''

    config.exp_name_time = "{}_{}_{}".format(config.exp_name,datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),config.seed)
    config.tensor_board_dir = os.path.join('..',
                                           'results',
                                           config.exp_name,
                                           config.exp_name_time)

    if not os.path.exists(config.tensor_board_dir):
        os.makedirs(config.tensor_board_dir)


In [93]:
def save_scripts(config,SRC_DIR):
    path = os.path.join(config.tensor_board_dir, 'scripts')
    if not os.path.exists(path):
        os.makedirs(path)
    scripts_to_save = glob.glob('{}/**/*.py'.format(SRC_DIR), recursive=True) + [config.config]
    scripts_to_save = [script for script in scripts_to_save if '{}/results'.format(SRC_DIR) not in script]
    if scripts_to_save is not None:
        for script in scripts_to_save:
            dst_file = os.path.join(path, os.path.basename(script))
            try:
                shutil.copyfile(os.path.join(os.path.dirname(sys.argv[0]), script), dst_file)
            except:
                print()

In [94]:

def print_config(config):
    print('')
    print('#' * 70)
    print('Configurations at beginning of run')
    print('#' * 70)
    for key in config.keys():
        print('{}, {}'.format(key,config['{}'.format(key)]))
    print('')
    print('')

In [95]:

SRC_DIR = os.getcwd()
config_path = '/Users/arigra/Desktop/DL projects/Dafc/config.json'  
args = get_args(config_path)
config = read_config(args)
gpu_init()
set_logger_and_tracker(config)
#save_scripts(config,SRC_DIR)
print_config(config)



######################################################################
Configurations at beginning of run
######################################################################
model_input_dim, [None]
model_output_dim, 1
quiet, False
seed, 3775404699870012136
exp_name, temp
load_complete_model, False
load_model_path, 
con_inf_rng_path, 
con_inf_vel_path, 
eval_model_pth, 
detection_pfa_miss_M_valid, 5000
detection_exp_type, pd
ipix_pkl_path, 
ipix_pkl_path_dir, /Users/arigra/Desktop/DL projects/Dafc/datasets/IPIX/15m/pkl/hh
ipix_pkl_cv_hold_out, 
ipix_cdf_files_list, []
ipix_skip_cv_iters, []
ipix_predefined_cv_iters, []
ipix_cv_mode, False
ipix_cv_script, main_train
ipix_cv_rng_pth, ../results/IPIX_3m_HH_K64_8targets_CV_twostage_fc_rng/IPIX_3m_HH_K64_8targets_CV_twostage_fc_rng_2022-03-12_17-29-13_405075
ipix_cv_vel_pth, ../results/IPIX_3m_HH_K64_8targets_CV_twostage_fc_vel/IPIX_3m_HH_K64_8targets_CV_twostage_fc_vel_2022-03-12_09-02-17_826518
sweep_run_eval_con_inf, False
sweep_run_e

#### config properties

In [96]:
config_properties = list(config)
config_properties.sort()

for attr_name in config_properties:
    attr_value = config[attr_name]
    print(f'{attr_name}: {attr_value}')


B_chirp: 50000000.0
CBBCE_penalize_interference: False
CBBCE_penalize_margin: 5
CBBCE_penalize_snr_use_geom_space: False
CBBCE_predefined_weight: 0
CBBCE_use_penalize_margin: False
CBBCE_use_penalize_snr: False
CNR_db: 15
FOV: 60
K: 64
L: 10
M_test: 10000
M_train: 10000
M_valid: 1000
N: 64
SCNR_db: -5
SCNR_db_random_choice: False
SCNR_db_random_constant: False
SCNR_db_range: [-5, 10]
SCNRs_eval: [-10, -7.5, -5, -2.5, 0, 2.5, 5, 7.5, 10]
T_PRI: 0.001
T_idle: 5e-05
activation: tanh
activation_sweep_list: ['relu', 'tanh']
additive_noise_std: 1.0
augment_list: []
augment_prob: 0.5
batch_size: 256
batch_size_sweep_list: [32, 64, 128, 256, 512]
beamforming_method_sweep_list: ['MVDR', 'MUSIC', 'MLE']
cfar_guard_cell: [0.1, 0.1, 0.1]
cfar_method: ca
cfar_num_censor_cells_largest: 0.25
cfar_num_censor_cells_largest_sweep_list: [0.1, 0.25, 0.5, 0.75]
cfar_num_censor_cells_smallest: 0.25
cfar_os_order_statistic: 0.5
cfar_os_order_statistic_sweep_list: [0.25, 0.5, 0.75]
cfar_single_param: []
cfar_

## Data loading


In [97]:
def get_dataset_ipix(config, apply_tf_preprocess_pipe=True, split_data=True, return_dict=False):

    if config.ipix_cv_mode:
        # load train and validation
        data_dict_per_file = {}
        cdf_files_list = [f for f in os.listdir(config.ipix_pkl_path_dir) if not f.startswith('.')]
        cdf_files_list = [x for x in cdf_files_list if x not in config.ipix_pkl_cv_hold_out]
        for cdf_file in cdf_files_list:
            config.ipix_pkl_path = os.path.join(config.ipix_pkl_path_dir, cdf_file)
            # read raw data and convert to fast-time x slow-time complex data
            c_tensor_total, rng_bins_ipix, clutter_vel = read_data_ipix(config)
            config.r_0_max = rng_bins_ipix[-1]
            M_valid = int(config.M_valid / len(cdf_files_list))
            M0_valid = int(config.without_target_ratio_test * M_valid)
            M1_valid = M_valid
            M_train = int(config.M_train / len(cdf_files_list))
            M0_train = int(config.without_target_ratio * M_train)
            M1_train = M_train
            # generate tf.data.Dataset objects
            data = {}
            c_tensor_total_valid = c_tensor_total[:, int(0.9 * c_tensor_total.shape[1]):]
            data['valid'] = gen_ipix_pipeline_dataset(c_tensor_total_valid, clutter_vel, config, M0_valid, M1_valid)

            c_tensor_total_train = c_tensor_total[:, :int(0.9 * c_tensor_total.shape[1])]
            data['train'] = gen_ipix_pipeline_dataset(c_tensor_total_train, clutter_vel, config, M0_train, M1_train)

            if apply_tf_preprocess_pipe:
                # add data pipeline functions (maps)
                for set_type in ['train', 'valid']:
                    data[set_type] = tf_dataset_pipeline(config, data[set_type])

            data_dict_per_file[cdf_file] = data

        data = {}
        for set_type in ['train', 'valid']:
            data[set_type] = data_dict_per_file[cdf_files_list[0]][set_type]
            for cdf_file in cdf_files_list[1:]:
                data[set_type] = data[set_type].concatenate(data_dict_per_file[cdf_file][set_type])

        # load test
        data_dict_per_file = {}
        for cdf_file in config.ipix_pkl_cv_hold_out:
            assert cdf_file not in cdf_files_list
            config.ipix_pkl_path = os.path.join(config.ipix_pkl_path_dir, cdf_file)
            c_tensor_total_test, rng_bins_ipix, clutter_vel = read_data_ipix(config)
            M_test = int(config.M_test / len(config.ipix_pkl_cv_hold_out))
            M0_test = int(config.without_target_ratio_test * M_test)
            M1_test = M_test
            data['test'] = gen_ipix_pipeline_dataset(c_tensor_total_test, clutter_vel, config, M0_test, M1_test)

            if apply_tf_preprocess_pipe:
                data['test'] = tf_dataset_pipeline(config, data['test'])

            data_dict_per_file[cdf_file] = data['test']

        data['test'] = data_dict_per_file[config.ipix_pkl_cv_hold_out[0]]
        for cdf_file in config.ipix_pkl_cv_hold_out[1:]:
            data['test'] = data['test'].concatenate(data_dict_per_file[cdf_file])

    else:

        data_dict_per_file = {}
        cdf_files_list = [f for f in os.listdir(config.ipix_pkl_path_dir) if not f.startswith('.')]
        assert len(cdf_files_list) > 0
        # cdf_files_list = ['19980227_221025_ANTSTEP_pol_hh.pkl']
        for cdf_file in cdf_files_list:
            config.ipix_pkl_path = os.path.join(config.ipix_pkl_path_dir, cdf_file)
            # read raw data and convert to fast-time x slow-time complex data
            c_tensor_total, rng_bins_ipix, clutter_vel = read_data_ipix(config)
            config.r_0_max = rng_bins_ipix[-1]
            M_test = int(config.M_test / len(cdf_files_list))
            M0_test = int(config.without_target_ratio_test * M_test)
            M1_test = M_test
            M_valid = int(config.M_valid / len(cdf_files_list))
            M0_valid = int(config.without_target_ratio_test * M_valid)
            M1_valid = M_valid
            M_train = int(config.M_train / len(cdf_files_list))
            M0_train = int(config.without_target_ratio * M_train)
            M1_train = M_train

            # generate tf.data.Dataset objects
            data = {}
            if split_data:
                # test = [0.9, 1.0], valid = [0.85, 0.9], train = [0.0, 0.85]
                c_tensor_total_test = c_tensor_total[:,int(0.9 * c_tensor_total.shape[1]):]
                data['test'] = gen_ipix_pipeline_dataset(c_tensor_total_test, clutter_vel, config, M0_test, M1_test)

                c_tensor_total_valid = c_tensor_total[:, int(0.85 * c_tensor_total.shape[1]): int(0.9 * c_tensor_total.shape[1])]
                data['valid'] = gen_ipix_pipeline_dataset(c_tensor_total_valid, clutter_vel, config, M0_valid, M1_valid)

                c_tensor_total_train = c_tensor_total[:, :int(0.85 * c_tensor_total.shape[1])]
                data['train'] = gen_ipix_pipeline_dataset(c_tensor_total_train, clutter_vel, config, M0_train, M1_train)
            else:
                M1_test = config.M_test
                M0_test = int(config.M_test * config.without_target_ratio_test)
                data['test'] = gen_ipix_pipeline_dataset(c_tensor_total, clutter_vel, config, M0_test, M1_test)


            if apply_tf_preprocess_pipe:
                # add data pipeline functions (maps)
                for set_type in ['train', 'valid', 'test']:
                    data[set_type] = tf_dataset_pipeline(config, data[set_type])

            data_dict_per_file[cdf_file] = data

        if return_dict:
            return data_dict_per_file
        else:
            data = {}
            for set_type in ['train', 'valid', 'test']:
                data[set_type] = data_dict_per_file[cdf_files_list[0]][set_type]
                for cdf_file in cdf_files_list[1:]:
                    data[set_type] = data[set_type].concatenate(data_dict_per_file[cdf_file][set_type])

    return data


In [98]:
def get_dataset_compund_gaussian(config, apply_tf_preprocess_pipe=True):
    data = {}
    M1_train = config.M_train
    M0_train = int(config.M_train * config.without_target_ratio)
    M1_valid = config.M_valid
    M0_valid = int(config.M_valid * config.without_target_ratio_test)
    M1_test = config.M_test
    M0_test = int(config.M_test * config.without_target_ratio_test)

    if config.embedded_target:
        assert config.compound_gaussian_single_clutter_vel
    data['train'] = gen_compound_gaussian_pipeline_dataset(config, M0_train, M1_train)
    data['valid'] = gen_compound_gaussian_pipeline_dataset(config, M0_valid, M1_valid)
    data['test'] = gen_compound_gaussian_pipeline_dataset(config, M0_test, M1_test)

    if apply_tf_preprocess_pipe:
        # add data pipeline functions (maps)
        for set_type in ['train', 'valid', 'test']:
            data[set_type] = tf_dataset_pipeline(config, data[set_type])

    return data

In [99]:
def get_dataset_wgn(config, apply_tf_preprocess_pipe=True):

    data = {}
    M1_train = config.M_train
    M0_train = int(config.M_train * config.without_target_ratio)
    M1_valid = config.M_valid
    M0_valid = int(config.M_valid * config.without_target_ratio_test)
    M1_test = config.M_test
    M0_test = int(config.M_test * config.without_target_ratio_test)

    data['train'] = gen_wgn_pipeline_dataset(config, M0_train, M1_train)
    data['valid'] = gen_wgn_pipeline_dataset(config, M0_valid, M1_valid)
    data['test'] = gen_wgn_pipeline_dataset(config, M0_test, M1_test)

    if apply_tf_preprocess_pipe:
        # add data pipeline functions (maps)
        for set_type in ['train', 'valid', 'test']:
            data[set_type] = tf_dataset_pipeline(config, data[set_type])

    return data

In [100]:
def get_model_output_dim(data, set_type):
    if len(list(data['train'].element_spec[1].shape)) > 1:
        return [list(data['train'].element_spec[1].shape)[0]]
    if type(data[set_type].element_spec[1]) == type(tuple()):
        return list(data[set_type].element_spec[1][0].shape)
    else:
        return [list(data[set_type].element_spec[1].shape)[0]]

def get_model_input_dim(data, set_type):
    if isinstance(data[set_type].element_spec[0], tuple):
        model_input_dim = []
        for spec in data[set_type].element_spec[0]:
            model_input_dim.append(list(spec.shape))
    else:
        model_input_dim = list(data[set_type].element_spec[0].shape)

    return model_input_dim

In [101]:

def make_iterators(data, config):

    M_train = len(data['train'])
    print('make_iterators(): M_train: {}'.format(M_train))
    data['train'] = data['train'].shuffle(M_train, reshuffle_each_iteration=True)


    train_iter = data['train'].batch(config.batch_size, drop_remainder=True).prefetch(config.batch_size)
    valid_iter = data['valid'].batch(config.batch_size)
    test_iter = data['test'].batch(config.batch_size)

    iterators = {'train': train_iter, 'valid': valid_iter, 'test': test_iter}
    return iterators

In [102]:
def load_data(config, use_make_iterators=True, apply_tf_preprocess_pipe=True):
    model_input_dim_set = 'train'
    if config.data_name == "ipix":
        data = get_dataset_ipix(config, apply_tf_preprocess_pipe=apply_tf_preprocess_pipe)
        model_input_dim_set = 'test'
    elif config.data_name == "compound_gaussian":
        data = get_dataset_compund_gaussian(config, apply_tf_preprocess_pipe=apply_tf_preprocess_pipe)
        if config.compound_gaussian_add_wgn:
            data_wgn = get_dataset_wgn(config, apply_tf_preprocess_pipe=apply_tf_preprocess_pipe)
            for key in data.keys():
                # rd_signal , label_tensor, (param_val_tensor, scnr_tensor, gamma_shape, clutter_vel, clutter_label_tensor)
                data[key] = data[key].map(lambda x0, x1, x2: (x0, x1, (x2[0], x2[1], x2[2], tf.constant(0.0), tf.constant(0.0))))
                data[key] = data[key].concatenate(data_wgn[key])
    elif config.data_name == "wgn":
        data = get_dataset_wgn(config, apply_tf_preprocess_pipe=apply_tf_preprocess_pipe)
    else:
        raise Exception(' ')

    # set model_input_dim
    config.model_input_dim = get_model_input_dim(data, model_input_dim_set)
    config.model_output_dim = get_model_output_dim(data, model_input_dim_set)

    if use_make_iterators:
        # make data iterators (shuffle,batch,etc.)
        data_iterators = make_iterators(data, config)
        return config, data_iterators
    else:
        return config, data

In [103]:
def read_data_ipix(config):
    with open(config.ipix_pkl_path, 'rb') as handle:
        ipix_data =pickle.load(handle)
        PRI = ipix_data['PRI']
        B = ipix_data['B']
        rng_bins = ipix_data['rng_bins']
        adc_data = ipix_data['adc_data']

    rng_bins = rng_bins[:config.ipix_max_nrange_bins]
    adc_data = adc_data[:config.ipix_max_nrange_bins, :]
    if not config.ipix_file_range_bins:
        assert config.N <= len(rng_bins) * 2
        adc_data = adc_data[:config.N // 2, :]
        rng_bins = rng_bins[:config.N // 2]
    else:
        config.N = len(rng_bins) * 2
    assert config.N == len(rng_bins) * 2
    config.T_PRI = PRI
    config.B_chirp = B
    if '19980205_191043' in config.ipix_pkl_path:
        # cut weird zero part of this file
        adc_data = adc_data[:, :50000]
    # convert to fast-time x slow-time data
    rng_bins = rng_bins - rng_bins[0]
    config.r_0_max = rng_bins[-1]
    clutter_omega_r = ((2 * np.pi) / config.N) * ((2*B) / 3e8) * rng_bins
    # workaround to prevent GPU overflow in multiple iterations
    try:
        clutter_range_steering_tensor = tf.math.exp(-1j * tf.cast(tf.expand_dims(tf.range(config.N, dtype=tf.float32), -1) * tf.expand_dims(clutter_omega_r, 0), dtype=tf.complex128))
        c_tensor = clutter_range_steering_tensor @ adc_data
    except:
        clutter_range_steering_tensor = np.exp(-1j * tf.cast(tf.expand_dims(tf.range(config.N, dtype=tf.float32), -1) * tf.expand_dims(clutter_omega_r, 0), dtype=tf.complex128))
        c_tensor = clutter_range_steering_tensor @ adc_data

    """
    estimate clutter Doppler frequency using welch method: 
        c_tensor[i] = e ^ {j 2 \pi f_d kT_PRI}
        f_d = (2 * f_c * clutter_vel) / c
    """

    Pxx_den_list = []
    for i in range(adc_data.shape[0]):
        f, Pxx_den = scipy.signal.welch(adc_data[i], 1 / PRI, return_onesided=False)
        Pxx_den_list.append(Pxx_den)

    PSD = np.mean(np.array(Pxx_den_list), 0)
    PSD = PSD / np.max(PSD)
    clutter_fd = f[np.argmax(PSD)]
    clutter_vel = -(3e8 * clutter_fd) / (2 * 9.39e9)

    return c_tensor, rng_bins, clutter_vel

In [104]:
def gen_ipix_pipeline_dataset(c_tensor_total, clutter_vel, config, M0, M1):

    def gen_ipix_frame2d(ind):
        c_tensor = tf.image.random_crop(c_tensor_total, [config.N, config.K])
        clutter_vel_local = clutter_vel
        if config.ipix_random_shift_doppler:
            shift_min = tf.cast(clutter_vel_local - config.v_r_min, dtype=tf.float32)
            shift_max = tf.cast(config.v_r_max - clutter_vel_local, dtype=tf.float32)
            doppler_shift_v = tf.cast(tf.random.uniform([], -shift_min, shift_max), dtype=tf.complex128)
            clutter_vel_local = clutter_vel + tf.cast(tf.math.real(doppler_shift_v), dtype=tf.float32)
            shift_factor = tf.expand_dims(tf.math.exp(-1j * 2 * np.pi * ((2 * config.f_c * doppler_shift_v) / 3e8) * config.T_PRI * tf.cast(tf.range(c_tensor.shape[1]), dtype=tf.complex128) ), 0 )
            c_tensor = c_tensor * shift_factor

            # fig, ax = plt.subplots(figsize=(6, 6))
            # im = ax.imshow(np.log10(tf.signal.fftshift(tf.abs(tf.signal.ifft2d(c_tensor)), axes=(0, 1)))[32:, :], interpolation='none', extent=[np.min(recon_vec_vel), np.max(recon_vec_vel), np.max(recon_vec_rng), np.min(recon_vec_rng)], aspect="auto")
            # plt.xlabel("Doppler Velocity [m/s]")
            # plt.ylabel("Normalized Range [m]")
            # plt.show()
        if with_target is False:
            param_val_tensor = tf.ones(config.num_targets) * -1000.0, tf.ones(config.num_targets) * -1000.0
            return c_tensor, tf.zeros((recon_vec_rng.shape[0], recon_vec_vel.shape[0]), dtype=tf.int64), \
                   param_val_tensor, tf.ones(config.num_targets) * -1000.0, tf.constant(0.0), clutter_vel_local, tf.constant(0.0)
        else:
            cn_norm = tf.abs(tf.linalg.norm(c_tensor))
            rd_signal, label_tensor, param_val_tensor, scnr_tensor = gen_target_matrix(config, cn_norm, clutter_vel_local, N, K, recon_vec_rng, recon_vec_vel)

            return rd_signal + c_tensor, label_tensor, param_val_tensor, scnr_tensor, tf.constant(0.0), clutter_vel_local, tf.constant(0.0)

    def get_ipix_tfds(M_tfds):
        _res = gen_ipix_frame2d(0)
        tfds = tf.data.Dataset.range(0, M_tfds).map(gen_ipix_frame2d, num_parallel_calls=-1)

        return tfds

    assert config.N == c_tensor_total.shape[0]
    assert not (config.SCNR_db_random_constant and config.SCNR_db_random_choice)
    N = config.N
    K = config.K

    recon_vec_rng = tf.cast(get_reconstruction_point_cloud_vec(config, param_ind=0), dtype=tf.float32)
    recon_vec_vel = tf.cast(get_reconstruction_point_cloud_vec(config, param_ind=1), dtype=tf.float32)

    with_target = False
    tfds0 = get_ipix_tfds(M0)
    with_target = True
    tfds1 = get_ipix_tfds(M1)

    assert tfds1.element_spec[0] == tfds0.element_spec[0] and tfds1.element_spec[1] == tfds0.element_spec[1]
    tfds = tfds1.concatenate(tfds0)
    tfds = tfds.map(split_auxillary_structure)


    return tfds


In [105]:
def get_reconstruction_point_cloud_vec(config, param_ind):

    if config.point_cloud_reconstruction_fft_dims:
        N = config.point_cloud_reconstruction_fft_dim_factor * config.N
        config.B_chirp = config.point_cloud_reconstruction_fft_dim_factor * config.B_chirp # multiply to rescale range dimension
        K = config.point_cloud_reconstruction_fft_dim_factor * config.K
        L = config.point_cloud_reconstruction_fft_dim_factor * config.L

        range_res, vel_res, recon_vec_rng, recon_vec, azimuth_bins_values = get_fft_resolutions(config, [1, N, K, L], T_PRI=config.T_PRI)
        bin_values_list, valid_bins_list = get_valid_2d_bins(config, [N, K], recon_vec_rng, recon_vec)
        range_bins_values = bin_values_list[0]
        vel_bins_values = bin_values_list[1]

        config.B_chirp = config.B_chirp /  config.point_cloud_reconstruction_fft_dim_factor
        if param_ind == 0:
            return range_bins_values
        elif param_ind == 1:
            return vel_bins_values
        else:
            raise Exception('  ')
    else:
        raise Exception('get_reconstruction_point_cloud_res(): Unsupported...')


In [106]:
def get_fft_resolutions(config, data_shape, T_PRI=None):
    assert len(data_shape) == 4
    T_PRI = data_shape[1] * (1 / config.f_s) + config.T_idle if T_PRI is None else T_PRI
    vel_res = 3e8 / (2 * config.f_c * data_shape[2] * T_PRI)
    range_res = 3e8 / (2 * config.B_chirp)
    range_bins_values = np.array([range_res * (i - data_shape[1] // 2) for i in range(data_shape[1])])
    vel_bins_values = np.array([vel_res * (i - data_shape[2] // 2) for i in range(data_shape[2])])

    azimuth_bins_values = np.arcsin([(2.0 * (i - data_shape[3] // 2)) / data_shape[3]
                                          for i in range(data_shape[3])]) * (180.0 / np.pi)

    return range_res, vel_res, range_bins_values, vel_bins_values, azimuth_bins_values

In [107]:

def get_valid_2d_bins(config, full_shape, range_bins_values, vel_bins_values):
    assert len(full_shape) == 2
    valid_vel_bins = \
    np.where(np.logical_and(vel_bins_values >= config.v_0_min, vel_bins_values <= config.v_0_max))[0]
    if valid_vel_bins[-1] < full_shape[1] - 1:
        valid_vel_bins = np.append(valid_vel_bins, valid_vel_bins[-1] + 1)
    if valid_vel_bins[0] > 0:
        valid_vel_bins = np.insert(valid_vel_bins, 0, valid_vel_bins[0] - 1)
    vel_bins_values = vel_bins_values[valid_vel_bins]

    valid_range_bins = np.where(np.logical_and(range_bins_values >= config.r_0_min, range_bins_values <= config.r_0_max))[0]
    # if valid_range_bins[-1] < full_shape[0] - 1:
    #     valid_range_bins = np.append(valid_range_bins, valid_range_bins[-1] + 1)
    # if valid_range_bins[0] > 0:
    #     valid_range_bins = np.insert(valid_range_bins, 0, valid_range_bins[0] - 1)
    range_bins_values = range_bins_values[valid_range_bins]

    return [range_bins_values, vel_bins_values], [valid_range_bins, valid_vel_bins]

### Target and label tensors generator

***This function generates a target matrix along with corresponding labels, parameter values, and SCNR tensors***

In [108]:
def gen_target_matrix(config, cn_norm, clutter_vel_local, N, K, recon_vec_rng, recon_vec_vel):

    # randomly determine the number of targets or set it to a constant value
    if config.random_num_targets:
        targets_num = tf.random.uniform([1, ], minval=1, maxval=config.num_targets + 1, dtype=tf.int64)
    else:
        targets_num = tf.cast(tf.constant([config.num_targets]), dtype=tf.int64)
    
    # Generate target velocities within specified range considering the presence of embedded targets
    if config.embedded_target:
        targets_vel = tf.random.uniform(targets_num, tf.maximum(clutter_vel_local - config.embedded_target_vel_offset, config.v_0_min),
                                                     tf.minimum(clutter_vel_local + config.embedded_target_vel_offset, config.v_0_max))
    else:
        targets_vel = tf.random.uniform(targets_num, config.v_0_min, config.v_0_max)
    
    # Generate target ranges within specified range
    targets_rng = tf.random.uniform(targets_num, config.r_0_min, config.r_0_max)
    
    # compute doppler and range target frequencies
    targets_omega_d = tf.cast(2 * np.pi * config.T_PRI * ((2 * config.f_c * targets_vel) / 3e8), dtype=tf.complex128)
    targets_omega_r = tf.cast(2 * np.pi * ((2 * config.B_chirp * targets_rng) / (3e8 * N)), dtype=tf.complex128)
    
    # compute doppler and range steering tensors
    doppler_steering_tensor = tf.math.exp(-1j * tf.expand_dims(targets_omega_d, 1) * tf.cast(tf.expand_dims(tf.range(K), 0), dtype=tf.complex128))
    range_steering_tensor = tf.math.exp(-1j * tf.expand_dims(targets_omega_r, 1) * tf.cast(tf.expand_dims(tf.range(N), 0), dtype=tf.complex128))
    
    # compute range-doppler and signal and get the SCNR
    rd_signal = tf.expand_dims(range_steering_tensor, 2) * tf.expand_dims(doppler_steering_tensor, 1)
    SCNR_db = get_SCNR_db(config, targets_num)
    
    # Adjust phase of the range-Doppler signal based on the configuration
    if config.signal_random_phase:
        rd_signal = rd_signal * tf.math.exp(1j * tf.cast(tf.expand_dims(tf.expand_dims(tf.random.uniform(targets_num, 0, 2 * np.pi), 1), 1), dtype=tf.complex128))
    elif config.signal_physical_phase:
        targets_tau0 = (2 * targets_rng) / 3e8
        rd_signal = rd_signal * tf.expand_dims(tf.expand_dims(tf.math.exp(1j * (tf.cast(-2 * np.pi * config.f_c * targets_tau0 + np.pi * (config.B_chirp / (config.N * config.f_s)) * (targets_tau0 ** 2), dtype=tf.complex128))), 1), 1)

    # compensate for the appropriate SCNR level
    s_norm = tf.math.real(tf.norm(rd_signal, axis=[1, 2]))
    sig_amp = (10 ** (tf.cast(SCNR_db, dtype=tf.float64) / 20.0)) * (tf.cast(cn_norm, dtype=tf.float64) / s_norm)
    rd_signal = tf.reduce_sum(tf.cast(tf.expand_dims(tf.expand_dims(sig_amp, -1), -1), dtype=tf.complex128) * rd_signal, axis=0)
    
    # gen label vector
    trgt_inds_vel = tf.expand_dims(tf.math.argmin(tf.abs(tf.expand_dims(targets_vel, 1) - tf.expand_dims(recon_vec_vel, 0)), axis=1), 1)
    trgt_inds_rng = tf.expand_dims(tf.math.argmin(tf.abs(tf.expand_dims(targets_rng, 1) - tf.expand_dims(recon_vec_rng, 0)), axis=1), 1)
    tf.debugging.assert_type(trgt_inds_vel, tf.int64)
    tf.debugging.assert_type(trgt_inds_rng, tf.int64)
    trgt_inds = tf.concat((trgt_inds_rng, trgt_inds_vel), 1)
    tf.debugging.assert_type(trgt_inds, tf.int64)
    label_tensor = tf.scatter_nd(trgt_inds, tf.squeeze(tf.ones_like(trgt_inds_vel), 1), (recon_vec_rng.shape[0], recon_vec_vel.shape[0]))

    #gen paramter value and SCNR tensor
    param_val_tensor = (tf.concat((targets_rng, tf.ones(config.num_targets - targets_num) * -1000.0), axis=0),
                        tf.concat((targets_vel, tf.ones(config.num_targets - targets_num) * -1000.0), axis=0))
    scnr_tensor = tf.concat((SCNR_db, tf.ones(config.num_targets - targets_num) * -1000.0), axis=0)

    return rd_signal, label_tensor, param_val_tensor, scnr_tensor

### SCNR

***gets the signal to clutter plus noise ratio:***
- if config.SCNR_db_random_choice, generates random integers between 0 and the length of config.SCNRs_eval. 
The SCNR values are gathered from SCNRs_eval using tf.gather() based on the indices in scnr_eval_inds
- else if config.SCNR_db_random_constant, a single random SCNR value is chosen from config.SCNRs_eval and repeated for targets_num times.
    scnr_eval_inds is generated to select a random index from config.SCNRs_eval.
    
    The selected SCNR value is multiplied by tf.ones(targets_num) to create a tensor of the same length as targets_num.
    The SCNR values are assigned to SCNR_db.
- if none of them so a constant SCNR value is used from config.SCNR_db, repeated for targets_num times.

    The constant SCNR value is multiplied by tf.ones(targets_num) to create a tensor of the same length as targets_num.***

In [109]:
def get_SCNR_db(config, targets_num):

    if config.SCNR_db_random_choice:
        SCNRs_eval = tf.constant(config.SCNRs_eval, dtype=tf.float32)
        scnr_eval_inds = tf.random.uniform(tf.cast(targets_num, dtype=tf.int32), minval=0, maxval=len(config.SCNRs_eval), dtype=tf.int32)
        SCNR_db = tf.gather(SCNRs_eval, scnr_eval_inds)
    elif config.SCNR_db_random_constant:
        SCNRs_eval = tf.constant(config.SCNRs_eval, dtype=tf.float32)
        scnr_eval_inds = tf.random.uniform([1], minval=0, maxval=len(config.SCNRs_eval), dtype=tf.int32)
        SCNR_db = tf.gather(SCNRs_eval, scnr_eval_inds) * tf.ones(targets_num)
    elif config.random_SCNR:
        SCNR_db = tf.random.uniform(targets_num, minval=config.SCNR_db_range[0], maxval=config.SCNR_db_range[1] + 0.001)
    else:
        SCNR_db = tf.constant(config.SCNR_db, dtype=tf.float32) * tf.ones(targets_num)

    return SCNR_db

### Split auxillary

***This function removes any 1 sized dimensions from mat_complex, keep mat_label untouched and creates a tuple from the other variables***

In [110]:
def split_auxillary_structure(mat_complex, mat_label, param_val, scnr, gamma_shape, clutter_vel, clutter_label_tensor):
    return tf.squeeze(mat_complex), mat_label, (param_val, scnr, gamma_shape, clutter_vel, clutter_label_tensor)


### Preprocessing pipeline

***This function manages the preprocessing of the data***

In [111]:
def tf_dataset_pipeline(config, data):

    # centering and reshape mat_complex
    def two_stage_fc_preprocess(mat_complex, label):
        mat = cube_center_and_reshape(mat_complex)
        return mat, label

    # standartization: element-wise division by the std 
    def two_stage_fc_stdize(mat_complex, mat_label, aux):
        mat_complex = mat_complex / tf.cast(tf.math.reduce_std(mat_complex, axis=0), dtype=tf.complex128)
        return mat_complex, mat_label, aux
    
    # ????? sends to previous function - why ??????
    def two_stage_fc_preprocess_cg(mat_complex, mat_label, aux):
        mat_complex, mat_label = two_stage_fc_preprocess(mat_complex, mat_label)
        return mat_complex, mat_label, aux

    # transpose
    def transpose_mat_complex(mat_complex, mat_label, aux):
        mat_complex = tf.transpose(mat_complex)
        return mat_complex, mat_label, aux

    # Concatenates the real and imaginary parts
    def concat_real_imag_cg(mat_complex, mat_label, aux):
        return tf.concat((tf.math.real(mat_complex), tf.math.imag(mat_complex)), axis=-1), mat_label, aux

    # summing along a specified axis (reduce_axis) and then clipping the values to be between 0 and 1
    def preprocess_label2d(mat_label):
        return tf.cast(tf.clip_by_value(tf.reduce_sum(mat_label, axis=reduce_axis), 0, 1), dtype=tf.float32)

    # ?
    def cg_preprocess_label_2dims(mat_complex, mat_label, aux):
        mat_label = tf.cast(tf.clip_by_value(tf.reduce_sum(mat_label, axis=reduce_axis), 0, 1), dtype=tf.float32)
        return mat_complex, mat_label, aux

    # squeeze and clip values between 0 and 1
    def cg_preprocess_label_1dim(mat_complex, mat_label, aux):
        mat_label = tf.squeeze(tf.cast(tf.clip_by_value(mat_label, 0, 1), dtype=tf.float32))
        return mat_complex, mat_label, aux

    # centers by subtracting the mean of reshaped input along the last axis and then reshapes it back to the original shape. 
    def cube_center_and_reshape(mat):
        mat_center = mat - tf.reduce_mean(tf.reshape(mat, (-1, mat.shape[-1])), axis=0)
        return tf.reshape(mat_center, (-1, mat_center.shape[-1]))

    '''
    if config.data_name is "ipix" & config.model_name is "Detection-TwoStage-FC":
    if config.estimation_params is ["rng"], the transpose_mat_complex function is applied to the data
    the function is sent to cube_center_and_reshape (from two_stage_fc_preprocess...)
    if config.two_stage_fc_stdize is True - it is divided element-wise by the std
    we then concatenate the real and imaginary parts, The reduce_axis is set to 1 if config.estimation_params is "rng", else 0.
    '''
    if config.data_name == "ipix":
        if config.model_name == "Detection-TwoStage-FC":
            assert config.estimation_params == ["rng"] or config.estimation_params == ["vel"]
            if config.estimation_params == ["rng"]:
                data = data.map(transpose_mat_complex)
            data = data.map(two_stage_fc_preprocess_cg)
            if config.two_stage_fc_stdize:
               data = data.map(two_stage_fc_stdize)
            data = data.map(concat_real_imag_cg)
            reduce_axis = 1 if config.estimation_params == ["rng"] else 0
            data = data.map(cg_preprocess_label_2dims)

    
    #The first is similar to the previous section:
    if config.data_name == "compound_gaussian" or config.data_name == "wgn":
        if config.model_name == "Detection-TwoStage-FC":
            assert config.estimation_params == ["rng"] or config.estimation_params == ["vel"]
            if config.estimation_params == ["rng"]:
                data = data.map(transpose_mat_complex)
            data = data.map(two_stage_fc_preprocess_cg)
            if config.two_stage_fc_stdize:
                data = data.map(two_stage_fc_stdize)
            data = data.map(concat_real_imag_cg)
    # from here comes the difference:
    # if config.compound_gaussian_dims is 2, the reduce_axis is set to 1 if config.estimation_params is ["rng"],else 0, and the cg_preprocess_label_2dims function is applied to the data.
    # if config.compound_gaussian_dims is not 2, we squeeze and clip values between 0 and 1.
    # if config.model_name is "Detection-FC" and the data is "compound_gaussian" or "wgn", the lambda function is applied to the data. 
    # the lambda function concatenates the real and imaginary parts of t along dimension 0, squeezes the label, and performs value clipping between 0 and 1.
            if config.compound_gaussian_dims == 2:
                reduce_axis = 1 if config.estimation_params == ["rng"] else 0
                data = data.map(cg_preprocess_label_2dims)
            else:
                data = data.map(cg_preprocess_label_1dim)
        elif config.model_name == "Detection-FC":
            data = data.map(lambda t, label, aux: (tf.squeeze(tf.concat((tf.math.real(t), tf.math.imag(t)), 0)),
                                                   tf.cast(tf.clip_by_value(tf.squeeze(label), 0, 1), dtype=tf.float32), aux))

    return data

## Model

In [112]:
class TwoStageFcLayer(tf.keras.layers.Layer):
    """
    Perform TwoStage Fully-Connected
    input tensor: (input_row_dim, input_col_dim)
    output tensor: (row_units, column_units)
    """
    def __init__(self, row_units, column_units, l2_lamda, activation_name, use_batchnorm, dropout_rate, is_first, **kwargs):
        super(TwoStageFcLayer, self).__init__()

        self.row_units = row_units
        self.column_units = column_units
        self.l2_lamda = l2_lamda
        self.activation_name = activation_name
        self.use_batchnorm = use_batchnorm
        self.dropout_rate = dropout_rate

        if l2_lamda > 0.0:
            # kernel_reg = EyeRegularizer() if is_first else tf.keras.regularizers.l2(l2_lamda)
            self.col_dense = tf.keras.layers.Dense(column_units,
                                bias_regularizer=tf.keras.regularizers.l2(l2_lamda),
                                kernel_regularizer=tf.keras.regularizers.l2(l2_lamda))
            # kernel_regularizer=kernel_reg)
            self.row_dense = tf.keras.layers.Dense(row_units,
                                bias_regularizer=tf.keras.regularizers.l2(l2_lamda),
                                kernel_regularizer=tf.keras.regularizers.l2(l2_lamda))
        else:
            self.col_dense = tf.keras.layers.Dense(column_units)
            self.row_dense = tf.keras.layers.Dense(row_units)
        self.col_activation = Activation(activation_name)
        self.row_activation = Activation(activation_name)

        if self.use_batchnorm:
            self.col_bnorm = tf.keras.layers.BatchNormalization()
            self.row_bnorm = tf.keras.layers.BatchNormalization()

        if self.dropout_rate > 0.0:
            self.col_dropout = Dropout(rate=self.dropout_rate)
            self.row_dropout = Dropout(rate=self.dropout_rate)
        return

    def call(self, input_tensor, training=False):
        # input_tensor: [batch, input_row_dim, input_col_dim]

        x = self.col_dense(input_tensor) # [batch, input_row_dim, column_units]
        x = self.col_activation(x)
        if self.use_batchnorm:
            x = self.col_bnorm(x, training=training)
        if self.dropout_rate > 0.0:
            x = self.col_dropout(x, training=training)

        x = tf.transpose(x, perm=[0, 2, 1]) # [batch, column_units, input_row_dim]

        x = self.row_dense(x) # [batch, column_units, row_units]
        x = self.row_activation(x)
        if self.use_batchnorm:
            x = self.row_bnorm(x, training=training)
        if self.dropout_rate > 0.0:
            x = self.row_dropout(x, training=training)

        x = tf.transpose(x, perm=[0, 2, 1]) # [batch, row_units, col_units]

        return x


In [113]:
def TwoStageFCModel(config, include_top=True, name_str=""):
    l2_lamda = config.l2_reg_parameter
    two_stage_fc_dims = config.two_stage_fc_dims
    two_stage_fc_use_batch_norm = config.two_stage_fc_use_batch_norm
    two_stage_fc_dropout_rate = config.two_stage_fc_dropout_rate
    activation_name = config.activation
    input_layer = Input(shape=config.model_input_dim, name="input")
    x = input_layer

    for i in range(len(two_stage_fc_dims)):
        is_first = True if i==0 else False
        x = TwoStageFcLayer(row_units=two_stage_fc_dims[i][0], column_units=two_stage_fc_dims[i][1], l2_lamda=l2_lamda,
                             activation_name=activation_name,
                             use_batchnorm=two_stage_fc_use_batch_norm[i], dropout_rate=two_stage_fc_dropout_rate[i], is_first=is_first)(x)
    if config.two_stage_fc_use_gap:
        x = tf.keras.layers.GlobalAveragePooling1D()(x)
    else:
        x = Flatten()(x)

    config.dense_sizes = config.two_stage_fc_dense_sizes
    config.dense_dropout = config.two_stage_fc_dense_dropout
    config.fc_batchnorm = config.two_stage_fc_dense_batchnorm
    if config.dense_sizes != []:
        x = FCSkeletonModel(config, x, create_model=False)

    if config.point_cloud_reconstruction:
        last_layer_dim = config.model_output_dim[0]
    elif config.mode == "Estimation":
        last_layer_dim = len(config.estimation_params)
    elif config.mode == "Detection":
        last_layer_dim = 2
    else:
        raise Exception("TwoStageFCModel(): Unsupported config.mode")

    if l2_lamda > 0:
        y_hat = Dense(last_layer_dim,
                  kernel_regularizer=tf.keras.regularizers.l2(l2_lamda),
                  bias_regularizer=tf.keras.regularizers.l2(l2_lamda))(x)
    else:
        y_hat = Dense(last_layer_dim)(x)

    if config.point_cloud_reconstruction:
        y_hat = activation(config, 'sigmoid', y_hat)
    elif config.mode == "Detection":
        y_hat = activation(config, 'softmax', y_hat)

    model = tf.keras.Model(inputs=input_layer, outputs=y_hat, name="TwoStageFcModel")

    return model


In [114]:

def activation(config, activation_name, x):
    if activation_name == 'leaky_relu':
        return LeakyReLU(alpha=config.leaky_alpha)(x)
    else:
        return Activation(activation_name)(x)

In [115]:
def FCSkeletonModel(config, input_layer, create_model=True, output_name="", dense_activations=None):

    # get parameters from config file
    dense_sizes = config.dense_sizes
    dense_dropout = [0 for _ in range(len(dense_sizes))] if config.dense_dropout is None else config.dense_dropout
    activation_name = config.activation
    if dense_activations == None:
        dense_activations = [activation_name for j in range(len(dense_sizes))]
    else:
        assert len(dense_activations) == len(dense_sizes)
    l2_lamda = config.l2_reg_parameter

    x = input_layer
    # Dense
    for i, size in enumerate(dense_sizes):
        if l2_lamda != 0:
            x = Dense(size, kernel_regularizer=tf.keras.regularizers.l2(l2_lamda),
                      bias_regularizer=tf.keras.regularizers.l2(l2_lamda))(x)
        else:
            x = Dense(size)(x)
        if dense_activations[i] != "None":
            x = activation(config, dense_activations[i], x)
        if config.fc_batchnorm:
            x = BatchNormalization()(x)
        # Dropout, at the last layer apply dropout after flatten
        if dense_dropout[i] != 0:
            x = Dropout(rate=dense_dropout[i])(x)

    output_layer = x
    if create_model:
        model = Model(input_layer, output_layer)
        if output_name!= "":
            model.layers[len(model.layers) - 1]._name = output_name
        return model
    else:
        return output_layer


In [116]:
def EstimationFCModel(config):
    input_layer = Input(shape=config.model_input_dim)
    l2_lamda = config.l2_reg_parameter

    model_fc_skeleton = FCSkeletonModel(config, input_layer)
    x = model_fc_skeleton.output

    # prediction neuron
    output_layer = Dense(len(config.estimation_params),
                         kernel_regularizer=tf.keras.regularizers.l2(l2_lamda),
                         bias_regularizer=tf.keras.regularizers.l2(l2_lamda))(x)

    model = tf.keras.Model(input_layer, output_layer, name="EstimationFCModel")

    return model

In [117]:

def build_model(config):
    if config.load_complete_model:
        model = tf.keras.models.load_model(config.load_model_path, compile=False)
        print("\n" + "!" * 25 + "\n" + "WARNING: LOADING MODEL FROM: {}".format(config.load_model_path) + "\n" + "!" * 25 + "\n")
    elif config.model_name == "Estimation-FC":
        model = EstimationFCModel(config)
    elif config.model_name == "Estimation-TwoStage-FC" or config.model_name == "Detection-TwoStage-FC":
         model = TwoStageFCModel(config)
    else:
        raise ValueError("'{}' is an invalid model name")

    model.summary(line_length=140)
    if config.load_complete_model:
        print("\n" + "!" * 25 + "\n" + "WARNING: LOADING MODEL FROM: {}".format(config.load_model_path) + "\n" + "!" * 25 + "\n")
    try:
        model_plot_dir = os.path.join(config.tensor_board_dir, "model_plot")
        if not os.path.exists(model_plot_dir):
            os.makedirs(model_plot_dir)
        img_pth = os.path.join(model_plot_dir, config.model_name + ".png")
        plot_model(model, img_pth, show_shapes=False)
        print("saved model plot at: {}".format(img_pth))
    except Exception as e:
        print("Failed to plot model:" + str(e))

    return model


## Trainer

In [118]:
class FocalLoss(tf.keras.losses.Loss):
    def __init__(self, gamma=5):
        super().__init__()
        self.gamma = gamma
        return

    def call(self, y_true, y_pred):
        _nll2 = tf.keras.losses.binary_crossentropy(tf.expand_dims(y_true, -1), tf.expand_dims(y_pred, -1))
        pt = tf.zeros_like(_nll2)
        # pt for y_true == 1
        ind1 = tf.where(y_true >= 0.999)
        pt1 = tf.gather_nd(y_pred, ind1)
        pt = tf.tensor_scatter_nd_update(pt, ind1, pt1)
        # pt for y_true == 0
        ind0 = tf.where(y_true < 0.999)
        pt0 = 1 - tf.gather_nd(y_pred, ind0)
        pt = tf.tensor_scatter_nd_update(pt, ind0, pt0)
        # compute Focal Loss
        loss = tf.math.pow(1 - pt, self.gamma) * _nll2

        return tf.reduce_mean(loss, -1)

In [119]:
class ClassBalancedBinaryCrossEntropy(tf.keras.losses.Loss):
    def __init__(self, e1, e0, predefined_weight, use_penalize_margin=False, penalize_margin=8, balanced_loss_beta=0.99, recon_dim=128):
        super().__init__()
        self.balanced_loss_beta = balanced_loss_beta
        self.e1 = e1
        self.e0 = e0
        self.predefined_weight = predefined_weight
        self.use_penalize_margin = use_penalize_margin
        self.penalize_margin = penalize_margin

        if self.predefined_weight > 0.0:
            self.weight0 = 1.0
            self.weight1 = self.predefined_weight
        else:
            self.weight0 = ((1.0 - self.balanced_loss_beta) / (1.0 - self.balanced_loss_beta ** self.e0))
            self.weight1 = ((1.0 - self.balanced_loss_beta) / (1.0 - self.balanced_loss_beta ** self.e1))

        if self.use_penalize_margin:
            assert self.penalize_margin > 1
            self.ind1_getter = self.get_ind1_with_margin
        else:
            self.ind1_getter = self.get_ind1

        self.label_map = lambda y_true: y_true

        self.recon_dim = recon_dim


    def get_ind1(self, y_true):
        ind1 = tf.where(y_true >= 0.9999)
        return ind1

    def get_ind1_with_margin(self, y_true):
        def get_ind_margin_inds(t):
            # t = tf.where(y_true >= 0.9999)[0]
            ind_plus = tf.concat((tf.expand_dims(tf.gather(t, 0) * tf.cast(tf.ones(self.penalize_margin), dtype=tf.int64), 1),
                 tf.expand_dims(tf.minimum(tf.gather(t, 1) + tf.range(self.penalize_margin, dtype=tf.int64), self.recon_dim - 1), 1)), axis=1)
            ind_minus = tf.concat((tf.expand_dims(tf.gather(t, 0) * tf.cast(tf.ones(self.penalize_margin - 1), dtype=tf.int64), 1),
                 tf.expand_dims(tf.maximum(tf.gather(t, 1) - tf.range(1, self.penalize_margin, dtype=tf.int64), 0), 1)), axis=1)

            return tf.concat((ind_plus, ind_minus), axis=0)
        ind1 = tf.reshape(tf.map_fn(get_ind_margin_inds, tf.where(y_true > 0.9999)),(-1, 2))
        return ind1

    def call(self, y_true, y_pred):
        # call function for the regular single-label tensor
        # y_true = [batch_dim, recon_dim], y_pred = [batch_dim, recon_dim]
        ind1 = self.ind1_getter(y_true)
        # y_true = self.label_map(y_true)
        _nll2 = tf.keras.losses.binary_crossentropy(tf.expand_dims(self.label_map(y_true), -1), tf.expand_dims(y_pred, -1))
        _nll_subset = self.weight1 * tf.gather_nd(_nll2, ind1)
        _nll2 = tf.tensor_scatter_nd_update(_nll2, ind1, _nll_subset)

        return tf.reduce_mean(_nll2, -1)


In [120]:
def CBBCE_get_n0_n1(config, data):
    if len(data['train'].element_spec[1].shape) == 1:
        n_total = len(data['train']) * data['train'].element_spec[1].shape[0]
        n1 = sum([len(np.where(y >= 0.9999)[0]) for X, y, aux in data['train'].as_numpy_iterator()])
        n0 = n_total - n1
        return n0, n1, n_total, None

    else:
        n_total = data['train'].element_spec[1].shape[0] * data['train'].element_spec[1].shape[1] * len(data['train'])
        if len(list(data['train'].element_spec[1].shape)) == 3:
            n1 = sum([len(np.where(y[:, :, 0] >= 0.9999)[0]) for X, y, aux in data['train'].as_numpy_iterator()])
        else:
            n1 = sum([len(np.where(y >= 0.9999)[0]) for X, y, aux in data['train'].as_numpy_iterator()])
        n0 = n_total - n1

    return n0, n1, n_total, None


In [121]:
class KerasTrainer(object):
    """
    General Keras Trainer class
    """

    def __init__(self, model, data, config, sweep_string=None):
        self.model_train = model
        # self.model_eval = model['eval']
        self.data = data
        self.config = config
        self.exp_name_time = config.exp_name_time
        self.sweep_string = sweep_string
        self.callback_list = []
        self.optimizer = self.get_optimizer(config.optimizer)

    def get_optimizer(self, name):
        if name == "adam":
            return Adam(learning_rate=self.config.learning_rate)
        else:
            raise Exception('Unsupported optimizer !!')

    def add_callbacks(self):
        self.callback_list = []
        checkpoint_dir = os.path.join(self.config.tensor_board_dir, 'checkpoints')
        checkpoint_best_filepath = os.path.join(checkpoint_dir, 'model_checkpoint_best')
        checkpoint_epoch_filepath = os.path.join(checkpoint_dir, 'model_checkpoint_epoch')
        # Save best Epoch model
        if self.config.use_model_checkpoint_best:
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            # save best model
            save_model_best_callback = ModelCheckpoint(filepath=checkpoint_best_filepath, save_weights_only=False,
                                                       monitor=self.config.model_checkpoint_best_metric,
                                                       save_best_only=True, verbose=1)
            self.callback_list.append(save_model_best_callback)

        # Save model periodically
        if self.config.model_checkpoint_epoch_period > 0:
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)
            # save model periodically
            save_model_epoch_callback = ModelCheckpoint(filepath=checkpoint_epoch_filepath, save_weights_only=False,
                                                        save_best_only=False, verbose=1)
            self.callback_list.append(save_model_epoch_callback)

        # CSV Logger for per-epoch logging
        if self.config.save_fit_history:
            fit_log_dir = os.path.join(self.config.tensor_board_dir, 'fit_log')
            if not os.path.exists(fit_log_dir):
                os.makedirs(fit_log_dir)
            if self.sweep_string is None:
                csv_logger_path = os.path.join(fit_log_dir, '{}_fit_log.csv'.format(self.config.exp_name_time))
            else:
                csv_logger_path = os.path.join(fit_log_dir,
                                               '{}_fit_log.csv'.format(self.sweep_string))
            csv_logger = CSVLogger(csv_logger_path)
            self.callback_list.append(csv_logger)

        # Early-Stopping
        if self.config.use_early_stop:
            # patiennce_epochs = int(self.config.early_stop_patience*self.config.num_epochs) if self.config.early_stop_patience<=1 else self.config.early_stop_patience
            # early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor=self.config.early_stop_metric, verbose=1, patience=patiennce_epochs)
            early_stop_callback = EarlyStoppingCallback(self.config)
            self.callback_list.append(early_stop_callback)

        if self.config.use_lr_scheduler:
            assert not self.config.use_lr_scheduler_deriv
            lr_scheduler = LrScheduler(self.config)
            lr_scheduler_callback = tf.keras.callbacks.LearningRateScheduler(lr_scheduler.schedule)
            self.callback_list.append(lr_scheduler_callback)
        if self.config.use_lr_scheduler_plateau:
            assert not self.config.use_lr_scheduler_deriv
            assert not self.config.use_lr_scheduler
            lr_scheduler_plateau = LrSchedulerPlateau(self.config)
            self.callback_list.append(lr_scheduler_plateau)

        if self.config.use_lr_scheduler_deriv:
            assert not self.config.use_lr_scheduler
            lr_scheduler_callback = LrSchedulerDeriv(self.config)
            self.callback_list.append(lr_scheduler_callback)

        if self.config.stop_max_acc:
            callback_obj = StoppingAtMaxAccuracy()
            self.callback_list.append(callback_obj)

        self.callback_list = None if self.callback_list == [] else self.callback_list
        return

    def compile(self, loss_fn, metrics):
        # compile the model
        self.model_train.compile(optimizer=self.optimizer, loss=loss_fn, metrics=metrics)

    def train(self):
        # # model checkpoints for
        # self.add_callbacks()

        # train the model
        history = self.model_train.fit(self.data['train'], epochs=self.config.num_epochs,
                                       validation_data=self.data['valid'],
                                       callbacks=self.callback_list,
                                       verbose=self.config.fit_verbose)
        return history


In [122]:
def positive_binary_cross_entropy(y_true, y_pred):
    ind1 = tf.where(y_true >= 0.9999)
    value = tf.cond(tf.size(ind1)==0, lambda: tf.constant(0.0),
            lambda :tf.reduce_mean(metrics.binary_crossentropy(tf.expand_dims(tf.gather_nd(y_true, ind1), -1),
                                               tf.expand_dims(tf.gather_nd(y_pred, ind1), -1)), axis=-1))
    return value
def negative_binary_cross_entropy(y_true, y_pred):
    ind0 = tf.where(y_true < 0.9999)
    value = tf.reduce_mean(metrics.binary_crossentropy(tf.expand_dims(tf.gather_nd(y_true, ind0), -1),
                                                       tf.expand_dims(tf.gather_nd(y_pred, ind0), -1)), axis=-1)
    return value


In [123]:
class ClassificationTrainerKeras(KerasTrainer):

    def __init__(self, model, data, config, sweep_string=None):
        super().__init__(model, data, config, sweep_string)
        # tf.config.experimental_run_functions_eagerly(True)
        assert config.data_name == "compound_gaussian" or config.data_name == "ipix" or config.data_name == "wgn"
        self.sweep_string = '' if sweep_string is None else sweep_string
        tune_hist_dir = os.path.join(self.config.tensor_board_dir, 'tune_hist')
        if not os.path.exists(tune_hist_dir):
            os.makedirs(tune_hist_dir)
        self.tune_hist_path = os.path.join(tune_hist_dir, 'tune_hist_' + self.sweep_string + '.csv')
        assert config.point_cloud_reconstruction
        if self.config.use_CBBCE:
            n0, n1, n_total, n2 = CBBCE_get_n0_n1(config, data)
            self.loss_fn = ClassBalancedBinaryCrossEntropy(e0=n0 / n_total, e1=n1 / n_total,
                predefined_weight=self.config.CBBCE_predefined_weight,
                use_penalize_margin=self.config.CBBCE_use_penalize_margin, penalize_margin=self.config.CBBCE_penalize_margin,
                recon_dim=self.config.model_output_dim[0])
        else:
            self.loss_fn = FocalLoss()

        self.metrics = [metrics.binary_crossentropy, positive_binary_cross_entropy, negative_binary_cross_entropy,
                    metrics.AUC(name='auc'), 'accuracy', metrics.FalsePositives(name="fp"), metrics.TruePositives(name="tp")]

    def train(self):
        # super().compile(self.loss_fn, self.metrics)
        self.model_train.compile(optimizer=self.optimizer, loss=self.loss_fn, metrics=self.metrics)
        super().add_callbacks()

        # train the model
        data_train = self.data['train']
        data_valid = self.data['valid']

        if self.config.data_name == "compound_gaussian" or self.config.data_name == "ipix" or self.config.data_name == "wgn":
            data_train = data_train.map(compound_gaussian_split_aux_trainer)
            data_valid = data_valid.map(compound_gaussian_split_aux_trainer)

        history = self.model_train.fit(data_train, epochs=self.config.num_epochs, validation_data=data_valid, callbacks=self.callback_list, verbose=self.config.fit_verbose)
        return history


    def evaluate(self):
        eval_res = self.model_train.evaluate(self.data['valid'])
        return eval_res

    def test(self):
        test_res = self.model_train.evaluate(self.data['test'])
        return test_res

    def train_eval(self):
        return self.model_train.evaluate(self.data['train'])

    def predict(self, X):
        return self.model_train.predict(X)


In [124]:
def build_trainer(model, data, config, sweep_string=None):
    if config.trainer_name == "detection_classification":
        trainer = ClassificationTrainerKeras(model, data, config, sweep_string)
    else:
        raise ValueError("'{}' is an invalid model name")

    return trainer


## nect step

In [132]:
class LrSchedulerPlateau(Callback):
    def __init__(self, config):
        super(LrSchedulerPlateau, self).__init__()
        self.decay = config.lr_scheduler_plateau_decay
        self.ma_window = config.lr_scheduler_plateau_window + 1
        self.cooldown = config.lr_scheduler_plateau_cooldown
        self.epoch_threshold = int(config.lr_scheduler_plateau_epoch_threshold * config.num_epochs)
        self.val_loss_buffer = []
        self.last_update_epoch = 0

    def on_epoch_end(self, epoch, logs=None):
        # first epoch is 0
        val_loss = logs.get("val_loss")
        self.val_loss_buffer.append(val_loss)

        if epoch > self.epoch_threshold and epoch > self.ma_window + 1:
            val_loss_mean = np.mean(self.val_loss_buffer[-self.ma_window:-1])
            if (val_loss - val_loss_mean > 1e-4) and (epoch - self.last_update_epoch > self.cooldown) :
                self.model.optimizer.learning_rate = self.model.optimizer.learning_rate * self.decay
                self.last_update_epoch = epoch

        return

In [130]:
class LrScheduler(object):
    def __init__(self, config):
        self.decay = config.lr_scheduler_decay
        self.period = config.lr_scheduler_period
        self.epoch_threshold = int(config.lr_scheduler_epoch_threshold * config.num_epochs)

    def schedule(self, epoch, lr):
        if epoch < self.epoch_threshold:
            return lr
        else:
            if epoch % self.period == 0:
                return self.decay * lr
            else:
                return lr

class EarlyStoppingCallback(Callback):
    def __init__(self, config):
        super(EarlyStoppingCallback, self).__init__()
        self.metric = config.early_stop_metric
        self.epoch_patience = int(config.early_stop_patience*config.num_epochs) if config.early_stop_patience<=1 else config.early_stop_patience
        self.val_loss_buffer = []
        self.stopped_epoch = None

        return
    def on_epoch_end(self, epoch, logs=None):
        val_loss = logs.get("val_loss")
        self.val_loss_buffer.append(val_loss)

        if epoch > self.epoch_patience + 1:
            val_loss_diff = self.val_loss_buffer[-1] - self.val_loss_buffer[-self.epoch_patience]
            if val_loss_diff > 1e-4:
                self.stopped_epoch = epoch
                self.model.stop_training = True

        return

    def on_train_end(self, logs=None):
        if self.stopped_epoch is not None:
            print("EarlyStoppingCallback(): Epoch %05d: early stopping" % (self.stopped_epoch + 1))
        return

## Activate

In [126]:
config, data = load_data(config)


make_iterators(): M_train: 19988


In [127]:
model = build_model(config)

Model: "TwoStageFcModel"
____________________________________________________________________________________________________________________________________________
 Layer (type)                                                  Output Shape                                            Param #              
 input (InputLayer)                                            [(None, 54, 128)]                                       0                    
                                                                                                                                            
 two_stage_fc_layer_3 (TwoStageFcLayer)                        (None, 128, 1024)                                       139136               
                                                                                                                                            
 two_stage_fc_layer_4 (TwoStageFcLayer)                        (None, 16, 256)                                         264464    

In [128]:
trainer = build_trainer(model, data, config)



In [134]:
def compound_gaussian_split_aux_trainer(mat, label, aux):
    return mat, label


In [135]:
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, Callback
history = trainer.train()



Epoch 1/300

Epoch 1: saving model to ../results/temp/temp_2023-07-11_14-46-44_3775404699870012136/checkpoints/model_checkpoint_epoch
INFO:tensorflow:Assets written to: ../results/temp/temp_2023-07-11_14-46-44_3775404699870012136/checkpoints/model_checkpoint_epoch/assets


INFO:tensorflow:Assets written to: ../results/temp/temp_2023-07-11_14-46-44_3775404699870012136/checkpoints/model_checkpoint_epoch/assets


78/78 - 28s - loss: 1.4326 - binary_crossentropy: 0.6961 - positive_binary_cross_entropy: 0.6632 - negative_binary_cross_entropy: 0.6973 - auc: 0.5702 - accuracy: 0.0232 - fp: 638883.0000 - tp: 26703.0000 - val_loss: 1.5388 - val_binary_crossentropy: 0.6813 - val_positive_binary_cross_entropy: 0.5969 - val_negative_binary_cross_entropy: 0.6854 - val_auc: 0.6881 - val_accuracy: 0.0553 - val_fp: 42029.0000 - val_tp: 3125.0000 - 28s/epoch - 354ms/step
Epoch 2/300


KeyboardInterrupt: 