Skip to content

Commit

Permalink
improve ray performance using generators
Browse files Browse the repository at this point in the history
  • Loading branch information
aromanielloNTIA committed Apr 26, 2023
1 parent 0d68dd3 commit 4b7ae39
Showing 1 changed file with 69 additions and 57 deletions.
126 changes: 69 additions & 57 deletions scos_actions/actions/acquire_sea_data_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Currently in development.
"""
import gc
import logging
import lzma
from time import perf_counter
Expand Down Expand Up @@ -59,10 +60,7 @@
calculate_pseudo_power,
create_power_detector,
)
from scos_actions.signal_processing.unit_conversion import (
convert_linear_to_dB,
convert_watts_to_dBm,
)
from scos_actions.signal_processing.unit_conversion import convert_linear_to_dB
from scos_actions.signals import measurement_action_completed, trigger_api_restart
from scos_actions.status import start_time
from scos_actions.utils import convert_datetime_to_millisecond_iso_format, get_days_up
Expand Down Expand Up @@ -121,7 +119,7 @@
}


@ray.remote
@ray.remote(num_returns=2)
def get_fft_results(iqdata: np.ndarray, params: dict) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute data product mean/max FFT results from IQ samples.
Expand Down Expand Up @@ -151,8 +149,12 @@ def get_fft_results(iqdata: np.ndarray, params: dict) -> Tuple[np.ndarray, np.nd
fft_result = apply_power_detector(fft_result, FFT_DETECTOR) # (max, mean)
fft_result /= IMPEDANCE_OHMS # Finish conversion to Watts
fft_result = np.fft.fftshift(fft_result, axes=(1,)) # Shift frequencies
fft_result = convert_watts_to_dBm(fft_result)
fft_result -= 3 # Baseband/RF power conversion
# Scaling:
# - Convert Watts to dBm (10*log10(fft_result) + 30)
# - Baseband/RF power conversion (fft_result - 3)
# Note: convert_watts_to_dBm() is not used to avoid a NumExpr usage
# for this operation on a relatively small array
fft_result = 10.0 * np.log10(fft_result) + 27
fft_result -= 10.0 * np.log10(params[SAMPLE_RATE] * FFT_SIZE) # PSD scaling
fft_result += 2.0 * convert_linear_to_dB(FFT_WINDOW_ECF) # Window energy correction

Expand All @@ -164,10 +166,12 @@ def get_fft_results(iqdata: np.ndarray, params: dict) -> Tuple[np.ndarray, np.nd
bin_end = FFT_SIZE - bin_start # bin_end = 750 with FFT_SIZE 875
fft_result = fft_result[:, bin_start:bin_end] # See comments above

return fft_result[0], fft_result[1]
yield fft_result[0] # Max detector result
yield fft_result[1] # Mean detector result
del fft_result


@ray.remote
@ray.remote(num_returns=1)
def get_apd_results(iqdata: np.ndarray, params: dict) -> np.ndarray:
"""
Generate downsampled APD result from IQ samples.
Expand All @@ -181,17 +185,18 @@ def get_apd_results(iqdata: np.ndarray, params: dict) -> np.ndarray:
# Scale input to get_apd to account for:
# dBm -> dBW (-30)
# baseband -> RF power reference (+3)
p, a = get_apd(
p, _ = get_apd(
iqdata,
params[APD_BIN_SIZE_DB],
params[APD_MIN_BIN_DBM] - 27.0,
params[APD_MAX_BIN_DBM] - 27.0,
IMPEDANCE_OHMS,
)
return p
yield p
del p, _


@ray.remote
@ray.remote(num_returns=4)
def get_td_power_results(
iqdata: np.ndarray, params: dict
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
Expand All @@ -213,30 +218,37 @@ def get_td_power_results(
while the length of the other arrays depends on the configured
detector period.
"""
# Reshape IQ data into blocks
# Reshape IQ data into blocks and calculate power
block_size = int(params[TD_BIN_SIZE_MS] * params[SAMPLE_RATE] * 1e-3)
n_blocks = len(iqdata) // block_size
iqdata = iqdata.reshape((n_blocks, block_size))
iq_pwr = calculate_power_watts(iqdata, IMPEDANCE_OHMS)
iq_pwr = calculate_power_watts(
iqdata.reshape((n_blocks, block_size)), IMPEDANCE_OHMS
)

# Apply max/mean detectors
td_result = apply_power_detector(iq_pwr, TD_DETECTOR, axis=1)

# Get single value median/max statistics
td_channel_result = np.array([np.max(td_result[0]), np.median(td_result[1])])
td_channel_result = np.array([td_result[0].max(), np.median(td_result[1])])

# Convert to dBm and account for RF/baseband power difference
# Note: convert_watts_to_dBm is not used to avoid NumExpr usage
# for the relatively small arrays
td_result, td_channel_result = (
convert_watts_to_dBm(x) - 3.0 for x in [td_result, td_channel_result]
10.0 * np.log10(x) + 27.0 for x in [td_result, td_channel_result]
)

channel_max, channel_median = (np.array(a) for a in td_channel_result)
yield td_result[0] # Max detector result
yield td_result[1] # Mean detector result

# packed order of td_result is (max, mean)
return td_result[0], td_result[1], channel_max, channel_median
# Get channel summary statistics as 0-dim NumPy arrays
# Order is max-of-max, median-of-mean
for a in td_channel_result:
yield np.array(a)
del td_result


@ray.remote
@ray.remote(num_returns=6)
def get_periodic_frame_power(
iqdata: np.ndarray,
params: dict,
Expand Down Expand Up @@ -278,8 +290,7 @@ def get_periodic_frame_power(
chunked_shape = (iqdata.shape[0] // Nframes, Npts, Nframes // Npts) + tuple(
[iqdata.shape[1]] if iqdata.ndim == 2 else []
)
iq_bins = iqdata.reshape(chunked_shape)
power_bins = calculate_pseudo_power(iq_bins)
power_bins = calculate_pseudo_power(iqdata.reshape(chunked_shape))

# compute statistics first by cycle
mean_power = power_bins.mean(axis=0)
Expand All @@ -296,31 +307,35 @@ def get_periodic_frame_power(
# Finish conversion to power
pfp /= IMPEDANCE_OHMS

# Convert to dBm
pfp = convert_watts_to_dBm(pfp)
pfp -= 3 # RF/baseband
return tuple(pfp)
# Convert to dBm and subtract 3 dB for baseband/RF power conversion
# Note: convert_watts_to_dBm is not used here to avoid NumExpr
# usage for the relatively small array
pfp = 10.0 * np.log10(pfp) + 27.0

# Yield detector results one-at-a-time
yield from pfp
del power_bins, mean_power, max_power, pfp

@ray.remote

@ray.remote(num_returns=4)
def generate_data_product(
iqdata: ray.ObjectRef, params: dict, iir_sos: np.ndarray
iqdata: np.ndarray, params: dict, iir_sos: np.ndarray
) -> list:
"""Process IQ data and generate the SEA data product."""
print(f"GEN_DP @ {params[FREQUENCY]/1e6:.1f}: IQ data is {type(iqdata)}")
iqdata = sosfilt(iir_sos, iqdata)
print(f"FILTERED IQ @ {params[FREQUENCY]/1e6:.1f}: data is {type(iqdata)}")
iqdata_ref = ray.put(iqdata)
del iqdata
remote_procs = [
get_fft_results.remote(iqdata_ref, params),
get_td_power_results.remote(iqdata_ref, params),
get_periodic_frame_power.remote(iqdata_ref, params),
get_apd_results.remote(iqdata_ref, params),
]

# Return identifiers to avoid waiting for processing to complete
return remote_procs
procs = []
procs.append(get_fft_results.remote(iqdata_ref, params))
procs.append(get_td_power_results.remote(iqdata_ref, params))
procs.append(get_periodic_frame_power.remote(iqdata_ref, params))
procs.append(get_apd_results.remote(iqdata_ref, params))

for p in procs:
yield ray.get(p)

del iqdata, procs, p
gc.collect()


class NasctnSeaDataProduct(Action):
Expand Down Expand Up @@ -410,14 +425,14 @@ def __call__(self, schedule_entry, task_id):
measurement_result = self.capture_iq(parameters)
# Start data product processing but do not block next IQ capture
tic = perf_counter()
iqdata_ref = ray.put(measurement_result["data"])
toc = perf_counter()
logger.debug(
f"Called ray.put for channel IQ in {toc-tic:.2f} s: {measurement_result['data'][:5]}"
)
dp_procs.append(
generate_data_product.remote(iqdata_ref, parameters, self.iir_sos)
generate_data_product.remote(
measurement_result["data"], parameters, self.iir_sos
)
)
toc = perf_counter()
logger.debug(f"IQ data delivered for processing in {toc-tic:.2f} s")
del measurement_result["data"]
# Generate capture metadata before sigan reconfigured
cap_meta_tuple = self.create_channel_metadata(measurement_result)
cap_meta.append(cap_meta_tuple[0])
Expand All @@ -444,26 +459,23 @@ def __call__(self, schedule_entry, task_id):
# Now wait for channel data to be processed
channel_data = []
tic = perf_counter()
for j, d in enumerate(ray.get(channel_data_process)):
if j == 3:
# APD requires different handling
for j, d in enumerate(channel_data_process):
if j == 1: # Power-vs-Time results
data = ray.get(d)
channel_data.extend(data[:2])
max_max_ch_pwrs.append(DATA_TYPE(data[2]))
med_mean_ch_pwrs.append(DATA_TYPE(data[3]))
if j == 3: # APD results
channel_data.append(ray.get(d))
else:
channel_data.extend(ray.get(d))
toc = perf_counter()
logger.debug(f"Waited {toc-tic} s for channel {i} data")

# Pull out single value channel powers (max of max, median of mean)
max_max_ch_pwrs.append(DATA_TYPE(channel_data[4]))
med_mean_ch_pwrs.append(DATA_TYPE(channel_data[5]))
del channel_data[4:6]

all_data.extend(NasctnSeaDataProduct.transform_data(channel_data))
last_data_len = len(all_data)
result_toc = perf_counter()
logger.debug(f"Got all processed data in {result_toc-result_tic:.2f} s")

del dp_procs
logger.debug(f"Got all processed data in {result_toc-result_tic:.2f} s")

# Build metadata and convert data to compressed bytes
all_data = self.compress_bytes_data(np.array(all_data).tobytes())
Expand Down

0 comments on commit 4b7ae39

Please sign in to comment.