From 94dbc9c4c06b1a38b8de81cf31a6b76d85881ba4 Mon Sep 17 00:00:00 2001 From: Garrett Vanhoy Date: Tue, 8 Nov 2022 16:15:08 -0500 Subject: [PATCH] Adding necessary transforms and visualizations. --- .gitignore | 3 + torchsig/transforms/__init__.py | 2 + .../deep_learning_techniques/__init__.py | 2 +- .../deep_learning_techniques/dlt.py | 455 +++++++- .../deep_learning_techniques/functional.py | 102 ++ .../transforms/expert_feature/__init__.py | 2 +- torchsig/transforms/expert_feature/eft.py | 8 +- .../transforms/expert_feature/functional.py | 201 ++++ torchsig/transforms/functional.py | 30 +- .../transforms/signal_processing/__init__.py | 2 +- .../signal_processing/functional.py | 92 ++ torchsig/transforms/signal_processing/sp.py | 2 +- .../spectrogram_transforms/__init__.py | 2 + .../spectrogram_transforms/functional.py | 168 +++ .../transforms/spectrogram_transforms/spec.py | 744 +++++++++++++ .../transforms/system_impairment/__init__.py | 2 +- .../system_impairment/functional.py | 635 ++++++++++++ torchsig/transforms/system_impairment/si.py | 196 +++- .../target_transforms/target_transforms.py | 974 +++++++++++++++++- torchsig/transforms/transforms.py | 30 + .../transforms/wireless_channel/__init__.py | 2 +- .../transforms/wireless_channel/functional.py | 163 +++ torchsig/transforms/wireless_channel/wce.py | 34 +- torchsig/utils/visualize.py | 431 +++++++- 24 files changed, 4184 insertions(+), 98 deletions(-) create mode 100644 .gitignore create mode 100644 torchsig/transforms/deep_learning_techniques/functional.py create mode 100644 torchsig/transforms/expert_feature/functional.py create mode 100644 torchsig/transforms/signal_processing/functional.py create mode 100644 torchsig/transforms/spectrogram_transforms/__init__.py create mode 100644 torchsig/transforms/spectrogram_transforms/functional.py create mode 100644 torchsig/transforms/spectrogram_transforms/spec.py create mode 100644 torchsig/transforms/system_impairment/functional.py create mode 100644 torchsig/transforms/wireless_channel/functional.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dfc9d4a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pyc +*.mdb diff --git a/torchsig/transforms/__init__.py b/torchsig/transforms/__init__.py index 8737964..8c8780a 100644 --- a/torchsig/transforms/__init__.py +++ b/torchsig/transforms/__init__.py @@ -2,6 +2,7 @@ from . import system_impairment from . import wireless_channel from . import signal_processing +from . import spectrogram_transforms from . import deep_learning_techniques from . import target_transforms from .functional import * @@ -10,5 +11,6 @@ from torchsig.transforms.wireless_channel import * from torchsig.transforms.expert_feature import * from torchsig.transforms.signal_processing import * +from torchsig.transforms.spectrogram_transforms import * from torchsig.transforms.deep_learning_techniques import * from torchsig.transforms.target_transforms import * diff --git a/torchsig/transforms/deep_learning_techniques/__init__.py b/torchsig/transforms/deep_learning_techniques/__init__.py index 8aeecd5..58b4958 100644 --- a/torchsig/transforms/deep_learning_techniques/__init__.py +++ b/torchsig/transforms/deep_learning_techniques/__init__.py @@ -1,2 +1,2 @@ from .dlt import * -from .dlt_functional import * +from .functional import * diff --git a/torchsig/transforms/deep_learning_techniques/dlt.py b/torchsig/transforms/deep_learning_techniques/dlt.py index 3be8742..4f996a0 100644 --- a/torchsig/transforms/deep_learning_techniques/dlt.py +++ b/torchsig/transforms/deep_learning_techniques/dlt.py @@ -1,13 +1,15 @@ import numpy as np from copy import deepcopy -from typing import Tuple, List, Any, Union, Optional +from typing import Tuple, List, Any, Union, Optional, Callable from torchsig.utils import SignalDescription, SignalData, SignalDataset from torchsig.transforms.transforms import SignalTransform +from torchsig.transforms.signal_processing import Normalize from torchsig.transforms.wireless_channel import TargetSNR from torchsig.transforms.functional import to_distribution, uniform_continuous_distribution, uniform_discrete_distribution -from torchsig.transforms.functional import NumericParameter, FloatParameter -from torchsig.transforms.deep_learning_techniques import dlt_functional +from torchsig.transforms.functional import NumericParameter, FloatParameter, IntParameter +from torchsig.transforms.deep_learning_techniques import functional +from torchsig.transforms.expert_feature import functional as eft_f class DatasetBasebandMixUp(SignalTransform): @@ -23,7 +25,7 @@ class DatasetBasebandMixUp(SignalTransform): than zero. This transform is loosely based on - `"mixup: Beyond Empirical Risk Minimization" `_. + `"mixup: Beyond Emperical Risk Minimization" `_. Args: @@ -356,10 +358,10 @@ def __call__(self, data: Any) -> Any: new_data.signal_description = new_signal_description # Perform data augmentation - new_data.iq_data = dlt_functional.cut_out(data.iq_data, cut_start, cut_dur, cut_type) + new_data.iq_data = functional.cut_out(data.iq_data, cut_start, cut_dur, cut_type) else: - new_data = dlt_functional.cut_out(data, cut_start, cut_dur, cut_type) + new_data = functional.cut_out(data, cut_start, cut_dur, cut_type) return new_data @@ -408,9 +410,446 @@ def __call__(self, data: Any) -> Any: ) # Perform data augmentation - new_data.iq_data = dlt_functional.patch_shuffle(data.iq_data, patch_size, shuffle_ratio) + new_data.iq_data = functional.patch_shuffle(data.iq_data, patch_size, shuffle_ratio) else: - new_data = dlt_functional.patch_shuffle(data, patch_size, shuffle_ratio) + new_data = functional.patch_shuffle(data, patch_size, shuffle_ratio) return new_data + +class DatasetWidebandCutMix(SignalTransform): + """SignalTransform that inputs a dataset to randomly sample from and insert + into the main dataset's examples, using an additional `alpha` input to set + the relative quantity in time to occupy, where + + cutmix_num_iq_samples = total_num_iq_samples * alpha + + This transform is loosely based on [CutMix: Regularization Strategy to + Train Strong Classifiers with Localizable Features] + (https://arxiv.org/pdf/1710.09412.pdf). + + Args: + dataset :obj:`SignalDataset`: + An SignalDataset of complex-valued examples to be used as a source for + the synthetic insertion/mixup + + alpha (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + alpha sets the difference in durations between the main dataset + example and the inserted example + * If Callable, produces a sample by calling alpha() + * If int or float, alpha is fixed at the value provided + * If list, alpha is any element in the list + * If tuple, alpha is in range of (tuple[0], tuple[1]) + + Example: + >>> import torchsig.transforms as ST + >>> from torchsig.datasets import WidebandSig53 + >>> # Add signals from the `ModulationsDataset` + >>> dataset = WidebandSig53('.') + >>> transform = ST.DatasetWidebandCutMix(dataset=dataset,alpha=(0.2,0.7)) + + """ + def __init__( + self, + dataset: SignalDataset = None, + alpha: NumericParameter = uniform_continuous_distribution(0.2, 0.7), + ): + super(DatasetWidebandCutMix, self).__init__() + self.alpha = to_distribution(alpha, self.random_generator) + self.dataset = dataset + self.dataset_num_samples = len(dataset) + + def __call__(self, data: Any) -> Any: + alpha = self.alpha() + if isinstance(data, SignalData): + # Randomly sample from provided dataset + idx = np.random.randint(self.dataset_num_samples) + insert_data, insert_signal_description = self.dataset[idx] + num_iq_samples = data.iq_data.shape[0] + if insert_data.shape[0] != num_iq_samples: + raise ValueError( + "Input dataset's `num_iq_samples` does not match main dataset.\n\t\ + Found {}, but expected {} samples" + .format(insert_data.shape[0],data.shape[0]) + ) + + # Mask both data examples based on alpha and a random start value + insert_num_iq_samples = int(alpha * num_iq_samples) + insert_start = np.random.randint(num_iq_samples - insert_num_iq_samples) + insert_stop = insert_start+insert_num_iq_samples + data.iq_data[insert_start:insert_stop] = 0 + insert_data[:insert_start] = 0 + insert_data[insert_stop:] = 0 + insert_start /= num_iq_samples + insert_dur = insert_num_iq_samples / num_iq_samples + + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=[], + ) + new_data.iq_data = data.iq_data + insert_data + + # Update SignalDescription + new_signal_description = [] + signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Update labels + if new_signal_desc.start > insert_start and new_signal_desc.start < insert_start + insert_dur: + # Label starts within cut region + if new_signal_desc.stop > insert_start and new_signal_desc.stop < insert_start + insert_dur: + # Label also stops within cut region --> Remove label + continue + else: + # Push label start to end of cut region + new_signal_desc.start = insert_start + insert_dur + elif new_signal_desc.stop > insert_start and new_signal_desc.stop < insert_start + insert_dur: + # Label stops within cut region but does not start in region --> Push stop to begining of cut region + new_signal_desc.stop = insert_start + elif new_signal_desc.start < insert_start and new_signal_desc.stop > insert_start + insert_dur: + # Label traverse cut region --> Split into two labels + new_signal_desc_split = deepcopy(signal_desc) + # Update first label region's stop + new_signal_desc.stop = insert_start + # Update second label region's start & append to description collection + new_signal_desc_split.start = insert_start + insert_dur + new_signal_description.append(new_signal_desc_split) + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + # Repeat for inserted example's SignalDescription(s) + for insert_signal_desc in insert_signal_description: + # Update labels + if insert_signal_desc.stop < insert_start or insert_signal_desc.start > insert_start + insert_dur: + # Label is outside inserted region --> Remove label + continue + elif insert_signal_desc.start < insert_start and insert_signal_desc.stop < insert_start + insert_dur: + # Label starts before and ends within region, push start to region start + insert_signal_desc.start = insert_start + elif insert_signal_desc.start >= insert_start and insert_signal_desc.stop > insert_start + insert_dur: + # Label starts within region and stops after, push stop to region stop + insert_signal_desc.stop = insert_start + insert_dur + elif insert_signal_desc.start < insert_start and insert_signal_desc.stop > insert_start + insert_dur: + # Label starts before and stops after, push both start & stop to region boundaries + insert_signal_desc.start = insert_start + insert_signal_desc.stop = insert_start + insert_dur + + # Append SignalDescription to list + new_signal_description.append(insert_signal_desc) + + # Set output data's SignalDescription to above list + new_data.signal_description = new_signal_description + + return new_data + else: + raise ValueError( + "Expected input type `SignalData`. Received {}. \n\t\ + The `DatasetWidebandCutMix` transform depends on metadata from a `SignalData` object." + .format(type(data)) + ) + + +class DatasetWidebandMixUp(SignalTransform): + """SignalTransform that inputs a dataset to randomly sample from and insert + into the main dataset's examples, using the `alpha` input to set the + difference in magnitudes between the two examples with the following + relationship: + + output_sample = main_sample * (1 - alpha) + mixup_sample * alpha + + This transform is loosely based on [mixup: Beyond Emperical Risk + Minimization](https://arxiv.org/pdf/1710.09412.pdf). + + Args: + dataset :obj:`SignalDataset`: + An SignalDataset of complex-valued examples to be used as a source for + the synthetic insertion/mixup + + alpha (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + alpha sets the difference in power level between the main dataset + example and the inserted example + * If Callable, produces a sample by calling alpha() + * If int or float, alpha is fixed at the value provided + * If list, alpha is any element in the list + * If tuple, alpha is in range of (tuple[0], tuple[1]) + + Example: + >>> import torchsig.transforms as ST + >>> from torchsig.datasets import WidebandSig53 + >>> # Add signals from the `WidebandSig53` Dataset + >>> dataset = WidebandSig53('.') + >>> transform = ST.DatasetWidebandMixUp(dataset=dataset,alpha=(0.4,0.6)) + + """ + def __init__( + self, + dataset: SignalDataset = None, + alpha: NumericParameter = uniform_continuous_distribution(0.4, 0.6), + ): + super(DatasetWidebandMixUp, self).__init__() + self.alpha = to_distribution(alpha, self.random_generator) + self.dataset = dataset + self.dataset_num_samples = len(dataset) + + def __call__(self, data: Any) -> Any: + alpha = self.alpha() + if isinstance(data, SignalData): + # Randomly sample from provided dataset + idx = np.random.randint(self.dataset_num_samples) + insert_data, insert_signal_description = self.dataset[idx] + if insert_data.shape[0] != data.iq_data.shape[0]: + raise ValueError( + "Input dataset's `num_iq_samples` does not match main dataset.\n\t\ + Found {}, but expected {} samples" + .format(insert_data.shape[0],data.shape[0]) + ) + + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=[], + ) + new_data.iq_data = data.iq_data * (1 - alpha) + insert_data * alpha + + # Update SignalDescription + new_signal_description = [] + new_signal_description.extend(data.signal_description) + new_signal_description.extend(insert_signal_description) + new_data.signal_description = new_signal_description + + return new_data + else: + raise ValueError( + "Expected input type `SignalData`. Received {}. \n\t\ + The `DatasetWidebandMixUp` transform depends on metadata from a `SignalData` object." + .format(type(data)) + ) + + +class SpectrogramRandomResizeCrop(SignalTransform): + """The SpectrogramRandomResizeCrop transforms the input IQ data into a + spectrogram with a randomized FFT size and overlap. This randomization in + the spectrogram computation results in spectrograms of various sizes. The + width and height arguments specify the target output size of the transform. + To get to the desired size, the randomly generated spectrogram may be + randomly cropped or padded in either the time or frequency dimensions. This + transform is meant to emulate the Random Resize Crop transform often used + in computer vision tasks. + + Args: + nfft (:py:class:`~Callable`, :obj:`int`, :obj:`list`, :obj:`tuple`): + The number of FFT bins for the random spectrogram. + * If Callable, nfft is set by calling nfft() + * If int, nfft is fixed by value provided + * If list, nfft is any element in the list + * If tuple, nfft is in range of (tuple[0], tuple[1]) + overlap_ratio (:py:class:`~Callable`, :obj:`int`, :obj:`list`, :obj:`tuple`): + The ratio of the (nfft-1) value to use as the overlap parameter for + the spectrogram operation. Setting as ratio ensures the overlap is + a lower value than the bin size. + * If Callable, nfft is set by calling overlap_ratio() + * If float, overlap_ratio is fixed by value provided + * If list, overlap_ratio is any element in the list + * If tuple, overlap_ratio is in range of (tuple[0], tuple[1]) + window_fcn (:obj:`str`): + Window to be used in spectrogram operation. + Default value is 'np.blackman'. + mode (:obj:`str`): + Mode of the spectrogram to be computed. + Default value is 'complex'. + width (:obj:`int`): + Target output width (time) of the spectrogram + height (:obj:`int`): + Target output height (frequency) of the spectrogram + + Example: + >>> import torchsig.transforms as ST + >>> # Randomly sample NFFT size in range [128,1024] and randomly crop/pad output spectrogram to (512,512) + >>> transform = ST.SpectrogramRandomResizeCrop(nfft=(128,1024), overlap_ratio=(0.0,0.2), width=512, height=512) + + """ + def __init__( + self, + nfft: IntParameter = (256, 1024), + overlap_ratio: FloatParameter = (0.0,0.2), + window_fcn: Callable[[int], np.ndarray] = np.blackman, + mode: str = 'complex', + width: int = 512, + height: int = 512, + ): + super(SpectrogramRandomResizeCrop, self).__init__() + self.nfft = to_distribution(nfft, self.random_generator) + self.overlap_ratio = to_distribution(overlap_ratio, self.random_generator) + self.window_fcn = window_fcn + self.mode = mode + self.width = width + self.height = height + + def __call__(self, data: Any) -> Any: + nfft = int(self.nfft()) + nperseg = nfft + overlap_ratio = self.overlap_ratio() + noverlap = int(overlap_ratio * (nfft-1)) + + iq_data = data.iq_data if isinstance(data, SignalData) else data + + # First, perform the random spectrogram operation + spec_data = eft_f.spectrogram(iq_data, nperseg, noverlap, nfft, self.window_fcn, self.mode) + if self.mode == "complex": + new_tensor = np.zeros((2, spec_data.shape[0], spec_data.shape[1]), dtype=np.float32) + new_tensor[0, :, :] = np.real(spec_data).astype(np.float32) + new_tensor[1, :, :] = np.imag(spec_data).astype(np.float32) + spec_data = new_tensor + + # Next, perform the random cropping/padding + channels, curr_height, curr_width = spec_data.shape + pad_height, crop_height = False, False + pad_width, crop_width = False, False + pad_height_samps, pad_width_samps = 0, 0 + if curr_height < self.height: + pad_height = True + pad_height_samps = self.height - curr_height + elif curr_height > self.height: + crop_height = True + if curr_width < self.width: + pad_width = True + pad_width_samps = self.width - curr_width + elif curr_width > self.width: + crop_width = True + + if pad_height or pad_width: + def pad_func(vector, pad_width, iaxis, kwargs): + vector[:pad_width[0]] = np.random.rand(len(vector[:pad_width[0]]))*kwargs['pad_value'] + vector[-pad_width[1]:] = np.random.rand(len(vector[-pad_width[1]:]))*kwargs['pad_value'] + + pad_height_start = np.random.randint(0,pad_height_samps//2+1) + pad_height_end = pad_height_samps - pad_height_start + 1 + pad_width_start = np.random.randint(0,pad_width_samps//2+1) + pad_width_end = pad_width_samps - pad_width_start + 1 + + if self.mode == "complex": + new_data_real = np.pad( + spec_data[0], + ( + (pad_height_start,pad_height_end), + (pad_width_start,pad_width_end), + ), + pad_func, + pad_value = np.percentile(np.abs(spec_data[0]),50), + ) + new_data_imag = np.pad( + spec_data[1], + ( + (pad_height_start,pad_height_end), + (pad_width_start,pad_width_end), + ), + pad_func, + pad_value = np.percentile(np.abs(spec_data[1]),50), + ) + spec_data = np.concatenate( + [ + np.expand_dims(new_data_real,axis=0), + np.expand_dims(new_data_imag,axis=0) + ], + axis=0, + ) + else: + spec_data = np.pad( + spec_data, + ( + (pad_height_start,pad_height_end), + (pad_width_start,pad_width_end), + ), + pad_func, + min_value = np.percentile(np.abs(spec_data[0]),50), + ) + + crop_width_start = np.random.randint(0,max(1,curr_width-self.width)) + crop_height_start = np.random.randint(0,max(1,curr_height-self.height)) + spec_data = spec_data[ + :, + crop_height_start:crop_height_start+self.height, + crop_width_start:crop_width_start+self.width, + ] + + # Update SignalData object if necessary, otherwise return + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=[], + ) + new_data.iq_data = spec_data + + # Update SignalDescription + new_signal_description = [] + signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Check bounds for partial signals + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + # Update labels based on padding/cropping + if pad_height: + new_signal_desc.lower_frequency = ((new_signal_desc.lower_frequency+0.5)*curr_height + pad_height_start) / self.height - 0.5 + new_signal_desc.upper_frequency = ((new_signal_desc.upper_frequency+0.5)*curr_height + pad_height_start) / self.height - 0.5 + new_signal_desc.center_frequency = ((new_signal_desc.center_frequency+0.5)*curr_height + pad_height_start) / self.height - 0.5 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + + if crop_height: + if (new_signal_desc.lower_frequency+0.5)*curr_height >= crop_height_start+self.height or \ + (new_signal_desc.upper_frequency+0.5)*curr_height <= crop_height_start: + continue + if (new_signal_desc.lower_frequency+0.5)*curr_height <= crop_height_start: + new_signal_desc.lower_frequency = -0.5 + else: + new_signal_desc.lower_frequency = ((new_signal_desc.lower_frequency+0.5)*curr_height - crop_height_start) / self.height - 0.5 + if (new_signal_desc.upper_frequency+0.5)*curr_height >= crop_height_start+self.height: + new_signal_desc.upper_frequency = crop_height_start+self.height + else: + new_signal_desc.upper_frequency = ((new_signal_desc.upper_frequency+0.5)*curr_height - crop_height_start) / self.height - 0.5 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth / 2 + + if pad_width: + new_signal_desc.start = (new_signal_desc.start * curr_width + pad_width_start) / self.width + new_signal_desc.stop = (new_signal_desc.stop * curr_width + pad_width_start) / self.width + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + if crop_width: + if new_signal_desc.start*curr_width <= crop_width_start: + new_signal_desc.start = 0.0 + elif new_signal_desc.start*curr_width >= crop_width_start+self.width: + continue + else: + new_signal_desc.start = (new_signal_desc.start * curr_width - crop_width_start) / self.width + if new_signal_desc.stop*curr_width >= crop_width_start+self.width: + new_signal_desc.stop = 1.0 + elif new_signal_desc.stop*curr_width <= crop_width_start: + continue + else: + new_signal_desc.stop = (new_signal_desc.stop * curr_width - crop_width_start) / self.width + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + new_data.signal_description = new_signal_description + + else: + new_data = spec_data + + return new_data diff --git a/torchsig/transforms/deep_learning_techniques/functional.py b/torchsig/transforms/deep_learning_techniques/functional.py new file mode 100644 index 0000000..e9ea386 --- /dev/null +++ b/torchsig/transforms/deep_learning_techniques/functional.py @@ -0,0 +1,102 @@ +import numpy as np + + +def cut_out( + tensor: np.ndarray, + cut_start: float, + cut_dur: float, + cut_type: str, +) -> np.ndarray: + """Performs the CutOut using the input parameters + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + cut_start: (:obj:`float`): + Start of cut region in range [0.0,1.0) + + cut_dur: (:obj:`float`): + Duration of cut region in range (0.0,1.0] + + cut_type: (:obj:`str`): + String specifying type of data to fill in cut region with + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone cut out + + """ + num_iq_samples = tensor.shape[0] + cut_start = int(cut_start * num_iq_samples) + + # Create cut mask + cut_mask_length = int(num_iq_samples * cut_dur) + if cut_mask_length + cut_start > num_iq_samples: + cut_mask_length = num_iq_samples - cut_start + + if cut_type == "zeros": + cut_mask = np.zeros(cut_mask_length, dtype=np.complex64) + elif cut_type == "ones": + cut_mask = np.ones(cut_mask_length) + 1j*np.ones(cut_mask_length) + elif cut_type == "low_noise": + real_noise = np.random.randn(cut_mask_length) + imag_noise = np.random.randn(cut_mask_length) + noise_power_db = -100 + cut_mask = (10.0**(noise_power_db/20.0))*(real_noise + 1j*imag_noise)/np.sqrt(2) + elif cut_type == "avg_noise": + real_noise = np.random.randn(cut_mask_length) + imag_noise = np.random.randn(cut_mask_length) + avg_power = np.mean(np.abs(tensor)**2) + cut_mask = avg_power*(real_noise + 1j*imag_noise)/np.sqrt(2) + elif cut_type == "high_noise": + real_noise = np.random.randn(cut_mask_length) + imag_noise = np.random.randn(cut_mask_length) + noise_power_db = 40 + cut_mask = (10.0**(noise_power_db/20.0))*(real_noise + 1j*imag_noise)/np.sqrt(2) + else: + raise ValueError("cut_type must be: zeros, ones, low_noise, avg_noise, or high_noise. Found: {}".format(cut_type)) + + # Insert cut mask into tensor + tensor[cut_start:cut_start+cut_mask_length] = cut_mask + + return tensor + + +def patch_shuffle( + tensor: np.ndarray, + patch_size: int, + shuffle_ratio: float, +) -> np.ndarray: + """Apply shuffling of patches specified by `num_patches` + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + patch_size (:obj:`int`): + Size of each patch to shuffle + + shuffle_ratio (:obj:`float`): + Ratio of patches to shuffle + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone patch shuffling + + """ + num_patches = int(tensor.shape[0] / patch_size) + num_to_shuffle = int(num_patches * shuffle_ratio) + patches_to_shuffle = np.random.choice( + num_patches, + replace=False, + size=num_to_shuffle, + ) + + for patch_idx in patches_to_shuffle: + patch_start = int(patch_idx * patch_size) + patch = tensor[patch_start:patch_start+patch_size] + np.random.shuffle(patch) + tensor[patch_start:patch_start+patch_size] = patch + + return tensor diff --git a/torchsig/transforms/expert_feature/__init__.py b/torchsig/transforms/expert_feature/__init__.py index 01ac526..ec96146 100644 --- a/torchsig/transforms/expert_feature/__init__.py +++ b/torchsig/transforms/expert_feature/__init__.py @@ -1,2 +1,2 @@ from .eft import * -from .eft_functional import * +from .functional import * diff --git a/torchsig/transforms/expert_feature/eft.py b/torchsig/transforms/expert_feature/eft.py index aecbfdb..ea79d6a 100644 --- a/torchsig/transforms/expert_feature/eft.py +++ b/torchsig/transforms/expert_feature/eft.py @@ -2,7 +2,7 @@ from typing import Callable, Tuple, Any from torchsig.utils.types import SignalData -from torchsig.transforms.expert_feature import eft_functional as F +from torchsig.transforms.expert_feature import functional as F from torchsig.transforms.transforms import SignalTransform @@ -170,7 +170,7 @@ def __call__(self, data: Any) -> Any: class Spectrogram(SignalTransform): - """ Calculates power spectral density over time + """Calculates power spectral density over time Args: nperseg (:obj:`int`): @@ -224,14 +224,14 @@ def __init__( def __call__(self, data: Any) -> Any: if isinstance(data, SignalData): - data.iq_data = F.spectrogram(data.iq_data) + data.iq_data = F.spectrogram(data.iq_data, self.nperseg, self.noverlap, self.nfft, self.window_fcn, self.mode) if self.mode == "complex": new_tensor = np.zeros((2, data.iq_data.shape[0], data.iq_data.shape[1]), dtype=np.float32) new_tensor[0, :, :] = np.real(data.iq_data).astype(np.float32) new_tensor[1, :, :] = np.imag(data.iq_data).astype(np.float32) data.iq_data = new_tensor else: - data = F.spectrogram(data) + data = F.spectrogram(data, self.nperseg, self.noverlap, self.nfft, self.window_fcn, self.mode) if self.mode == "complex": new_tensor = np.zeros((2, data.shape[0], data.shape[1]), dtype=np.float32) new_tensor[0, :, :] = np.real(data).astype(np.float32) diff --git a/torchsig/transforms/expert_feature/functional.py b/torchsig/transforms/expert_feature/functional.py new file mode 100644 index 0000000..729d12d --- /dev/null +++ b/torchsig/transforms/expert_feature/functional.py @@ -0,0 +1,201 @@ +import pywt +import numpy as np +from scipy import signal +from typing import Callable + + +def interleave_complex(tensor: np.ndarray) -> np.ndarray: + """Converts complex vectors to real interleaved IQ vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + Interleaved vectors. + """ + new_tensor = np.empty((tensor.shape[0]*2,)) + new_tensor[::2] = np.real(tensor) + new_tensor[1::2] = np.imag(tensor) + return new_tensor + + +def complex_to_2d(tensor: np.ndarray) -> np.ndarray: + """Converts complex IQ to two channels representing real and imaginary + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + Expanded vectors + """ + + new_tensor = np.zeros((2, tensor.shape[0]), dtype=np.float64) + new_tensor[0] = np.real(tensor).astype(np.float64) + new_tensor[1] = np.imag(tensor).astype(np.float64) + return new_tensor + + +def real(tensor: np.ndarray) -> np.ndarray: + """Converts complex IQ to a real-only vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + real(tensor) + """ + return np.real(tensor) + + +def imag(tensor: np.ndarray) -> np.ndarray: + """Converts complex IQ to a imaginary-only vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + imag(tensor) + """ + return np.imag(tensor) + + +def complex_magnitude(tensor: np.ndarray) -> np.ndarray: + """Converts complex IQ to a complex magnitude vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + abs(tensor) + """ + return np.abs(tensor) + + +def wrapped_phase(tensor: np.ndarray) -> np.ndarray: + """Converts complex IQ to a wrapped-phase vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + angle(tensor) + """ + return np.angle(tensor) + + +def discrete_fourier_transform(tensor: np.ndarray) -> np.ndarray: + """Computes DFT of complex IQ vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + fft(tensor). normalization is 1/sqrt(n) + """ + return np.fft.fft(tensor, norm="ortho") + + +def spectrogram( + tensor: np.ndarray, + nperseg: int, + noverlap: int, + nfft: int, + window_fcn: Callable[[int], np.ndarray], + mode: str, +) -> np.ndarray: + """Computes spectrogram of complex IQ vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + nperseg (:obj:`int`): + Length of each segment. If window is str or tuple, is set to 256, + and if window is array_like, is set to the length of the window. + + noverlap (:obj:`int`): + Number of points to overlap between segments. + If None, noverlap = nperseg // 8. + + nfft (:obj:`int`): + Length of the FFT used, if a zero padded FFT is desired. + If None, the FFT length is nperseg. + + window_fcn (:obj:`Callable`): + Function generating the window for each FFT + + mode (:obj:`str`): + Mode of the spectrogram to be computed. + + Returns: + transformed (:class:`numpy.ndarray`): + Spectrogram of tensor along time dimension + """ + _, _, spectrograms = signal.spectrogram( + tensor, + nperseg=nperseg, + noverlap=noverlap, + nfft=nfft, + window=window_fcn(nperseg), + return_onesided=False, + mode=mode, + axis=0 + ) + return np.fft.fftshift(spectrograms, axes=0) + + +def continuous_wavelet_transform( + tensor: np.ndarray, + wavelet: str, + nscales: int, + sample_rate: float +) -> np.ndarray: + """Computes the continuous wavelet transform resulting in a Scalogram of the complex IQ vector + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + wavelet (:obj:`str`): + Name of the mother wavelet. + If None, wavename = 'mexh'. + + nscales (:obj:`int`): + Number of scales to use in the Scalogram. + If None, nscales = 33. + + sample_rate (:obj:`float`): + Sample rate of the signal. + If None, fs = 1.0. + + Returns: + transformed (:class:`numpy.ndarray`): + Scalogram of tensor along time dimension + """ + scales = np.arange(1, nscales) + cwtmatr, _ = pywt.cwt( + tensor, + scales=scales, + wavelet=wavelet, + sampling_period=1.0/sample_rate + ) + + # if the dtype is complex then return the magnitude + if np.iscomplexobj(cwtmatr): + cwtmatr = abs(cwtmatr) + + return cwtmatr diff --git a/torchsig/transforms/functional.py b/torchsig/transforms/functional.py index c13e38f..2303bed 100644 --- a/torchsig/transforms/functional.py +++ b/torchsig/transforms/functional.py @@ -1,32 +1,13 @@ -from typing import Callable, List, Protocol, Sequence, Tuple, Union +from typing import Callable, Union, Tuple, List from functools import partial import numpy as np -import numpy.typing as npt FloatParameter = Union[Callable[[int], float], float, Tuple[float, float], List] IntParameter = Union[Callable[[int], int], int, Tuple[int, int], List] NumericParameter = Union[FloatParameter, IntParameter] -class RandomStatePartial(Protocol): - """Type definition for the partially applied random distribution function - returned by the functions in this module. - - These partials can be either called with zero arguments, in which case a - single value is returned, or by passing in a size parameter, in which case - a np.ndarray of the specified shape is returned. - - See: https://peps.python.org/pep-0544/ - See: https://mypy.readthedocs.io/en/stable/protocols.html#callback-protocols - """ - def __call__(self, size: Union[int, Sequence[int]] = ...) -> npt.ArrayLike: - ... - - -def uniform_discrete_distribution( - choices: List, - random_generator: np.random.RandomState = np.random.RandomState() -) -> RandomStatePartial: +def uniform_discrete_distribution(choices: List, random_generator: np.random.RandomState = np.random.RandomState()): return partial(random_generator.choice, choices) @@ -34,14 +15,11 @@ def uniform_continuous_distribution( lower: Union[int, float], upper: Union[int, float], random_generator: np.random.RandomState = np.random.RandomState() -) -> RandomStatePartial: +): return partial(random_generator.uniform, lower, upper) -def to_distribution( - param: NumericParameter, - random_generator: np.random.RandomState = np.random.RandomState() -) -> RandomStatePartial: +def to_distribution(param, random_generator: np.random.RandomState = np.random.RandomState()): if isinstance(param, Callable): return param diff --git a/torchsig/transforms/signal_processing/__init__.py b/torchsig/transforms/signal_processing/__init__.py index 22642dc..a6ce1d4 100644 --- a/torchsig/transforms/signal_processing/__init__.py +++ b/torchsig/transforms/signal_processing/__init__.py @@ -1,2 +1,2 @@ from .sp import * -from .sp_functional import * +from .functional import * diff --git a/torchsig/transforms/signal_processing/functional.py b/torchsig/transforms/signal_processing/functional.py new file mode 100644 index 0000000..e7c7057 --- /dev/null +++ b/torchsig/transforms/signal_processing/functional.py @@ -0,0 +1,92 @@ +import numpy as np +from scipy import signal + + +def normalize(tensor: np.ndarray, norm_order: int = 2, flatten: bool = False) -> np.ndarray: + """Scale a tensor so that a specfied norm computes to 1. For detailed information, see :func:`numpy.linalg.norm.` + * For norm=1, norm = max(sum(abs(x), axis=0)) (sum of the elements) + * for norm=2, norm = sqrt(sum(abs(x)^2), axis=0) (square-root of the sum of squares) + * for norm=np.inf, norm = max(sum(abs(x), axis=1)) (largest absolute value) + + Args: + tensor (:class:`numpy.ndarray`)): + (batch_size, vector_length, ...)-sized tensor to be normalized. + + norm_order (:class:`int`)): + norm order to be passed to np.linalg.norm + + flatten (:class:`bool`)): + boolean specifying if the input array's norm should be calculated on the flattened representation of the input tensor + + Returns: + Tensor: + Normalized complex array. + """ + if flatten: + flat_tensor = tensor.reshape(tensor.size) + norm = np.linalg.norm(flat_tensor, norm_order, keepdims=True) + else: + norm = np.linalg.norm(tensor, norm_order, keepdims=True) + return np.multiply(tensor, 1.0/norm) + + +def resample( + tensor: np.ndarray, + up_rate: int, + down_rate: int, + num_iq_samples: int, + keep_samples: bool, + anti_alias_lpf: bool = False, +) -> np.ndarray: + """Resample a tensor by rational value + + Args: + tensor (:class:`numpy.ndarray`): + tensor to be resampled. + + up_rate (:class:`int`): + rate at which to up-sample the tensor + + down_rate (:class:`int`): + rate at which to down-sample the tensor + + num_iq_samples (:class:`int`): + number of IQ samples to have after resampling + + keep_samples (:class:`bool`): + boolean to specify if the resampled data should be returned as is + + anti_alias_lpf (:class:`bool`)): + boolean to specify if an additional anti aliasing filter should be + applied + + Returns: + Tensor: + Resampled tensor + """ + if anti_alias_lpf: + new_rate = up_rate/down_rate + # Filter around center to future bandwidth + num_taps = int(2*np.ceil(50*2*np.pi/new_rate/.125/22)) # fred harris rule of thumb * 2 + taps = signal.firwin( + num_taps, + new_rate*0.98, + width=new_rate * .02, + window=signal.get_window("blackman", num_taps), + scale=True + ) + tensor = signal.fftconvolve(tensor, taps, mode="same") + + # Resample + resampled = signal.resample_poly(tensor, up_rate, down_rate) + + # Handle extra or not enough IQ samples + if keep_samples: + new_tensor = resampled + elif resampled.shape[0] > num_iq_samples: + new_tensor = resampled[-num_iq_samples:] + else: + new_tensor = np.zeros((num_iq_samples,), dtype=np.complex128) + new_tensor[:resampled.shape[0]] = resampled + + return new_tensor diff --git a/torchsig/transforms/signal_processing/sp.py b/torchsig/transforms/signal_processing/sp.py index 77dedc5..4f04770 100644 --- a/torchsig/transforms/signal_processing/sp.py +++ b/torchsig/transforms/signal_processing/sp.py @@ -4,7 +4,7 @@ from torchsig.utils.types import SignalData, SignalDescription from torchsig.transforms.transforms import SignalTransform -from torchsig.transforms.signal_processing import sp_functional as F +from torchsig.transforms.signal_processing import functional as F from torchsig.transforms.functional import NumericParameter, to_distribution diff --git a/torchsig/transforms/spectrogram_transforms/__init__.py b/torchsig/transforms/spectrogram_transforms/__init__.py new file mode 100644 index 0000000..220064b --- /dev/null +++ b/torchsig/transforms/spectrogram_transforms/__init__.py @@ -0,0 +1,2 @@ +from .spec import * +from .functional import * diff --git a/torchsig/transforms/spectrogram_transforms/functional.py b/torchsig/transforms/spectrogram_transforms/functional.py new file mode 100644 index 0000000..77c769b --- /dev/null +++ b/torchsig/transforms/spectrogram_transforms/functional.py @@ -0,0 +1,168 @@ +import numpy as np + + +def drop_spec_samples( + tensor: np.ndarray, + drop_starts: np.ndarray, + drop_sizes: np.ndarray, + fill: str, +) -> np.ndarray: + """Drop samples at specified input locations/durations with fill technique + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + drop_starts (:class:`numpy.ndarray`): + Indices of where drops start + + drop_sizes (:class:`numpy.ndarray`): + Durations of each drop instance + + fill (:obj:`str`): + String specifying how the dropped samples should be replaced + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone the dropped samples + + """ + flat_spec = tensor.reshape(tensor.shape[0], tensor.shape[1]*tensor.shape[2]) + for idx, drop_start in enumerate(drop_starts): + if fill == "ffill": + drop_region_real = np.ones(drop_sizes[idx])*flat_spec[0,drop_start-1] + drop_region_complex = np.ones(drop_sizes[idx])*flat_spec[1,drop_start-1] + flat_spec[0,drop_start:drop_start+drop_sizes[idx]] = drop_region_real + flat_spec[1,drop_start:drop_start+drop_sizes[idx]] = drop_region_complex + elif fill == "bfill": + drop_region_real = np.ones(drop_sizes[idx])*flat_spec[0,drop_start+drop_sizes[idx]] + drop_region_complex = np.ones(drop_sizes[idx])*flat_spec[1,drop_start+drop_sizes[idx]] + flat_spec[0,drop_start:drop_start+drop_sizes[idx]] = drop_region_real + flat_spec[1,drop_start:drop_start+drop_sizes[idx]] = drop_region_complex + elif fill == "mean": + drop_region_real = np.ones(drop_sizes[idx])*np.mean(flat_spec[0]) + drop_region_complex = np.ones(drop_sizes[idx])*np.mean(flat_spec[1]) + flat_spec[0,drop_start:drop_start+drop_sizes[idx]] = drop_region_real + flat_spec[1,drop_start:drop_start+drop_sizes[idx]] = drop_region_complex + elif fill == "zero": + drop_region = np.zeros(drop_sizes[idx]) + flat_spec[:,drop_start:drop_start+drop_sizes[idx]] = drop_region + elif fill == "min": + drop_region_real = np.ones(drop_sizes[idx])*np.min(np.abs(flat_spec[0])) + drop_region_complex = np.ones(drop_sizes[idx])*np.min(np.abs(flat_spec[1])) + flat_spec[0,drop_start:drop_start+drop_sizes[idx]] = drop_region_real + flat_spec[1,drop_start:drop_start+drop_sizes[idx]] = drop_region_complex + elif fill == "max": + drop_region_real = np.ones(drop_sizes[idx])*np.max(np.abs(flat_spec[0])) + drop_region_complex = np.ones(drop_sizes[idx])*np.max(np.abs(flat_spec[1])) + flat_spec[0,drop_start:drop_start+drop_sizes[idx]] = drop_region_real + flat_spec[1,drop_start:drop_start+drop_sizes[idx]] = drop_region_complex + elif fill == "low": + drop_region = np.ones(drop_sizes[idx])*1e-3 + flat_spec[:,drop_start:drop_start+drop_sizes[idx]] = drop_region + elif fill == "ones": + drop_region = np.ones(drop_sizes[idx]) + flat_spec[:,drop_start:drop_start+drop_sizes[idx]] = drop_region + else: + raise ValueError("fill expects ffill, bfill, mean, zero, min, max, low, ones. Found {}".format(fill)) + new_tensor = flat_spec.reshape(tensor.shape[0], tensor.shape[1], tensor.shape[2]) + return new_tensor + + +def spec_patch_shuffle( + tensor: np.ndarray, + patch_size: int, + shuffle_ratio: float, +) -> np.ndarray: + """Apply shuffling of patches specified by `num_patches` + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + patch_size (:obj:`int`): + Size of each patch to shuffle + + shuffle_ratio (:obj:`float`): + Ratio of patches to shuffle + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone patch shuffling + + """ + channels, height, width = tensor.shape + num_freq_patches = int(height/patch_size) + num_time_patches = int(width/patch_size) + num_patches = int(num_freq_patches * num_time_patches) + num_to_shuffle = int(num_patches * shuffle_ratio) + patches_to_shuffle = np.random.choice( + num_patches, + replace=False, + size=num_to_shuffle, + ) + + for patch_idx in patches_to_shuffle: + freq_idx = np.floor(patch_idx / num_freq_patches) + time_idx = patch_idx % num_time_patches + patch = tensor[ + :, + int(freq_idx*patch_size):int(freq_idx*patch_size+patch_size), + int(time_idx*patch_size):int(time_idx*patch_size+patch_size) + ] + patch = patch.reshape(int(2*patch_size*patch_size)) + np.random.shuffle(patch) + patch = patch.reshape(2,int(patch_size),int(patch_size)) + tensor[ + :, + int(freq_idx*patch_size):int(freq_idx*patch_size+patch_size), + int(time_idx*patch_size):int(time_idx*patch_size+patch_size) + ] = patch + return tensor + + +def spec_translate( + tensor: np.ndarray, + time_shift: int, + freq_shift: int, +) -> np.ndarray: + """Apply time/freq translation to input spectrogram + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + time_shift (:obj:`int`): + Time shift + + freq_shift (:obj:`int`): + Frequency shift + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone time/freq translation + + """ + # Pre-fill the data with background noise + new_tensor = np.random.rand(*tensor.shape)*np.percentile(np.abs(tensor),50) + + # Apply translation + channels, height, width = tensor.shape + if time_shift >= 0 and freq_shift >= 0: + valid_dur = width - time_shift + valid_bw = height - freq_shift + new_tensor[:,freq_shift:,time_shift:] = tensor[:,:valid_bw,:valid_dur] + elif time_shift < 0 and freq_shift >= 0: + valid_dur = width + time_shift + valid_bw = height - freq_shift + new_tensor[:,freq_shift:,:valid_dur] = tensor[:,:valid_bw,-time_shift:] + elif time_shift >= 0 and freq_shift < 0: + valid_dur = width - time_shift + valid_bw = height + freq_shift + new_tensor[:,:valid_bw,time_shift:] = tensor[:,-freq_shift:,:valid_dur] + elif time_shift < 0 and freq_shift < 0: + valid_dur = width + time_shift + valid_bw = height + freq_shift + new_tensor[:,:valid_bw,:valid_dur] = tensor[:,-freq_shift:,-time_shift:] + + return new_tensor diff --git a/torchsig/transforms/spectrogram_transforms/spec.py b/torchsig/transforms/spectrogram_transforms/spec.py new file mode 100644 index 0000000..54adf14 --- /dev/null +++ b/torchsig/transforms/spectrogram_transforms/spec.py @@ -0,0 +1,744 @@ +import numpy as np +from copy import deepcopy +from typing import Optional, Any, Union, List + +from torchsig.utils.dataset import SignalDataset +from torchsig.utils.types import SignalData, SignalDescription +from torchsig.transforms.transforms import SignalTransform +from torchsig.transforms.spectrogram_transforms import functional +from torchsig.transforms.functional import NumericParameter, FloatParameter, IntParameter +from torchsig.transforms.functional import to_distribution, uniform_continuous_distribution, uniform_discrete_distribution + + +class SpectrogramResize(SignalTransform): + """SpectrogramResize inputs data that has already been transformed into a + spectrogram, and then it crops and/or pads both the time and frequency + dimensions to reach a specified target width (time) and height (frequency). + + Args: + width (:obj:`int`): + Target output width (time) of the spectrogram + height (:obj:`int`): + Target output height (frequency) of the spectrogram + + Example: + >>> import torchsig.transforms as ST + >>> # Resize input spectrogram to (512,512) + >>> transform = ST.SpectrogramResize(width=512, height=512) + + """ + def __init__( + self, + width: int = 512, + height: int = 512, + ): + super(SpectrogramResize, self).__init__() + self.width = width + self.height = height + + def __call__(self, data: Any) -> Any: + spec_data = data.iq_data if isinstance(data, SignalData) else data + + # Next, perform the random cropping/padding + channels, curr_height, curr_width = spec_data.shape + pad_height, crop_height = False, False + pad_width, crop_width = False, False + pad_height_samps, pad_width_samps = 0, 0 + if curr_height < self.height: + pad_height = True + pad_height_samps = self.height - curr_height + elif curr_height > self.height: + crop_height = True + if curr_width < self.width: + pad_width = True + pad_width_samps = self.width - curr_width + elif curr_width > self.width: + crop_width = True + + if pad_height or pad_width: + def pad_func(vector, pad_width, iaxis, kwargs): + vector[:pad_width[0]] = np.random.rand(len(vector[:pad_width[0]]))*kwargs['pad_value'] + vector[-pad_width[1]:] = np.random.rand(len(vector[-pad_width[1]:]))*kwargs['pad_value'] + + if channels == 2: + new_data_real = np.pad( + spec_data[0], + ( + (pad_height_samps//2+1,pad_height_samps//2+1), + (pad_width_samps//2+1,pad_width_samps//2+1), + ), + pad_func, + pad_value = np.percentile(np.abs(spec_data[0]),50), + ) + new_data_imag = np.pad( + spec_data[1], + ( + (pad_height_samps//2+1,pad_height_samps//2+1), + (pad_width_samps//2+1,pad_width_samps//2+1), + ), + pad_func, + pad_value = np.percentile(np.abs(spec_data[1]),50), + ) + spec_data = np.concatenate( + [ + np.expand_dims(new_data_real,axis=0), + np.expand_dims(new_data_imag,axis=0) + ], + axis=0, + ) + else: + spec_data = np.pad( + spec_data, + ( + (pad_height_samps//2+1,pad_height_samps//2+1), + (pad_width_samps//2+1,pad_width_samps//2+1), + ), + pad_func, + min_value = np.percentile(np.abs(spec_data[0]),50), + ) + + spec_data = spec_data[:,:self.height,:self.width] + + # Update SignalData object if necessary, otherwise return + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=[], + ) + new_data.iq_data = spec_data + + # Update SignalDescription + new_signal_description = [] + signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Check bounds for partial signals + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + # Update labels based on padding/cropping + if pad_height: + new_signal_desc.lower_frequency = ((new_signal_desc.lower_frequency+0.5)*curr_height + pad_height_samps//2+1) / self.height - 0.5 + new_signal_desc.upper_frequency = ((new_signal_desc.upper_frequency+0.5)*curr_height + pad_height_samps//2+1) / self.height - 0.5 + new_signal_desc.center_frequency = ((new_signal_desc.center_frequency+0.5)*curr_height + pad_height_samps//2+1) / self.height - 0.5 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + + if crop_height: + if (new_signal_desc.lower_frequency+0.5)*curr_height >= crop_height_start+self.height or \ + (new_signal_desc.upper_frequency+0.5)*curr_height <= crop_height_start: + continue + if (new_signal_desc.lower_frequency+0.5)*curr_height <= crop_height_start: + new_signal_desc.lower_frequency = -0.5 + else: + new_signal_desc.lower_frequency = ((new_signal_desc.lower_frequency+0.5)*curr_height) / self.height - 0.5 + if (new_signal_desc.upper_frequency+0.5)*curr_height >= crop_height_start+self.height: + new_signal_desc.upper_frequency = crop_height_start+self.height + else: + new_signal_desc.upper_frequency = ((new_signal_desc.upper_frequency+0.5)*curr_height) / self.height - 0.5 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth / 2 + + if pad_width: + new_signal_desc.start = (new_signal_desc.start * curr_width + pad_width_samps//2+1) / self.width + new_signal_desc.stop = (new_signal_desc.stop * curr_width + pad_width_samps//2+1) / self.width + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + if crop_width: + if new_signal_desc.start*curr_width <= 0: + new_signal_desc.start = 0.0 + elif new_signal_desc.start*curr_width >= self.width: + continue + else: + new_signal_desc.start = (new_signal_desc.start * curr_width) / self.width + if new_signal_desc.stop*curr_width >= self.width: + new_signal_desc.stop = 1.0 + elif new_signal_desc.stop*curr_width <= 0: + continue + else: + new_signal_desc.stop = (new_signal_desc.stop * curr_width) / self.width + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + new_data.signal_description = new_signal_description + + else: + new_data = spec_data + + return new_data + + +class SpectrogramDropSamples(SignalTransform): + """Randomly drop samples from the input data of specified durations and + with specified fill techniques: + * `ffill` (front fill): replace drop samples with the last previous value + * `bfill` (back fill): replace drop samples with the next value + * `mean`: replace drop samples with the mean value of the full data + * `zero`: replace drop samples with zeros + * `low`: replace drop samples with low power samples + * `min`: replace drop samples with the minimum of the absolute power + * `max`: replace drop samples with the maximum of the absolute power + * `ones`: replace drop samples with ones + + Transform is based off of the + `TSAug Dropout Transform `_. + + Args: + drop_rate (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + drop_rate sets the rate at which to drop samples + * If Callable, produces a sample by calling drop_rate() + * If int or float, drop_rate is fixed at the value provided + * If list, drop_rate is any element in the list + * If tuple, drop_rate is in range of (tuple[0], tuple[1]) + + size (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + size sets the size of each instance of dropped samples + * If Callable, produces a sample by calling size() + * If int or float, size is fixed at the value provided + * If list, size is any element in the list + * If tuple, size is in range of (tuple[0], tuple[1]) + + fill (:py:class:`~Callable`, :obj:`list`, :obj:`str`): + fill sets the method of how the dropped samples should be filled + * If Callable, produces a sample by calling fill() + * If list, fill is any element in the list + * If str, fill is fixed at the method provided + + """ + def __init__( + self, + drop_rate: NumericParameter = uniform_continuous_distribution(0.001,0.005), + size: NumericParameter = uniform_discrete_distribution(np.arange(1,10)), + fill: Union[List, str] = uniform_discrete_distribution(["ffill", "bfill", "mean", "zero", "low", "min", "max", "ones"]), + ): + super(SpectrogramDropSamples, self).__init__() + self.drop_rate = to_distribution(drop_rate, self.random_generator) + self.size = to_distribution(size, self.random_generator) + self.fill = to_distribution(fill, self.random_generator) + + def __call__(self, data: Any) -> Any: + drop_rate = self.drop_rate() + fill = self.fill() + + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.float64), + signal_description=data.signal_description, + ) + + # Perform data augmentation + channels, height, width = data.iq_data.shape + spec_size = height * width + drop_instances = int(spec_size * drop_rate) + drop_sizes = self.size(drop_instances).astype(int) + drop_starts = np.random.uniform(1, spec_size-max(drop_sizes)-1, drop_instances).astype(int) + + new_data.iq_data = functional.drop_spec_samples(data.iq_data, drop_starts, drop_sizes, fill) + + else: + drop_instances = int(data.shape[0] * drop_rate) + drop_sizes = self.size(drop_instances).astype(int) + drop_starts = np.random.uniform(0, data.shape[0]-max(drop_sizes), drop_instances).astype(int) + + new_data = functional.drop_spec_samples(data, drop_starts, drop_sizes, fill) + return new_data + + +class SpectrogramPatchShuffle(SignalTransform): + """Randomly shuffle multiple local regions of samples. + + Transform is loosely based on + `PatchShuffle Regularization `_. + + Args: + patch_size (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + patch_size sets the size of each patch to shuffle + * If Callable, produces a sample by calling patch_size() + * If int or float, patch_size is fixed at the value provided + * If list, patch_size is any element in the list + * If tuple, patch_size is in range of (tuple[0], tuple[1]) + + shuffle_ratio (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + shuffle_ratio sets the ratio of the patches to shuffle + * If Callable, produces a sample by calling shuffle_ratio() + * If int or float, shuffle_ratio is fixed at the value provided + * If list, shuffle_ratio is any element in the list + * If tuple, shuffle_ratio is in range of (tuple[0], tuple[1]) + + """ + def __init__( + self, + patch_size: NumericParameter = uniform_continuous_distribution(2,16), + shuffle_ratio: FloatParameter = uniform_continuous_distribution(0.01,0.10), + ): + super(SpectrogramPatchShuffle, self).__init__() + self.patch_size = to_distribution(patch_size, self.random_generator) + self.shuffle_ratio = to_distribution(shuffle_ratio, self.random_generator) + + def __call__(self, data: Any) -> Any: + patch_size = int(self.patch_size()) + shuffle_ratio = self.shuffle_ratio() + + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=data.signal_description, + ) + + # Perform data augmentation + new_data.iq_data = functional.spec_patch_shuffle(data.iq_data, patch_size, shuffle_ratio) + else: + new_data = functional.spec_patch_shuffle(data, patch_size, shuffle_ratio) + return new_data + + +class SpectrogramTranslation(SignalTransform): + """Transform that inputs a spectrogram and applies a random time/freq + translation + + Args: + time_shift (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + time_shift sets the translation along the time-axis + * If Callable, produces a sample by calling time_shift() + * If int, time_shift is fixed at the value provided + * If list, time_shift is any element in the list + * If tuple, time_shift is in range of (tuple[0], tuple[1]) + + freq_shift (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + freq_shift sets the translation along the freq-axis + * If Callable, produces a sample by calling freq_shift() + * If int, freq_shift is fixed at the value provided + * If list, freq_shift is any element in the list + * If tuple, freq_shift is in range of (tuple[0], tuple[1]) + + """ + def __init__( + self, + time_shift: IntParameter = uniform_continuous_distribution(-128,128), + freq_shift: IntParameter = uniform_continuous_distribution(-128,128), + ): + super(SpectrogramTranslation, self).__init__() + self.time_shift = to_distribution(time_shift, self.random_generator) + self.freq_shift = to_distribution(freq_shift, self.random_generator) + + def __call__(self, data: Any) -> Any: + time_shift = int(self.time_shift()) + freq_shift = int(self.freq_shift()) + + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=data.signal_description, + ) + + new_data.iq_data = functional.spec_translate(data.iq_data, time_shift, freq_shift) + + # Update SignalDescription + new_signal_description = [] + signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Update time fields + new_signal_desc.start = new_signal_desc.start + time_shift / new_data.iq_data.shape[1] + new_signal_desc.stop = new_signal_desc.stop + time_shift / new_data.iq_data.shape[1] + if new_signal_desc.start >= 1.0 or new_signal_desc.stop <= 0.0: + continue + new_signal_desc.start = 0.0 if new_signal_desc.start < 0.0 else new_signal_desc.start + new_signal_desc.stop = 1.0 if new_signal_desc.stop > 1.0 else new_signal_desc.stop + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + # Trim any out-of-capture freq values + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + + # Update freq fields + new_signal_desc.lower_frequency = new_signal_desc.lower_frequency + freq_shift / new_data.iq_data.shape[2] + new_signal_desc.upper_frequency = new_signal_desc.upper_frequency + freq_shift / new_data.iq_data.shape[2] + if new_signal_desc.lower_frequency >= 0.5 or new_signal_desc.upper_frequency <= -0.5: + continue + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + # Set output data's SignalDescription to above list + new_data.signal_description = new_signal_description + + else: + new_data = functional.spec_translate(data, time_shift, freq_shift) + return new_data + + +class SpectrogramMosaicCrop(SignalTransform): + """The SpectrogramMosaicCrop transform takes the original input tensor and + inserts it randomly into one cell of a 2x2 grid of 2x the size of the + orginal spectrogram input. The `dataset` argument is then read 3x to + retrieve spectrograms to fill the remaining cells of the 2x2 grid. Finally, + the 2x larger stitched view of 4x spectrograms is randomly cropped to the + original target size, containing pieces of each of the 4x stitched + spectrograms. + + Args: + dataset :obj:`SignalDataset`: + An SignalDataset of complex-valued examples to be used as a source for + the mosaic operation + + """ + def __init__(self, dataset: SignalDataset = None): + super(SpectrogramMosaicCrop, self).__init__() + self.dataset = dataset + + def __call__(self, data: Any) -> Any: + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=data.signal_description, + ) + + # Read shapes + channels, height, width = data.iq_data.shape + + # Randomly decide the new x0, y0 point of the stitched images + x0 = np.random.randint(0,width) + y0 = np.random.randint(0,height) + + # Initialize new SignalDescription object + new_signal_description = [] + + # First, create a 2x2 grid of (512+512,512+512) and randomly put the initial data into a grid cell + cell_idx = np.random.randint(0,4) + x_idx = 0 if cell_idx == 0 or cell_idx == 2 else 1 + y_idx = 0 if cell_idx == 0 or cell_idx == 1 else 1 + full_mosaic = np.empty( + (channels, height*2, width*2), + dtype=data.iq_data.dtype, + ) + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = data.iq_data + + # Update original data's SignalDescription objects given the cell index + signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Update time fields + if x_idx == 0: + if new_signal_desc.stop * width < x0: + continue + new_signal_desc.start = 0 if new_signal_desc.start < (x0 / width) else new_signal_desc.start - (x0 / width) + new_signal_desc.stop = new_signal_desc.stop - (x0 / width) if new_signal_desc.stop < 1.0 else 1.0 - (x0 / width) + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + else: + if new_signal_desc.start * width > x0: + continue + new_signal_desc.start = (width - x0) / width + new_signal_desc.start + new_signal_desc.stop = (width - x0) / width + new_signal_desc.stop + new_signal_desc.stop = 1.0 if new_signal_desc.stop > 1.0 else new_signal_desc.stop + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + # Update frequency fields + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + if y_idx == 0: + if (new_signal_desc.upper_frequency+0.5) * height < y0: + continue + new_signal_desc.lower_frequency = -0.5 if (new_signal_desc.lower_frequency+0.5) < (y0 / height) else new_signal_desc.lower_frequency - (y0 / height) + new_signal_desc.upper_frequency = new_signal_desc.upper_frequency - (y0 / height) if new_signal_desc.upper_frequency < 0.5 else 0.5 - (y0 / height) + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + else: + if (new_signal_desc.lower_frequency+0.5) * height > y0: + continue + new_signal_desc.lower_frequency = (height - y0) / height + new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = (height - y0) / height + new_signal_desc.upper_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + # Next, fill in the remaining cells with data randomly sampled from the input dataset + for cell_i in range(4): + if cell_i == cell_idx: + # Skip if the original data's cell + continue + x_idx = 0 if cell_i == 0 or cell_i == 2 else 1 + y_idx = 0 if cell_i == 0 or cell_i == 1 else 1 + dataset_idx = np.random.randint(len(self.dataset)) + curr_data, curr_signal_desc = self.dataset[dataset_idx] + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = curr_data + + # Update inserted data's SignalDescription objects given the cell index + signal_description = [curr_signal_desc] if isinstance(curr_signal_desc, SignalDescription) else curr_signal_desc + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Update time fields + if x_idx == 0: + if new_signal_desc.stop * width < x0: + continue + new_signal_desc.start = 0 if new_signal_desc.start < (x0 / width) else new_signal_desc.start - (x0 / width) + new_signal_desc.stop = new_signal_desc.stop - (x0 / width) if new_signal_desc.stop < 1.0 else 1.0 - (x0 / width) + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + else: + if new_signal_desc.start * width > x0: + continue + new_signal_desc.start = (width - x0) / width + new_signal_desc.start + new_signal_desc.stop = (width - x0) / width + new_signal_desc.stop + new_signal_desc.stop = 1.0 if new_signal_desc.stop > 1.0 else new_signal_desc.stop + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + # Update frequency fields + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + if y_idx == 0: + if (new_signal_desc.upper_frequency+0.5) * height < y0: + continue + new_signal_desc.lower_frequency = -0.5 if (new_signal_desc.lower_frequency+0.5) < (y0 / height) else new_signal_desc.lower_frequency - (y0 / height) + new_signal_desc.upper_frequency = new_signal_desc.upper_frequency - (y0 / height) if new_signal_desc.upper_frequency < 0.5 else 0.5 - (y0 / height) + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + else: + if (new_signal_desc.lower_frequency+0.5) * height > y0: + continue + new_signal_desc.lower_frequency = (height - y0) / height + new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = (height - y0) / height + new_signal_desc.upper_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + # After the data has been stitched into the large 2x2 gride, crop using x0, y0 + new_data.iq_data = full_mosaic[:,y0:y0+height,x0:x0+width] + + # Set output data's SignalDescription to above list + new_data.signal_description = new_signal_description + + else: + # Read shapes + channels, height, width = data.shape + + # Randomly decide the new x0, y0 point of the stitched images + x0 = np.random.randint(0,width) + y0 = np.random.randint(0,height) + + # Initialize new SignalDescription object + new_signal_description = [] + + # First, create a 2x2 grid of (512+512,512+512) and randomly put the initial data into a grid cell + cell_idx = np.random.randint(0,4) + x_idx = 0 if cell_idx == 0 or cell_idx == 2 else 1 + y_idx = 0 if cell_idx == 0 or cell_idx == 1 else 1 + full_mosaic = np.empty( + (channels, height*2, width*2), + dtype=data.dtype, + ) + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = data + + # Next, fill in the remaining cells with data randomly sampled from the input dataset + for cell_i in range(4): + if cell_i == cell_idx: + # Skip if the original data's cell + continue + x_idx = 0 if cell_i == 0 or cell_i == 2 else 1 + y_idx = 0 if cell_i == 0 or cell_i == 1 else 1 + dataset_idx = np.random.randint(len(self.dataset)) + curr_data, curr_signal_desc = self.dataset[dataset_idx] + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = curr_data + + # After the data has been stitched into the large 2x2 gride, crop using x0, y0 + new_data = full_mosaic[:,y0:y0+height,x0:x0+width] + + return new_data + + +class SpectrogramMosaicDownsample(SignalTransform): + """The SpectrogramMosaicDownsample transform takes the original input + tensor and inserts it randomly into one cell of a 2x2 grid of 2x the size + of the orginal spectrogram input. The `dataset` argument is then read 3x to + retrieve spectrograms to fill the remaining cells of the 2x2 grid. Finally, + the 2x oversized stitched spectrograms are downsampled by 2 to become the + desired, original shape + + Args: + dataset :obj:`SignalDataset`: + An SignalDataset of complex-valued examples to be used as a source for + the mosaic operation + + """ + def __init__(self, dataset: SignalDataset = None): + super(SpectrogramMosaicDownsample, self).__init__() + self.dataset = dataset + + def __call__(self, data: Any) -> Any: + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=data.signal_description, + ) + + # Read shapes + channels, height, width = data.iq_data.shape + + # Initialize new SignalDescription object + new_signal_description = [] + + # First, create a 2x2 grid of (512+512,512+512) and randomly put the initial data into a grid cell + cell_idx = np.random.randint(0,4) + x_idx = 0 if cell_idx == 0 or cell_idx == 2 else 1 + y_idx = 0 if cell_idx == 0 or cell_idx == 1 else 1 + full_mosaic = np.empty( + (channels, height*2, width*2), + dtype=data.iq_data.dtype, + ) + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = data.iq_data + + # Update original data's SignalDescription objects given the cell index + signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Update time fields + if x_idx == 0: + new_signal_desc.start /= 2 + new_signal_desc.stop /= 2 + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + else: + new_signal_desc.start = new_signal_desc.start / 2 + 0.5 + new_signal_desc.stop = new_signal_desc.stop / 2 + 0.5 + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + # Update frequency fields + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + if y_idx == 0: + new_signal_desc.lower_frequency = (new_signal_desc.lower_frequency+0.5) / 2 - 0.5 + new_signal_desc.upper_frequency = (new_signal_desc.upper_frequency+0.5) / 2 - 0.5 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + else: + new_signal_desc.lower_frequency = (new_signal_desc.lower_frequency+0.5) / 2 + new_signal_desc.upper_frequency = (new_signal_desc.upper_frequency+0.5) / 2 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + # Next, fill in the remaining cells with data randomly sampled from the input dataset + for cell_i in range(4): + if cell_i == cell_idx: + # Skip if the original data's cell + continue + x_idx = 0 if cell_i == 0 or cell_i == 2 else 1 + y_idx = 0 if cell_i == 0 or cell_i == 1 else 1 + dataset_idx = np.random.randint(len(self.dataset)) + curr_data, curr_signal_desc = self.dataset[dataset_idx] + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = curr_data + + # Update inserted data's SignalDescription objects given the cell index + signal_description = [curr_signal_desc] if isinstance(curr_signal_desc, SignalDescription) else curr_signal_desc + for signal_desc in signal_description: + new_signal_desc = deepcopy(signal_desc) + + # Update time fields + if x_idx == 0: + new_signal_desc.start /= 2 + new_signal_desc.stop /= 2 + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + else: + new_signal_desc.start = new_signal_desc.start / 2 + 0.5 + new_signal_desc.stop = new_signal_desc.stop / 2 + 0.5 + new_signal_desc.duration = new_signal_desc.stop - new_signal_desc.start + + # Update frequency fields + new_signal_desc.lower_frequency = -0.5 if new_signal_desc.lower_frequency < -0.5 else new_signal_desc.lower_frequency + new_signal_desc.upper_frequency = 0.5 if new_signal_desc.upper_frequency > 0.5 else new_signal_desc.upper_frequency + if y_idx == 0: + new_signal_desc.lower_frequency = (new_signal_desc.lower_frequency+0.5) / 2 - 0.5 + new_signal_desc.upper_frequency = (new_signal_desc.upper_frequency+0.5) / 2 - 0.5 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + else: + new_signal_desc.lower_frequency = (new_signal_desc.lower_frequency+0.5) / 2 + new_signal_desc.upper_frequency = (new_signal_desc.upper_frequency+0.5) / 2 + new_signal_desc.bandwidth = new_signal_desc.upper_frequency - new_signal_desc.lower_frequency + new_signal_desc.center_frequency = new_signal_desc.lower_frequency + new_signal_desc.bandwidth * 0.5 + + # Append SignalDescription to list + new_signal_description.append(new_signal_desc) + + # After the data has been stitched into the large 2x2 gride, downsample by 2 + new_data.iq_data = full_mosaic[:,::2,::2] + + # Set output data's SignalDescription to above list + new_data.signal_description = new_signal_description + + else: + # Read shapes + channels, height, width = data.shape + + # Initialize new SignalDescription object + new_signal_description = [] + + # First, create a 2x2 grid of (512+512,512+512) and randomly put the initial data into a grid cell + cell_idx = np.random.randint(0,4) + x_idx = 0 if cell_idx == 0 or cell_idx == 2 else 1 + y_idx = 0 if cell_idx == 0 or cell_idx == 1 else 1 + full_mosaic = np.empty( + (channels, height*2, width*2), + dtype=data.dtype, + ) + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = data + + # Next, fill in the remaining cells with data randomly sampled from the input dataset + for cell_i in range(4): + if cell_i == cell_idx: + # Skip if the original data's cell + continue + x_idx = 0 if cell_i == 0 or cell_i == 2 else 1 + y_idx = 0 if cell_i == 0 or cell_i == 1 else 1 + dataset_idx = np.random.randint(len(self.dataset)) + curr_data, curr_signal_desc = self.dataset[dataset_idx] + full_mosaic[:,y_idx*height:(y_idx+1)*height,x_idx*width:(x_idx+1)*width] = curr_data + + # After the data has been stitched into the large 2x2 gride, downsample by 2 + new_data = full_mosaic[:,::2,::2] + + return new_data diff --git a/torchsig/transforms/system_impairment/__init__.py b/torchsig/transforms/system_impairment/__init__.py index fc6988b..1b81242 100644 --- a/torchsig/transforms/system_impairment/__init__.py +++ b/torchsig/transforms/system_impairment/__init__.py @@ -1,2 +1,2 @@ from .si import * -from .si_functional import * +from .functional import * diff --git a/torchsig/transforms/system_impairment/functional.py b/torchsig/transforms/system_impairment/functional.py new file mode 100644 index 0000000..e9ea04d --- /dev/null +++ b/torchsig/transforms/system_impairment/functional.py @@ -0,0 +1,635 @@ +import numpy as np +from scipy import signal as sp +from numba import njit, int64, float64, complex64 + + +def time_shift( + tensor: np.ndarray, + t_shift: float +) -> np.ndarray: + """Shifts tensor in the time dimension by tshift samples. Zero-padding is applied to maintain input size. + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor to be shifted. + + t_shift (:obj:`int` or :class:`numpy.ndarray`): + Number of samples to shift right or left (if negative) + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor shifted in time of size tensor.shape + """ + # Valid Range Error Checking + if np.max(np.abs(t_shift)) >= tensor.shape[0]: + return np.zeros_like(tensor, dtype=np.complex64) + + # This overwrites tensor as side effect, modifies inplace + if t_shift > 0: + tmp = tensor[:-t_shift] # I'm sure there's a more compact way. + tensor = np.pad(tmp, (t_shift, 0), 'constant', constant_values=0 + 0j) + elif t_shift < 0: + tmp = tensor[-t_shift:] # I'm sure there's a more compact way. + tensor = np.pad(tmp, (0, -t_shift), 'constant', constant_values=0 + 0j) + return tensor + + +def time_crop( + tensor: np.ndarray, + start: int, + length: int +) -> np.ndarray: + """Crops a tensor in the time dimension from index start(inclusive) for length samples. + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor to be cropped. + + start (:obj:`int` or :class:`numpy.ndarray`): + index to begin cropping + + length (:obj:`int`): + number of samples to include + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor cropped in time of size (tensor.shape[0], length) + """ + # Type and Size checking + if length < 0: + raise ValueError('Length must be greater than 0') + + if np.any(start < 0): + raise ValueError('Start must be greater than 0') + + if np.max(start) >= tensor.shape[0] or length == 0: + return np.empty(shape=(1, 1)) + + crop_len = min(length, tensor.shape[0] - np.max(start)) + + return tensor[start:start + crop_len] + + +def freq_shift(tensor: np.ndarray, f_shift: float) -> np.ndarray: + """Shifts each tensor in freq by freq_shift along the time dimension + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor to be frequency-shifted. + + f_shift (:obj:`float` or :class:`numpy.ndarray`): + Frequency shift relative to the sample rate in range [-.5, .5] + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has been frequency shifted along time dimension of size tensor.shape + """ + sinusoid = np.exp(2j * np.pi * f_shift * np.arange(tensor.shape[0], dtype=np.float64)) + return np.multiply(tensor, np.asarray(sinusoid)) + + +def freq_shift_avoid_aliasing(tensor: np.ndarray, f_shift: float) -> np.ndarray: + """Similar to `freq_shift` function but performs the frequency shifting at + a higher sample rate with filtering to avoid aliasing + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor to be frequency-shifted. + + f_shift (:obj:`float` or :class:`numpy.ndarray`): + Frequency shift relative to the sample rate in range [-.5, .5] + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has been frequency shifted along time dimension of size tensor.shape + """ + # Match output size to input + num_iq_samples = tensor.shape[0] + + # Interpolate up to avoid frequency wrap around during shift + up = 2 + down = 1 + tensor = sp.resample_poly(tensor, up, down) + + # Filter around center to remove original alias effects + num_taps = int(2*np.ceil(50*2*np.pi/(1/up)/.125/22)) # fred harris rule of thumb * 2 + taps = sp.firwin( + num_taps, + (1/up), + width=(1/up) * .02, + window=sp.get_window("blackman", num_taps), + scale=True + ) + tensor = sp.fftconvolve(tensor, taps, mode="same") + + # Freq shift to desired center freq + time_vector = np.arange(tensor.shape[0], dtype=np.float) + tensor = tensor * np.exp(2j * np.pi * f_shift / up * time_vector) + + # Filter to remove out-of-band regions + num_taps = int(2 * np.ceil(50 * 2 * np.pi / (1/up) / .125 / 22)) # fred harris rule-of-thumb * 2 + taps = sp.firwin( + num_taps, + 1 / up, + width=(1/up) * .02, + window=sp.get_window("blackman", num_taps), + scale=True + ) + tensor = sp.fftconvolve(tensor, taps, mode="same") + tensor = tensor[:int(num_iq_samples*up)] # prune to be correct size out of filter + + # Decimate back down to correct sample rate + tensor = sp.resample_poly(tensor, down, up) + + return tensor[:num_iq_samples] + + +@njit(cache=False) +def _fractional_shift_helper( + taps: np.ndarray, + raw_iq: np.ndarray, + stride: int, + offset: int +): + """Fractional shift. First, we up-sample by a large, fixed amount. Filter with 1/upsample_rate/2.0, + Next we down-sample by the same, large fixed amount with a chosen offset. Doing this efficiently means not actually zero-padding. + + The efficient way to do this is to decimate the taps and filter the signal with some offset in the taps. + """ + # We purposely do not calculate values within the group delay. + group_delay = ((taps.shape[0] - 1) // 2 - (stride - 1)) // stride + 1 + if offset < 0: + offset += stride + group_delay -= 1 + + # Decimate the taps. + taps = taps[offset::stride] + + # Determine output size + num_taps = taps.shape[0] + num_raw_iq = raw_iq.shape[0] + output = np.zeros(((num_taps + num_raw_iq - 1 - group_delay),), dtype=np.complex128) + + # This is a just convolution of taps and raw_iq + for o_idx in range(output.shape[0]): + idx_mn = o_idx - (num_raw_iq - 1) if o_idx >= num_raw_iq - 1 else 0 + idx_mx = o_idx if o_idx < num_taps - 1 else num_taps - 1 + for f_idx in range(idx_mn, idx_mx): + output[o_idx - group_delay] += taps[f_idx] * raw_iq[o_idx - f_idx] + return output + + +def fractional_shift( + tensor: np.ndarray, + taps: np.ndarray, + stride: int, + delay: int +) -> np.ndarray: + """Applies fractional sample delay of delay using a polyphase interpolator + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor to be shifted in time. + + taps (:obj:`float` or :class:`numpy.ndarray`): + taps to use for filtering + + stride (:obj:`int`): + interpolation rate of internal filter + + delay (:obj:`int` ): + Delay in number of samples in [-1, 1] + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has been fractionally-shifted along time dimension of size tensor.shape + """ + real_part = _fractional_shift_helper(taps, tensor.real, stride, int(stride * float(delay))) + imag_part = _fractional_shift_helper(taps, tensor.imag, stride, int(stride * float(delay))) + tensor = real_part[:tensor.shape[0]] + 1j * imag_part[:tensor.shape[0]] + zero_idx = -1 if delay < 0 else 0 # do not extrapolate, zero-pad. + tensor[zero_idx] = 0 + return tensor + + +def iq_imbalance( + tensor: np.ndarray, + iq_amplitude_imbalance_db: float, + iq_phase_imbalance: float, + iq_dc_offset_db: float +) -> np.ndarray: + """Applies IQ imbalance to tensor + + Args: + tensor (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor to be shifted in time. + + iq_amplitude_imbalance_db (:obj:`float` or :class:`numpy.ndarray`): + IQ amplitude imbalance in dB + + iq_phase_imbalance (:obj:`float` or :class:`numpy.ndarray`): + IQ phase imbalance in radians [-pi, pi] + + iq_dc_offset_db (:obj:`float` or :class:`numpy.ndarray`): + IQ DC Offset in dB + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has an IQ imbalance applied across the time dimension of size tensor.shape + """ + # amplitude imbalance + tensor = 10 ** (iq_amplitude_imbalance_db / 10.0) * np.real(tensor) + \ + 1j * 10 ** (iq_amplitude_imbalance_db / 10.0) * np.imag(tensor) + + # phase imbalance + tensor = np.exp(-1j * iq_phase_imbalance / 2.0) * np.real(tensor) + \ + np.exp(1j * (np.pi / 2.0 + iq_phase_imbalance / 2.0)) * np.imag(tensor) + + tensor += 10 ** (iq_dc_offset_db / 10.0) * np.real(tensor) + \ + 1j * 10 ** (iq_dc_offset_db / 10.0) * np.imag(tensor) + return tensor + + +def spectral_inversion(tensor: np.ndarray) -> np.ndarray: + """Applies a spectral inversion + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone a spectral inversion + + """ + tensor.imag *= -1 + return tensor + + +def channel_swap(tensor: np.ndarray) -> np.ndarray: + """Swap the I and Q channels of input complex data + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone channel swapping + + """ + real_component = tensor.real + imag_component = tensor.imag + new_tensor = np.empty(*tensor.shape, dtype=tensor.dtype) + new_tensor.real = imag_component + new_tensor.imag = real_component + return new_tensor + + +def time_reversal(tensor: np.ndarray) -> np.ndarray: + """Applies a time reversal to the input tensor + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone a time reversal + + """ + return np.flip(tensor, axis=0) + + +def amplitude_reversal(tensor: np.ndarray) -> np.ndarray: + """Applies an amplitude reversal to the input tensor by multiplying by -1 + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone an amplitude reversal + + """ + return tensor*-1 + + +def roll_off( + tensor: np.ndarray, + lowercutfreq: float, + uppercutfreq: float, + fltorder: int, +) -> np.ndarray: + """Applies front-end filter to tensor. Rolls off lower/upper edges of bandwidth + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + lowercutfreq (:obj:`float`): + lower bandwidth cut-off to begin linear roll-off + + uppercutfreq (:obj:`float`): + upper bandwidth cut-off to begin linear roll-off + + fltorder (:obj:`int`): + order of each FIR filter to be applied + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone front-end filtering. + + """ + if (lowercutfreq == 0) & (uppercutfreq == 1): + return tensor + elif uppercutfreq == 1: + if fltorder % 2 == 0: + fltorder += 1 + bandwidth = uppercutfreq - lowercutfreq + center_freq = lowercutfreq - 0.5 + bandwidth/2 + num_taps = fltorder + sinusoid = np.exp(2j * np.pi * center_freq * np.linspace(0, num_taps - 1, num_taps)) + taps = sp.firwin( + num_taps, + bandwidth, + width=bandwidth * .02, + window=sp.get_window("blackman", num_taps), + scale=True + ) + taps = taps * sinusoid + return sp.fftconvolve(tensor, taps, mode="same") + + +def add_slope(tensor: np.ndarray) -> np.ndarray: + """The slope between each sample and its preceeding sample is added to + every sample + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor with added noise. + """ + slope = np.diff(tensor) + slope = np.insert(slope, 0, 0) + return tensor + slope + + +def mag_rescale( + tensor: np.ndarray, + start: float, + scale: float, +) -> np.ndarray: + """Apply a rescaling of input `scale` starting at time `start` + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + start (:obj:`float`): + Normalized start time of rescaling + + scale (:obj:`float`): + Scaling factor + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone rescaling + + """ + start = int(tensor.shape[0] * start) + tensor[start:] *= scale + return tensor + + +def drop_samples( + tensor: np.ndarray, + drop_starts: np.ndarray, + drop_sizes: np.ndarray, + fill: str, +) -> np.ndarray: + """Drop samples at specified input locations/durations with fill technique + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + drop_starts (:class:`numpy.ndarray`): + Indices of where drops start + + drop_sizes (:class:`numpy.ndarray`): + Durations of each drop instance + + fill (:obj:`str`): + String specifying how the dropped samples should be replaced + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone the dropped samples + + """ + for idx, drop_start in enumerate(drop_starts): + if fill == "ffill": + drop_region = np.ones(drop_sizes[idx], dtype=np.complex64)*tensor[drop_start-1] + elif fill == "bfill": + drop_region = np.ones(drop_sizes[idx], dtype=np.complex64)*tensor[drop_start+drop_sizes[idx]] + elif fill == "mean": + drop_region = np.ones(drop_sizes[idx], dtype=np.complex64)*np.mean(tensor) + elif fill == "zero": + drop_region = np.zeros(drop_sizes[idx], dtype=np.complex64) + else: + raise ValueError("fill expects ffill, bfill, mean, or zero. Found {}".format(fill)) + + # Update drop region + tensor[drop_start:drop_start+drop_sizes[idx]] = drop_region + + return tensor + + +def quantize( + tensor: np.ndarray, + num_levels: int, + round_type: str = 'floor', +) -> np.ndarray: + """Quantize the input to the number of levels specified + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + num_levels (:obj:`int`): + Number of quantization levels + + round_type (:obj:`str`): + Quantization rounding. Options: 'floor', 'middle', 'ceiling' + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone quantization + + """ + # Setup quantization resolution/bins + max_value = max(np.abs(tensor)) + 1e-9 + bins = np.linspace(-max_value,max_value,num_levels+1) + + # Digitize to bins + quantized_real = np.digitize(tensor.real, bins) + quantized_imag = np.digitize(tensor.imag, bins) + + if round_type == 'floor': + quantized_real -= 1 + quantized_imag -= 1 + + # Revert to values + quantized_real = bins[quantized_real] + quantized_imag = bins[quantized_imag] + + if round_type == 'nearest': + bin_size = np.diff(bins)[0] + quantized_real -= (bin_size/2) + quantized_imag -= (bin_size/2) + + quantized_tensor = quantized_real + 1j*quantized_imag + + return quantized_tensor + + +def clip(tensor: np.ndarray, clip_percentage: float) -> np.ndarray: + """Clips input tensor's values above/below a specified percentage of the + max/min of the input tensor + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + clip_percentage (:obj:`float`): + Percentage of max/min values to clip + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor with added noise. + """ + real_tensor = tensor.real + max_val = np.max(real_tensor) * clip_percentage + min_val = np.min(real_tensor) * clip_percentage + real_tensor[real_tensor>max_val] = max_val + real_tensor[real_tensormax_val] = max_val + imag_tensor[imag_tensor np.ndarray: + """Create a complex-valued filter with `num_taps` number of taps, convolve + the random filter with the input data, and sum the original data with the + randomly-filtered data using an `alpha` weighting factor. + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + num_taps: (:obj:`int`): + Number of taps in random filter + + alpha: (:obj:`float`): + Weighting for the summation between the original data and the + randomly-filtered data, following: + + `output = (1 - alpha) * tensor + alpha * filtered_tensor` + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor with weighted random filtering + + """ + filter_taps = np.random.rand(num_taps)+1j*np.random.rand(num_taps) + return (1 - alpha) * tensor + alpha * np.convolve(tensor, filter_taps, mode='same') + + +@njit(complex64[:](complex64[:], float64, float64, float64, float64, float64, float64, float64, float64, float64), cache=False) +def agc( + tensor: np.ndarray, + initial_gain_db: float, + alpha_smooth: float, + alpha_track: float, + alpha_overflow: float, + alpha_acquire: float, + ref_level_db: float, + track_range_db: float, + low_level_db: float, + high_level_db: float, +) -> np.ndarray: + """AGC implementation + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor to be agc'd + + initial_gain_db (:obj:`float`): + Initial gain value in linear units + + alpha_smooth (:obj:`float`): + Alpha for averaging the measured signal level level_n = level_n*alpha + level_n-1*(1 - alpha) + + alpha_track (:obj:`float`): + Amount by which to adjust gain when in tracking state + + alpha_overflow (:obj:`float`): + Amount by which to adjust gain when in overflow state [level_db + gain_db] >= max_level + + alpha_acquire (:obj:`float`): + Amount by which to adjust gain when in acquire state abs([ref_level_db - level_db - gain_db]) >= track_range_db + + ref_level_db (:obj:`float`): + Level to which we intend to adjust gain to achieve + + track_range_db (:obj:`float`): + Range from ref_level_linear for which we can deviate before going into acquire state + + low_level_db (:obj:`float`): + Level below which we disable AGC + + high_level_db (:obj:`float`): + Level above which we go into overflow state + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor with AGC applied + + """ + output = np.zeros_like(tensor) + gain_db = initial_gain_db + for sample_idx, sample in enumerate(tensor): + if np.abs(sample) == 0: + level_db = -200 + else: + level_db = level_db*alpha_smooth + np.log(np.abs(sample))*(1 - alpha_smooth) + output_db = level_db + gain_db + diff_db = ref_level_db - output_db + + if level_db <= low_level_db: + alpha_adjust = 0 + elif output_db >= high_level_db: + alpha_adjust = alpha_overflow + elif (abs(diff_db) > track_range_db): + alpha_adjust = alpha_acquire + else: + alpha_adjust = alpha_track + + gain_db += diff_db * alpha_adjust + output[sample_idx] = tensor[sample_idx] * np.exp(gain_db) + return output diff --git a/torchsig/transforms/system_impairment/si.py b/torchsig/transforms/system_impairment/si.py index 7dbf8fa..97e33c6 100644 --- a/torchsig/transforms/system_impairment/si.py +++ b/torchsig/transforms/system_impairment/si.py @@ -1,11 +1,11 @@ import numpy as np from copy import deepcopy from scipy import signal as sp -from typing import Optional, Any, Union, List, Callable +from typing import Optional, Any, Union, List from torchsig.utils.types import SignalData, SignalDescription from torchsig.transforms.transforms import SignalTransform -from torchsig.transforms.system_impairment import si_functional +from torchsig.transforms.system_impairment import functional from torchsig.transforms.functional import NumericParameter, IntParameter, FloatParameter from torchsig.transforms.functional import to_distribution, uniform_continuous_distribution, uniform_discrete_distribution @@ -67,13 +67,13 @@ def __call__(self, data: Any) -> Any: ) # Apply data transformation - new_data.iq_data = si_functional.fractional_shift( + new_data.iq_data = functional.fractional_shift( data.iq_data, self.taps, self.interp_rate, -decimal_part # this needed to be negated to be consistent with the previous implementation ) - new_data.iq_data = si_functional.time_shift(new_data.iq_data, int(integer_part)) + new_data.iq_data = functional.time_shift(new_data.iq_data, int(integer_part)) # Update SignalDescription new_signal_description = [] @@ -91,13 +91,13 @@ def __call__(self, data: Any) -> Any: new_data.signal_description = new_signal_description else: - new_data = si_functional.fractional_shift( + new_data = functional.fractional_shift( data, self.taps, self.interp_rate, -decimal_part # this needed to be negated to be consistent with the previous implementation ) - new_data = si_functional.time_shift(new_data, int(integer_part)) + new_data = functional.time_shift(new_data, int(integer_part)) return new_data @@ -167,7 +167,7 @@ def __call__(self, data: Any) -> Any: ) # Perform data augmentation - new_data.iq_data = si_functional.time_crop(iq_data, start, self.length) + new_data.iq_data = functional.time_crop(iq_data, start, self.length) # Update SignalDescription new_signal_description = [] @@ -190,7 +190,7 @@ def __call__(self, data: Any) -> Any: new_data.signal_description = new_signal_description else: - new_data = si_functional.time_crop(data, start, self.length) + new_data = functional.time_crop(data, start, self.length) return new_data @@ -228,10 +228,10 @@ def __call__(self, data: Any) -> Any: ) # Perform data augmentation - new_data.iq_data = si_functional.time_reversal(data.iq_data) + new_data.iq_data = functional.time_reversal(data.iq_data) if undo_spec_inversion: # If spectral inversion not desired, reverse effect - new_data.iq_data = si_functional.spectral_inversion(new_data.iq_data) + new_data.iq_data = functional.spectral_inversion(new_data.iq_data) # Update SignalDescription new_signal_description = [] @@ -258,10 +258,10 @@ def __call__(self, data: Any) -> Any: new_data.signal_description = new_signal_description else: - new_data = si_functional.time_reversal(data) + new_data = functional.time_reversal(data) if undo_spec_inversion: # If spectral inversion not desired, reverse effect - new_data = si_functional.spectral_inversion(new_data) + new_data = functional.spectral_inversion(new_data) return new_data @@ -284,10 +284,10 @@ def __call__(self, data: Any) -> Any: ) # Perform data augmentation - new_data.iq_data = si_functional.amplitude_reversal(data.iq_data) + new_data.iq_data = functional.amplitude_reversal(data.iq_data) else: - new_data = si_functional.amplitude_reversal(data) + new_data = functional.amplitude_reversal(data) return new_data @@ -373,16 +373,134 @@ def __call__(self, data: Any) -> Any: # Apply data augmentation if avoid_aliasing: # If any potential aliasing detected, perform shifting at higher sample rate - new_data.iq_data = si_functional.freq_shift_avoid_aliasing(data.iq_data, freq_shift) + new_data.iq_data = functional.freq_shift_avoid_aliasing(data.iq_data, freq_shift) else: # Otherwise, use faster freq shifter - new_data.iq_data = si_functional.freq_shift(data.iq_data, freq_shift) + new_data.iq_data = functional.freq_shift(data.iq_data, freq_shift) else: - new_data = si_functional.freq_shift(data, freq_shift) + new_data = functional.freq_shift(data, freq_shift) return new_data +class RandomDelayedFrequencyShift(SignalTransform): + """Apply a delayed frequency shift to the input data + + Args: + start_shift (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + start_shift sets the start time of the delayed shift + * If Callable, produces a sample by calling start_shift() + * If int, start_shift is fixed at the value provided + * If list, start_shift is any element in the list + * If tuple, start_shift is in range of (tuple[0], tuple[1]) + + freq_shift (:py:class:`~Callable`, :obj:`int`, :obj:`float`, :obj:`list`, :obj:`tuple`): + freq_shift sets the translation along the freq-axis + * If Callable, produces a sample by calling freq_shift() + * If int, freq_shift is fixed at the value provided + * If list, freq_shift is any element in the list + * If tuple, freq_shift is in range of (tuple[0], tuple[1]) + + """ + def __init__( + self, + start_shift: IntParameter = uniform_continuous_distribution(0.1,0.9), + freq_shift: IntParameter = uniform_continuous_distribution(-0.2,0.2), + ): + super(RandomDelayedFrequencyShift, self).__init__() + self.start_shift = to_distribution(start_shift, self.random_generator) + self.freq_shift = to_distribution(freq_shift, self.random_generator) + + def __call__(self, data: Any) -> Any: + start_shift = self.start_shift() + # Randomly generate a freq shift that is not near the original fc + freq_shift = 0 + while freq_shift < 0.05 and freq_shift > -0.05: + freq_shift = self.freq_shift() + + if isinstance(data, SignalData): + # Create new SignalData object for transformed data + new_data = SignalData( + data=None, + item_type=np.dtype(np.float64), + data_type=np.dtype(np.complex128), + signal_description=[], + ) + new_data.iq_data = data.iq_data + num_iq_samples = data.iq_data.shape[0] + + # Setup new SignalDescription object + new_signal_description = [] + signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description + avoid_aliasing = False + for signal_desc in signal_description: + new_signal_desc_first_seg = deepcopy(signal_desc) + new_signal_desc_sec_seg = deepcopy(signal_desc) + # Check bounds for partial signals + new_signal_desc_first_seg.lower_frequency = -0.5 if new_signal_desc_first_seg.lower_frequency < -0.5 else new_signal_desc_first_seg.lower_frequency + new_signal_desc_first_seg.upper_frequency = 0.5 if new_signal_desc_first_seg.upper_frequency > 0.5 else new_signal_desc_first_seg.upper_frequency + new_signal_desc_first_seg.bandwidth = new_signal_desc_first_seg.upper_frequency - new_signal_desc_first_seg.lower_frequency + new_signal_desc_first_seg.center_frequency = new_signal_desc_first_seg.lower_frequency + new_signal_desc_first_seg.bandwidth * 0.5 + + # Update time for original segment if present in segment and add to list + if new_signal_desc_first_seg.start < start_shift: + new_signal_desc_first_seg.stop = start_shift if new_signal_desc_first_seg.stop > start_shift else new_signal_desc_first_seg.stop + new_signal_desc_first_seg.duration = new_signal_desc_first_seg.stop - new_signal_desc_first_seg.start + # Append SignalDescription to list + new_signal_description.append(new_signal_desc_first_seg) + + # Begin second segment processing + new_signal_desc_sec_seg.lower_frequency = -0.5 if new_signal_desc_sec_seg.lower_frequency < -0.5 else new_signal_desc_sec_seg.lower_frequency + new_signal_desc_sec_seg.upper_frequency = 0.5 if new_signal_desc_sec_seg.upper_frequency > 0.5 else new_signal_desc_sec_seg.upper_frequency + new_signal_desc_sec_seg.bandwidth = new_signal_desc_sec_seg.upper_frequency - new_signal_desc_sec_seg.lower_frequency + new_signal_desc_sec_seg.center_frequency = new_signal_desc_sec_seg.lower_frequency + new_signal_desc_sec_seg.bandwidth * 0.5 + + # Update freqs for next segment + new_signal_desc_sec_seg.lower_frequency += freq_shift + new_signal_desc_sec_seg.upper_frequency += freq_shift + new_signal_desc_sec_seg.center_frequency += freq_shift + + # Check bounds for aliasing + if new_signal_desc_sec_seg.lower_frequency >= 0.5 or new_signal_desc_sec_seg.upper_frequency <= -0.5: + avoid_aliasing = True + continue + if new_signal_desc_sec_seg.lower_frequency < -0.45 or new_signal_desc_sec_seg.upper_frequency > 0.45: + avoid_aliasing = True + new_signal_desc_sec_seg.lower_frequency = -0.5 if new_signal_desc_sec_seg.lower_frequency < -0.5 else new_signal_desc_sec_seg.lower_frequency + new_signal_desc_sec_seg.upper_frequency = 0.5 if new_signal_desc_sec_seg.upper_frequency > 0.5 else new_signal_desc_sec_seg.upper_frequency + + # Update bw & fc + new_signal_desc_sec_seg.bandwidth = new_signal_desc_sec_seg.upper_frequency - new_signal_desc_sec_seg.lower_frequency + new_signal_desc_sec_seg.center_frequency = new_signal_desc_sec_seg.lower_frequency + new_signal_desc_sec_seg.bandwidth * 0.5 + + # Update time for shifted segment if present in segment and add to list + if new_signal_desc_sec_seg.stop > start_shift: + new_signal_desc_sec_seg.start = start_shift if new_signal_desc_sec_seg.start < start_shift else new_signal_desc_sec_seg.start + new_signal_desc_sec_seg.stop = new_signal_desc_sec_seg.stop + new_signal_desc_sec_seg.duration = new_signal_desc_sec_seg.stop - new_signal_desc_sec_seg.start + # Append SignalDescription to list + new_signal_description.append(new_signal_desc_sec_seg) + + # Update with the new SignalDescription + new_data.signal_description = new_signal_description + + # Perform augmentation + if avoid_aliasing: + # If any potential aliasing detected, perform shifting at higher sample rate + new_data.iq_data[int(start_shift*num_iq_samples):] = functional.freq_shift_avoid_aliasing( + data.iq_data[int(start_shift*num_iq_samples):], + freq_shift + ) + else: + # Otherwise, use faster freq shifter + new_data.iq_data[int(start_shift*num_iq_samples):] = functional.freq_shift( + data.iq_data[int(start_shift*num_iq_samples):], + freq_shift + ) + + return new_data + + class LocalOscillatorDrift(SignalTransform): """LocalOscillatorDrift is a transform modelling a local oscillator's drift in frequency by a random walk in frequency. @@ -600,7 +718,7 @@ def __call__(self, data: Any) -> Any: ref_level_db = np.random.uniform(-.5 + self.ref_level_db, .5 + self.ref_level_db, 1) - iq_data = si_functional.agc( + iq_data = functional.agc( np.ascontiguousarray(iq_data, dtype=np.complex64), np.float64(self.initial_gain_db), np.float64(alpha_smooth), @@ -677,14 +795,14 @@ def __call__(self, data: Any) -> Any: dc_offset = self.dc_offset() if isinstance(data, SignalData): - data.iq_data = si_functional.iq_imbalance( + data.iq_data = functional.iq_imbalance( data.iq_data, amp_imbalance, phase_imbalance, dc_offset ) else: - data = si_functional.iq_imbalance( + data = functional.iq_imbalance( data, amp_imbalance, phase_imbalance, @@ -742,9 +860,9 @@ def __call__(self, data: Any) -> Any: upper_freq = self.upper_freq() if np.random.rand() < self.upper_cut_apply else 1.0 order = self.order() if isinstance(data, SignalData): - data.iq_data = si_functional.roll_off(data.iq_data, low_freq, upper_freq, int(order)) + data.iq_data = functional.roll_off(data.iq_data, low_freq, upper_freq, int(order)) else: - data = si_functional.roll_off(data, low_freq, upper_freq, int(order)) + data = functional.roll_off(data, low_freq, upper_freq, int(order)) return data @@ -767,10 +885,10 @@ def __call__(self, data: Any) -> Any: ) # Apply data augmentation - new_data.iq_data = si_functional.add_slope(data.iq_data) + new_data.iq_data = functional.add_slope(data.iq_data) else: - new_data = si_functional.add_slope(data) + new_data = functional.add_slope(data) return new_data @@ -792,7 +910,7 @@ def __call__(self, data: Any) -> Any: ) # Perform data augmentation - new_data.iq_data = si_functional.spectral_inversion(data.iq_data) + new_data.iq_data = functional.spectral_inversion(data.iq_data) # Update SignalDescription new_signal_description = [] @@ -812,7 +930,7 @@ def __call__(self, data: Any) -> Any: new_data.signal_description = new_signal_description else: - new_data = si_functional.spectral_inversion(data) + new_data = functional.spectral_inversion(data) return new_data @@ -851,10 +969,10 @@ def __call__(self, data: Any) -> Any: new_data.signal_description = new_signal_description # Perform data augmentation - new_data.iq_data = si_functional.channel_swap(data.iq_data) + new_data.iq_data = functional.channel_swap(data.iq_data) else: - new_data = si_functional.channel_swap(data) + new_data = functional.channel_swap(data) return new_data @@ -901,10 +1019,10 @@ def __call__(self, data: Any) -> Any: ) # Perform data augmentation - new_data.iq_data = si_functional.mag_rescale(data.iq_data, start, scale) + new_data.iq_data = functional.mag_rescale(data.iq_data, start, scale) else: - new_data = si_functional.mag_rescale(data, start, scale) + new_data = functional.mag_rescale(data, start, scale) return new_data @@ -945,7 +1063,7 @@ def __init__( self, drop_rate: NumericParameter = uniform_continuous_distribution(0.01,0.05), size: NumericParameter = uniform_discrete_distribution(np.arange(1,10)), - fill: Union[Callable, List, str] = uniform_discrete_distribution(["ffill", "bfill", "mean", "zero"]), + fill: Union[List, str] = uniform_discrete_distribution(["ffill", "bfill", "mean", "zero"]), ): super(RandomDropSamples, self).__init__() self.drop_rate = to_distribution(drop_rate, self.random_generator) @@ -970,14 +1088,14 @@ def __call__(self, data: Any) -> Any: drop_sizes = self.size(drop_instances).astype(int) drop_starts = np.random.uniform(1, data.iq_data.shape[0]-max(drop_sizes)-1, drop_instances).astype(int) - new_data.iq_data = si_functional.drop_samples(data.iq_data, drop_starts, drop_sizes, fill) + new_data.iq_data = functional.drop_samples(data.iq_data, drop_starts, drop_sizes, fill) else: drop_instances = int(data.shape[0] * drop_rate) drop_sizes = self.size(drop_instances).astype(int) drop_starts = np.random.uniform(0, data.shape[0]-max(drop_sizes), drop_instances).astype(int) - new_data = si_functional.drop_samples(data, drop_starts, drop_sizes, fill) + new_data = functional.drop_samples(data, drop_starts, drop_sizes, fill) return new_data @@ -1022,10 +1140,10 @@ def __call__(self, data: Any) -> Any: ) # Perform data augmentation - new_data.iq_data = si_functional.quantize(data.iq_data, num_levels, round_type) + new_data.iq_data = functional.quantize(data.iq_data, num_levels, round_type) else: - new_data = si_functional.quantize(data, num_levels, round_type) + new_data = functional.quantize(data, num_levels, round_type) return new_data @@ -1063,10 +1181,10 @@ def __call__(self, data: Any) -> Any: ) # Apply data augmentation - new_data.iq_data = si_functional.clip(data.iq_data, clip_percentage) + new_data.iq_data = functional.clip(data.iq_data, clip_percentage) else: - new_data = si_functional.clip(data, clip_percentage) + new_data = functional.clip(data, clip_percentage) return new_data @@ -1117,8 +1235,8 @@ def __call__(self, data: Any) -> Any: ) # Apply data augmentation - new_data.iq_data = si_functional.random_convolve(data.iq_data, num_taps, alpha) + new_data.iq_data = functional.random_convolve(data.iq_data, num_taps, alpha) else: - new_data = si_functional.random_convolve(data, num_taps, alpha) + new_data = functional.random_convolve(data, num_taps, alpha) return new_data diff --git a/torchsig/transforms/target_transforms/target_transforms.py b/torchsig/transforms/target_transforms/target_transforms.py index b84b340..f41a046 100644 --- a/torchsig/transforms/target_transforms/target_transforms.py +++ b/torchsig/transforms/target_transforms/target_transforms.py @@ -1,3 +1,4 @@ +import torch import numpy as np from typing import Tuple, List, Any, Union, Optional @@ -117,6 +118,518 @@ def __call__( return classes[0], snrs[0] +class DescToMask(Transform): + """Transform to transform SignalDescriptions into spectrogram masks + + Args: + max_bursts (:obj:`int`): + Maximum number of bursts to label in their own target channel + width (:obj:`int`): + Width of resultant spectrogram mask + height (:obj:`int`): + Height of resultant spectrogram mask + + """ + def __init__(self, max_bursts: int, width: int, height: int): + super(DescToMask, self).__init__() + self.max_bursts = max_bursts + self.width = width + self.height = height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + masks = np.zeros((self.max_bursts, self.height, self.width)) + idx = 0 + for signal_desc in signal_description: + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + if int((signal_desc.lower_frequency+0.5) * self.height) == int((signal_desc.upper_frequency+0.5) * self.height): + masks[ + idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height)+1, + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + else: + masks[ + idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height), + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + idx += 1 + return masks + + +class DescToMaskSignal(Transform): + """Transform to transform SignalDescriptions into spectrogram masks for binary + signal detection + + Args: + width (:obj:`int`): + Width of resultant spectrogram mask + height (:obj:`int`): + Height of resultant spectrogram mask + + """ + def __init__(self, width: int, height: int): + super(DescToMaskSignal, self).__init__() + self.width = width + self.height = height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + masks = np.zeros((self.height, self.width)) + for signal_desc in signal_description: + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + if int((signal_desc.lower_frequency+0.5) * self.height) == int((signal_desc.upper_frequency+0.5) * self.height): + masks[ + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height)+1, + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + else: + masks[ + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height), + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + return masks + + +class DescToMaskFamily(Transform): + """Transform to transform SignalDescriptions into spectrogram masks with + different channels for each class's family. If no `class_family_dict` + provided, the default mapping for the WBSig53 modulation families is used. + + Args: + class_family_dict (:obj:`dict`): + Dictionary mapping all class names to their families + family_list (:obj:`list`): + List of all of the families + width (:obj:`int`): + Width of resultant spectrogram mask + height (:obj:`int`): + Height of resultant spectrogram mask + + """ + class_family_dict = { + '4ask':'ask', + '8ask':'ask', + '16ask':'ask', + '32ask':'ask', + '64ask':'ask', + 'ook':'pam', + '4pam':'pam', + '8pam':'pam', + '16pam':'pam', + '32pam':'pam', + '64pam':'pam', + '2fsk':'fsk', + '2gfsk':'fsk', + '2msk':'fsk', + '2gmsk':'fsk', + '4fsk':'fsk', + '4gfsk':'fsk', + '4msk':'fsk', + '4gmsk':'fsk', + '8fsk':'fsk', + '8gfsk':'fsk', + '8msk':'fsk', + '8gmsk':'fsk', + '16fsk':'fsk', + '16gfsk':'fsk', + '16msk':'fsk', + '16gmsk':'fsk', + 'bpsk':'psk', + 'qpsk':'psk', + '8psk':'psk', + '16psk':'psk', + '32psk':'psk', + '64psk':'psk', + '16qam':'qam', + '32qam':'qam', + '32qam_cross':'qam', + '64qam':'qam', + '128qam_cross':'qam', + '256qam':'qam', + '512qam_cross':'qam', + '1024qam':'qam', + 'ofdm-64':'ofdm', + 'ofdm-72':'ofdm', + 'ofdm-128':'ofdm', + 'ofdm-180':'ofdm', + 'ofdm-256':'ofdm', + 'ofdm-300':'ofdm', + 'ofdm-512':'ofdm', + 'ofdm-600':'ofdm', + 'ofdm-900':'ofdm', + 'ofdm-1024':'ofdm', + 'ofdm-1200':'ofdm', + 'ofdm-2048':'ofdm', + } + def __init__( + self, + width: int, + height: int, + class_family_dict: dict = None, + family_list: list = None, + label_encode: bool = False, + ): + super(DescToMaskFamily, self).__init__() + self.class_family_dict = class_family_dict if class_family_dict else self.class_family_dict + self.family_list = family_list if family_list else sorted(list(set(self.class_family_dict.values()))) + self.width = width + self.height = height + self.label_encode = label_encode + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + masks = np.zeros((len(self.family_list), self.height, self.width)) + for signal_desc in signal_description: + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + if isinstance(signal_desc.class_name, list): + signal_desc.class_name = signal_desc.class_name[0] + family_name = self.class_family_dict[signal_desc.class_name] + family_idx = self.family_list.index(family_name) + if int((signal_desc.lower_frequency+0.5) * self.height) == int((signal_desc.upper_frequency+0.5) * self.height): + masks[ + family_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height)+1, + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + else: + masks[ + family_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height), + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + if self.label_encode: + background_mask = np.zeros((1, self.height, self.height)) + masks = np.concatenate([background_mask, masks], axis=0) + masks = np.argmax(masks, axis=0) + return masks + + +class DescToMaskClass(Transform): + """Transform to transform list of SignalDescriptions into spectrogram masks + with classes + + Args: + num_classes (:obj:`int`): + Integer number of classes, setting the channel dimension of the resultant mask + width (:obj:`int`): + Width of resultant spectrogram mask + height (:obj:`int`): + Height of resultant spectrogram mask + + """ + def __init__(self, num_classes: int, width: int, height: int): + super(DescToMaskClass, self).__init__() + self.num_classes = num_classes + self.width = width + self.height = height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + masks = np.zeros((self.num_classes, self.height, self.width)) + for signal_desc in signal_description: + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + if int((signal_desc.lower_frequency+0.5) * self.height) == int((signal_desc.upper_frequency+0.5) * self.height): + masks[ + signal_desc.class_index, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height)+1, + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + else: + masks[ + signal_desc.class_index, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height), + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + return masks + + +class DescToSemanticClass(Transform): + """Transform to transform SignalDescriptions into spectrogram semantic + segmentation mask with class information denoted as a value, rather than by + a one/multi-hot vector in an additional channel like the + DescToMaskClass does. Note that the class indicies are all + incremented by 1 in order to reserve the 0 class for "background". Note + that cases of overlapping bursts are currently resolved by comparing SNRs, + labeling the pixel by the stronger signal. Ties in SNR are awarded to the + burst that appears later in the burst collection. + + Args: + num_classes (:obj:`int`): + Integer number of classes, setting the channel dimension of the resultant mask + width (:obj:`int`): + Width of resultant spectrogram mask + height (:obj:`int`): + Height of resultant spectrogram mask + + """ + def __init__(self, num_classes: int, width: int, height: int): + super(DescToSemanticClass, self).__init__() + self.num_classes = num_classes + self.width = width + self.height = height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + masks = np.zeros((self.height, self.width)) + curr_snrs = np.ones((self.height, self.width))*-np.inf + for signal_desc in signal_description: + # Normalize freq values to [0,1] + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + + # Convert to pixels + height_start = max(0, int((signal_desc.lower_frequency+0.5) * self.height)) + height_stop = min(int((signal_desc.upper_frequency+0.5) * self.height), self.height) + width_start = max(0, int(signal_desc.start * self.width)) + width_stop = min(int(signal_desc.stop * self.width), self.width) + + # Account for signals with bandwidths < a pixel + if height_start == height_stop: + height_stop = min(height_stop+1, self.height) + + # Loop through pixels + for height_idx in range(height_start, height_stop): + for width_idx in range(width_start, width_stop): + # Check SNR against currently stored SNR at pixel + if signal_desc.snr >= curr_snrs[height_idx, width_idx]: + # If SNR >= currently stored class's SNR, update class & snr + masks[ + height_start : height_stop, + width_start : width_stop, + ] = signal_desc.class_index+1 + curr_snrs[ + height_start : height_stop, + width_start : width_stop, + ] = signal_desc.snr_db + return masks + + +class DescToBBox(Transform): + """Transform to transform SignalDescriptions into spectrogram bounding boxes + with dimensions: , where the last 5 represents: + - 0: presence ~ 1 if center of burst in current cell, else 0 + - 1: center_time ~ normalized to cell + - 2: dur_time ~ normalized to full spec time + - 3: center_freq ~ normalized to cell + - 4: bw_freq ~ normalized to full spec bw + + Args: + grid_width (:obj:`int`): + Width of grid celling + grid_height (:obj:`int`): + Height of grid celling + + """ + def __init__(self, grid_width: int, grid_height: int): + super(DescToBBox, self).__init__() + self.grid_width = grid_width + self.grid_height = grid_height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + boxes = np.zeros((self.grid_width, self.grid_height, 5)) + for signal_desc in signal_description: + # Time conversions + if signal_desc.start >= 1.0: + # Burst starts outside of window of capture + continue + elif signal_desc.start + signal_desc.duration * 0.5 >= 1.0: + # Center is outside grid cell; re-center to truncated burst + signal_desc.duration = 1 - signal_desc.start + x = (signal_desc.start + signal_desc.duration * 0.5) * self.grid_width + time_cell = int(np.floor(x)) + center_time = x - time_cell + + # Freq conversions + if signal_desc.lower_frequency > 0.5 or signal_desc.upper_frequency < -0.5: + # Burst is fully outside of capture bandwidth + continue + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + signal_desc.bandwidth = signal_desc.upper_frequency - signal_desc.lower_frequency + signal_desc.center_frequency = signal_desc.lower_frequency + signal_desc.bandwidth / 2 + y = (signal_desc.center_frequency + 0.5) * self.grid_height + freq_cell = int(np.floor(y)) + center_freq = y - freq_cell + + if time_cell >= self.grid_width: + print("Error: time_cell idx is greater than grid_width") + print("time_cell: {}".format(time_cell)) + print("burst.start: {}".format(signal_desc.start)) + print("burst.duration: {}".format(signal_desc.duration)) + print("x: {}".format(x)) + if freq_cell >= self.grid_height: + print("Error: freq_cell idx is greater than grid_height") + print("freq_cell: {}".format(freq_cell)) + print("burst.lower_frequency: {}".format(signal_desc.lower_frequency)) + print("burst.upper_frequency: {}".format(signal_desc.upper_frequency)) + print("burst.center_frequency: {}".format(signal_desc.center_frequency)) + print("y: {}".format(y)) + + # Assign to label + boxes[time_cell, freq_cell, 0] = 1 + boxes[time_cell, freq_cell, 1] = center_time + boxes[time_cell, freq_cell, 2] = signal_desc.duration + boxes[time_cell, freq_cell, 3] = center_freq + boxes[time_cell, freq_cell, 4] = signal_desc.bandwidth + return boxes + + +class DescToAnchorBoxes(Transform): + """Transform to transform BurstCollections into spectrogram bounding boxes + using anchor boxes, such that the output target shape will have the + dimensions: , where the last 5 represents: + - 0: objectness ~ 1 if burst associated with current cell & anchor, else 0 + - 1: center_time ~ normalized to cell + - 2: dur_offset ~ offset in duration with anchor box duration + - 3: center_freq ~ normalized to cell + - 4: bw_offset ~ offset in bandwidth with anchor box duration + + Args: + grid_width (:obj:`int`): + Width of grid celling + grid_height (:obj:`int`): + Height of grid celling + anchor_boxes: + List of tuples describing the anchor boxes (normalized values) + Example format: [(dur1, bw1), (dur2, bw2)] + + """ + def __init__(self, grid_width: int, grid_height: int, anchor_boxes: List): + super(DescToAnchorBoxes, self).__init__() + self.grid_width = grid_width + self.grid_height = grid_height + self.anchor_boxes = anchor_boxes + self.num_anchor_boxes = len(anchor_boxes) + + # IoU function + def iou(self, start_a, dur_a, center_freq_a, bw_a, start_b, dur_b, center_freq_b, bw_b): + # Convert to start/stops + x_start_a = start_a + x_stop_a = start_a + dur_a + y_start_a = center_freq_a - bw_a/2 + y_stop_a = center_freq_a + bw_a/2 + + x_start_b = start_b + x_stop_b = start_b + dur_b + y_start_b = center_freq_b - bw_b/2 + y_stop_b = center_freq_b + bw_b/2 + + # Determine the (x, y)-coordinates of the intersection + x_start_int = max(x_start_a, x_start_b) + y_start_int = max(y_start_a, y_start_b) + x_stop_int = min(x_stop_a, x_stop_b) + y_stop_int = min(y_stop_a, y_stop_b) + + # Compute the area of intersection + inter_area = abs(max((x_stop_int - x_start_int, 0)) * max((y_stop_int - y_start_int), 0)) + if inter_area == 0: + return 0 + # Compute the area of both the prediction and ground-truth + area_a = abs((x_stop_a - x_start_a) * (y_stop_a - y_start_a)) + area_b = abs((x_stop_b - x_start_b) * (y_stop_b - y_start_b)) + + # Compute the intersection over union + iou = inter_area / float(area_a + area_b - inter_area) + return iou + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + boxes = np.zeros((self.grid_width, self.grid_height, 5*self.num_anchor_boxes)) + for signal_desc in signal_description: + # Time conversions + if signal_desc.start > 1.0: + # Error handling (TODO: should fix within dataset) + continue + elif signal_desc.start + signal_desc.duration * 0.5 > 1.0: + # Center is outside grid cell; re-center to truncated burst + signal_desc.duration = 1 - signal_desc.start + x = (signal_desc.start + signal_desc.duration * 0.5) * self.grid_width + time_cell = int(np.floor(x)) + center_time = x - time_cell + + # Freq conversions + y = (signal_desc.center_frequency + 0.5) * self.grid_height + freq_cell = int(np.floor(y)) + center_freq = y - freq_cell + + # Debugging messages for potential errors + if time_cell > self.grid_width: + print("Error: time_cell idx is greater than grid_width") + print("time_cell: {}".format(time_cell)) + print("burst.start: {}".format(signal_desc.start)) + print("burst.duration: {}".format(signal_desc.duration)) + print("x: {}".format(x)) + if freq_cell > self.grid_height: + print("Error: freq_cell idx is greater than grid_height") + print("freq_cell: {}".format(freq_cell)) + print("burst.center_frequency: {}".format(signal_desc.center_frequency)) + print("y: {}".format(y)) + + # Determine which anchor box to associate burst with + best_iou_score = -1 + best_iou_idx = 0 + best_anchor_duration = 0 + best_anchor_bw = 0 + for anchor_idx, anchor_box in enumerate(self.anchor_boxes): + #anchor_start = ((time_cell+0.5) / self.grid_width) - (anchor_box[0]*0.5) # Anchor centered on cell + anchor_start = signal_desc.start + 0.5*signal_desc.duration - anchor_box[0]*0.5 # Anchor overlaid on burst + anchor_duration = anchor_box[0] + #anchor_center_freq = (freq_cell+0.5) / self.grid_height # Anchor centered on cell + anchor_center_freq = signal_desc.center_frequency # Anchor overlaid on burst + anchor_bw = anchor_box[1] + iou_score = self.iou(signal_desc.start, signal_desc.duration, signal_desc.center_frequency, signal_desc.bandwidth, + anchor_start, anchor_duration, anchor_center_freq, anchor_bw) + if iou_score > best_iou_score and boxes[time_cell, freq_cell, 0+5*anchor_idx] != 1: + # If IoU score is the best out of all anchors and anchor hasn't already been used for another burst, save results + best_iou_score = iou_score + best_iou_idx = anchor_idx + best_anchor_duration = anchor_duration + best_anchor_bw = anchor_bw + + # Convert absolute coordinates to anchor-box offsets + # centers are normalized values like previous code segment below + # width/height are relative values to anchor boxes + # -- if anchor width is 0.6; true width is 0.5; label width should be 0.5/0.6 + # -- if anchor height is 0.6; true height is 0.7; label height should be 0.7/0.6 + # -- loss & inference will require predicted_box_wh = (sigmoid(model_output_wh)*2)**2 * anchor_wh + if best_iou_score > 0: + # Detection: + boxes[time_cell, freq_cell, 0+5*best_iou_idx] = 1 + # Center time & freq + boxes[time_cell, freq_cell, 1+5*best_iou_idx] = center_time + boxes[time_cell, freq_cell, 3+5*best_iou_idx] = center_freq + # Duration/Bandwidth (Width/Height) + boxes[time_cell, freq_cell, 2+5*best_iou_idx] = signal_desc.duration / best_anchor_duration + boxes[time_cell, freq_cell, 4+5*best_iou_idx] = signal_desc.bandwidth / best_anchor_bw + return boxes + + class DescPassThrough(Transform): """Transform to simply pass the SignalDescription through. Same as applying no transform in most cases. @@ -256,6 +769,465 @@ def __call__(self, signal_description: Union[List[SignalDescription], SignalDesc return encoding +class DescToBBoxDict(Transform): + """Transform to transform SignalDescriptions into the class bounding box format + using dictionaries of labels and boxes, similar to the COCO image dataset + + Args: + class_list (:obj:`list`): + List of class names. Used when converting SignalDescription class names + to indices + + """ + def __init__(self, class_list): + super(DescToBBoxDict, self).__init__() + self.class_list = class_list + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + labels = [] + boxes = np.empty((len(signal_description),4)) + for signal_desc_idx, signal_desc in enumerate(signal_description): + #xcycwh + duration = signal_desc.stop - signal_desc.start + bandwidth = signal_desc.upper_frequency - signal_desc.lower_frequency + boxes[signal_desc_idx] = np.array([ + signal_desc.start + 0.5*duration, + signal_desc.lower_frequency + 0.5 + 0.5*bandwidth, + duration, + bandwidth + ]) + labels.append(self.class_list.index(signal_desc.class_name)) + + targets = {"labels":torch.Tensor(labels).long(), "boxes":torch.Tensor(boxes)} + return targets + + +class DescToBBoxSignalDict(Transform): + """Transform to transform SignalDescriptions into the class bounding box format + using dictionaries of labels and boxes, similar to the COCO image dataset. + Differs from the `SignalDescriptionToBoundingBoxDictTransform` in the ommission + of signal-specific class labels, grouping all objects into the 'signal' + class. + + """ + def __init__(self): + super(DescToBBoxSignalDict, self).__init__() + self.class_list = ["signal"] + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + labels = [] + boxes = np.empty((len(signal_description),4)) + for signal_desc_idx, signal_desc in enumerate(signal_description): + #xcycwh + duration = signal_desc.stop - signal_desc.start + bandwidth = signal_desc.upper_frequency - signal_desc.lower_frequency + boxes[signal_desc_idx] = np.array([ + signal_desc.start + 0.5*duration, + signal_desc.lower_frequency + 0.5 + 0.5*bandwidth, + duration, + bandwidth + ]) + labels.append(self.class_list.index(self.class_list[0])) + + targets = {"labels":torch.Tensor(labels).long(), "boxes":torch.Tensor(boxes)} + return targets + + +class DescToBBoxFamilyDict(Transform): + """Transform to transform SignalDescriptions into the class bounding box format + using dictionaries of labels and boxes, similar to the COCO image dataset. + Differs from the `DescToBBoxDict` transform in the grouping + of fine-grain classes into their signal family as defined by an input + `class_family_dict` dictionary. + + Args: + class_family_dict (:obj:`dict`): + Dictionary mapping all class names to their families + + """ + class_family_dict = { + '4ask':'ask', + '8ask':'ask', + '16ask':'ask', + '32ask':'ask', + '64ask':'ask', + 'ook':'pam', + '4pam':'pam', + '8pam':'pam', + '16pam':'pam', + '32pam':'pam', + '64pam':'pam', + '2fsk':'fsk', + '2gfsk':'fsk', + '2msk':'fsk', + '2gmsk':'fsk', + '4fsk':'fsk', + '4gfsk':'fsk', + '4msk':'fsk', + '4gmsk':'fsk', + '8fsk':'fsk', + '8gfsk':'fsk', + '8msk':'fsk', + '8gmsk':'fsk', + '16fsk':'fsk', + '16gfsk':'fsk', + '16msk':'fsk', + '16gmsk':'fsk', + 'bpsk':'psk', + 'qpsk':'psk', + '8psk':'psk', + '16psk':'psk', + '32psk':'psk', + '64psk':'psk', + '16qam':'qam', + '32qam':'qam', + '32qam_cross':'qam', + '64qam':'qam', + '128qam_cross':'qam', + '256qam':'qam', + '512qam_cross':'qam', + '1024qam':'qam', + 'ofdm-64':'ofdm', + 'ofdm-72':'ofdm', + 'ofdm-128':'ofdm', + 'ofdm-180':'ofdm', + 'ofdm-256':'ofdm', + 'ofdm-300':'ofdm', + 'ofdm-512':'ofdm', + 'ofdm-600':'ofdm', + 'ofdm-900':'ofdm', + 'ofdm-1024':'ofdm', + 'ofdm-1200':'ofdm', + 'ofdm-2048':'ofdm', + } + def __init__(self, class_family_dict: dict = None, family_list: list = None): + super(DescToBBoxFamilyDict, self).__init__() + self.class_family_dict = class_family_dict if class_family_dict else self.class_family_dict + self.family_list = family_list if family_list else sorted(list(set(self.class_family_dict.values()))) + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + labels = [] + boxes = np.empty((len(signal_description),4)) + for signal_desc_idx, signal_desc in enumerate(signal_description): + #xcycwh + duration = signal_desc.stop - signal_desc.start + bandwidth = signal_desc.upper_frequency - signal_desc.lower_frequency + boxes[signal_desc_idx] = np.array([ + signal_desc.start + 0.5*duration, + signal_desc.lower_frequency + 0.5 + 0.5*bandwidth, + duration, + bandwidth + ]) + if isinstance(signal_desc.class_name, list): + signal_desc.class_name = signal_desc.class_name[0] + family_name = self.class_family_dict[signal_desc.class_name] + labels.append(self.family_list.index(family_name)) + + targets = {"labels":torch.Tensor(labels).long(), "boxes":torch.Tensor(boxes)} + return targets + + +class DescToInstMaskDict(Transform): + """Transform to transform SignalDescriptions into the class mask format + using dictionaries of labels and masks, similar to the COCO image dataset + + Args: + class_list (:obj:`list`): + List of class names. Used when converting SignalDescription class names + to indices + width (:obj:`int`): + Width of masks + heigh (:obj:`int`): + Height of masks + + """ + def __init__( + self, + class_list: List = [], + width: int = 512, + height: int = 512, + ): + super(DescToInstMaskDict, self).__init__() + self.class_list = class_list + self.width = width + self.height = height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + num_objects = len(signal_description) + labels = [] + masks = np.zeros((num_objects, self.height, self.width)) + for signal_desc_idx, signal_desc in enumerate(signal_description): + labels.append(self.class_list.index(signal_desc.class_name)) + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + if int((signal_desc.lower_frequency+0.5) * self.height) == int((signal_desc.upper_frequency+0.5) * self.height): + masks[ + signal_desc_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height)+1, + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + else: + masks[ + signal_desc_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height), + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + + targets = {"labels":torch.Tensor(labels).long(), "masks":torch.Tensor(masks.astype(bool))} + return targets + + +class DescToSignalInstMaskDict(Transform): + """Transform to transform SignalDescriptions into the class mask format + using dictionaries of labels and masks, similar to the COCO image dataset + + Args: + width (:obj:`int`): + Width of masks + heigh (:obj:`int`): + Height of masks + + """ + def __init__( + self, + width: int = 512, + height: int = 512, + ): + super(DescToSignalInstMaskDict, self).__init__() + self.width = width + self.height = height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + num_objects = len(signal_description) + labels = [] + masks = np.zeros((num_objects, self.height, self.width)) + for signal_desc_idx, signal_desc in enumerate(signal_description): + labels.append(0) + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + if int((signal_desc.lower_frequency+0.5) * self.height) == int((signal_desc.upper_frequency+0.5) * self.height): + masks[ + signal_desc_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height)+1, + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + else: + masks[ + signal_desc_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height), + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + + targets = {"labels":torch.Tensor(labels).long(), "masks":torch.Tensor(masks.astype(bool))} + return targets + + +class DescToSignalFamilyInstMaskDict(Transform): + """Transform to transform SignalDescriptions into the class mask format + using dictionaries of labels and masks, similar to the COCO image dataset. + The labels with this target transform are set to be the class's family. If + no `class_family_dict` is provided, the default mapping for the WBSig53 + modulation families is used. + + Args: + class_family_dict (:obj:`dict`): + Dictionary mapping all class names to their families + family_list (:obj:`list`): + List of all of the families + width (:obj:`int`): + Width of resultant spectrogram mask + height (:obj:`int`): + Height of resultant spectrogram mask + + """ + class_family_dict = { + '4ask':'ask', + '8ask':'ask', + '16ask':'ask', + '32ask':'ask', + '64ask':'ask', + 'ook':'pam', + '4pam':'pam', + '8pam':'pam', + '16pam':'pam', + '32pam':'pam', + '64pam':'pam', + '2fsk':'fsk', + '2gfsk':'fsk', + '2msk':'fsk', + '2gmsk':'fsk', + '4fsk':'fsk', + '4gfsk':'fsk', + '4msk':'fsk', + '4gmsk':'fsk', + '8fsk':'fsk', + '8gfsk':'fsk', + '8msk':'fsk', + '8gmsk':'fsk', + '16fsk':'fsk', + '16gfsk':'fsk', + '16msk':'fsk', + '16gmsk':'fsk', + 'bpsk':'psk', + 'qpsk':'psk', + '8psk':'psk', + '16psk':'psk', + '32psk':'psk', + '64psk':'psk', + '16qam':'qam', + '32qam':'qam', + '32qam_cross':'qam', + '64qam':'qam', + '128qam_cross':'qam', + '256qam':'qam', + '512qam_cross':'qam', + '1024qam':'qam', + 'ofdm-64':'ofdm', + 'ofdm-72':'ofdm', + 'ofdm-128':'ofdm', + 'ofdm-180':'ofdm', + 'ofdm-256':'ofdm', + 'ofdm-300':'ofdm', + 'ofdm-512':'ofdm', + 'ofdm-600':'ofdm', + 'ofdm-900':'ofdm', + 'ofdm-1024':'ofdm', + 'ofdm-1200':'ofdm', + 'ofdm-2048':'ofdm', + } + def __init__( + self, + width: int, + height: int, + class_family_dict: dict = None, + family_list: list = None, + ): + super(DescToSignalFamilyInstMaskDict, self).__init__() + self.class_family_dict = class_family_dict if class_family_dict else self.class_family_dict + self.family_list = family_list if family_list else sorted(list(set(self.class_family_dict.values()))) + self.width = width + self.height = height + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> np.ndarray: + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + num_objects = len(signal_description) + labels = [] + masks = np.zeros((num_objects, self.height, self.width)) + for signal_desc_idx, signal_desc in enumerate(signal_description): + family_name = self.class_family_dict[signal_desc.class_name] + family_idx = self.family_list.index(family_name) + labels.append(family_idx) + if signal_desc.lower_frequency < -0.5: + signal_desc.lower_frequency = -0.5 + if signal_desc.upper_frequency > 0.5: + signal_desc.upper_frequency = 0.5 + if int((signal_desc.lower_frequency+0.5) * self.height) == int((signal_desc.upper_frequency+0.5) * self.height): + masks[ + signal_desc_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height)+1, + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + else: + masks[ + signal_desc_idx, + int((signal_desc.lower_frequency+0.5) * self.height) : int((signal_desc.upper_frequency+0.5) * self.height), + int(signal_desc.start * self.width) : int(signal_desc.stop * self.width), + ] = 1.0 + + targets = {"labels":torch.Tensor(labels).long(), "masks":torch.Tensor(masks.astype(bool))} + return targets + + +class DescToListTuple(Transform): + """Transform to transform SignalDescription into a list of tuples containing + the modulation, start time, stop time, center frequency, bandwidth, and SNR + for each signal present + + Args: + precision (:obj: `np.dtype`): + Specify the data type precision for the tuple's information + + """ + def __init__(self, precision: np.dtype = np.dtype(np.float16)): + super(DescToListTuple, self).__init__() + self.precision = precision + + def __call__(self, signal_description: Union[List[SignalDescription], SignalDescription]) -> Union[List[str], str]: + output = [] + # Handle cases of both SignalDescriptions and lists of SignalDescriptions + signal_description = [signal_description] if isinstance(signal_description, SignalDescription) else signal_description + # Loop through SignalDescription's, converting values of interest to tuples + for signal_desc_idx, signal_desc in enumerate(signal_description): + curr_tuple = ( + signal_desc.class_name[0], + self.precision.type(signal_desc.start), + self.precision.type(signal_desc.stop), + self.precision.type(signal_desc.center_frequency), + self.precision.type(signal_desc.bandwidth), + self.precision.type(signal_desc.snr), + ) + output.append(curr_tuple) + return output + + +class ListTupleToDesc(Transform): + """Transform to transform a list of tuples to a list of SignalDescriptions + Sample rate and number of IQ samples optional arguments are provided in + order to fill in additional information if desired. If a class list is + provided, the class names are used with the list to fill in class indices + + Args: + sample_rate (:obj: `Optional[float]`): + Optionally provide the sample rate for the SignalDescriptions + + num_iq_samples (:obj: `Optional[int]`): + Optionally provide the number of IQ samples for the SignalDescriptions + + class_list (:obj: `List`): + Optionally provide the class list to fill in class indices + + """ + def __init__( + self, + sample_rate: Optional[float] = 1.0, + num_iq_samples: Optional[int] = int(512*512), + class_list: Optional[List] = None, + ): + super(ListTupleToDesc, self).__init__() + self.sample_rate = sample_rate + self.num_iq_samples = num_iq_samples + self.class_list = class_list + + def __call__(self, list_tuple: List[Tuple]) -> List[SignalDescription]: + output = [] + # Loop through SignalDescription's, converting values of interest to tuples + for tuple_idx, curr_tuple in enumerate(list_tuple): + curr_signal_desc = SignalDescription( + sample_rate=self.sample_rate, + num_iq_samples=self.num_iq_samples, + class_name=curr_tuple[0], + class_index=self.class_list.index(curr_tuple[0]) if self.class_list else None, + start=curr_tuple[1], + stop=curr_tuple[2], + center_frequency=curr_tuple[3], + bandwidth=curr_tuple[4], + lower_frequency=curr_tuple[3]-curr_tuple[4]/2, + upper_frequency=curr_tuple[3]+curr_tuple[4]/2, + snr=curr_tuple[5], + ) + output.append(curr_signal_desc) + return output + + class LabelSmoothing(Transform): """Transform to transform a numpy array encoding to a smoothed version to assist with overconfidence. The input hyperparameter `alpha` determines the @@ -286,4 +1258,4 @@ def __init__(self, alpha: float = 0.1) -> np.ndarray: def __call__(self, encoding: np.ndarray) -> np.ndarray: return (1 - self.alpha) / np.sum(encoding) * encoding + (self.alpha / encoding.shape[0]) - + \ No newline at end of file diff --git a/torchsig/transforms/transforms.py b/torchsig/transforms/transforms.py index 2496527..0c8cc78 100644 --- a/torchsig/transforms/transforms.py +++ b/torchsig/transforms/transforms.py @@ -225,3 +225,33 @@ def __call__(self, data: Any) -> Any: for t in transforms: data = t(data) return data + + +class RandChoice(SignalTransform): + """RandChoice inputs a list of transforms and their associated + probabilities. When called, a single transform will be sampled from the + list using the probabilities provided, and then the selected transform + will operate on the input data. + + Args: + transforms (:obj:`list`): + List of transforms to sample from and then apply + probabilities (:obj:`list`): + Probabilities used when sampling the above list of transforms + + """ + def __init__( + self, + transforms: List[SignalTransform], + probabilities: Optional[List[float]] = None, + **kwargs, + ): + super(RandChoice, self).__init__(**kwargs) + self.transforms = transforms + self.probabilities = probabilities if probabilities else np.ones(len(self.transforms))/len(self.transforms) + if sum(self.probabilities) != 1.0: + self.probabilities /= sum(self.probabilities) + + def __call__(self, data: Any) -> Any: + t = self.random_generator.choice(self.transforms, p=self.probabilities) + return t(data) \ No newline at end of file diff --git a/torchsig/transforms/wireless_channel/__init__.py b/torchsig/transforms/wireless_channel/__init__.py index fcfb11f..4da9f8b 100644 --- a/torchsig/transforms/wireless_channel/__init__.py +++ b/torchsig/transforms/wireless_channel/__init__.py @@ -1,2 +1,2 @@ from .wce import * -from .wce_functional import * +from .functional import * diff --git a/torchsig/transforms/wireless_channel/functional.py b/torchsig/transforms/wireless_channel/functional.py new file mode 100644 index 0000000..b96a4e4 --- /dev/null +++ b/torchsig/transforms/wireless_channel/functional.py @@ -0,0 +1,163 @@ +import numpy as np +from numba import njit +from scipy import signal as sp +from scipy import interpolate + + +@njit(cache=False) +def make_sinc_filter(beta, tap_cnt, sps, offset=0): + """ + return the taps of a sinc filter + """ + ntap_cnt = tap_cnt + ((tap_cnt + 1) % 2) + t_index = np.arange(-(ntap_cnt - 1) // 2, (ntap_cnt - 1) // 2 + 1) / np.double(sps) + + taps = np.sinc(beta * t_index + offset) + taps /= np.sum(taps) + + return taps[:tap_cnt] + + +def awgn(tensor: np.ndarray, noise_power_db: float) -> np.ndarray: + """Adds zero-mean complex additive white Gaussian noise with power of + noise_power_db. + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + noise_power_db (:obj:`float`): + Defined as 10*log10(E[|n|^2]). + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor with added noise. + """ + real_noise = np.random.randn(*tensor.shape) + imag_noise = np.random.randn(*tensor.shape) + return tensor + (10.0**(noise_power_db/20.0))*(real_noise + 1j*imag_noise)/np.sqrt(2) + + +def time_varying_awgn( + tensor: np.ndarray, + noise_power_db_low: float, + noise_power_db_high: float, + inflections: int, + random_regions: bool, +) -> np.ndarray: + """Adds time-varying complex additive white Gaussian noise with power + levels in range (`noise_power_db_low`, `noise_power_db_high`) and with + `inflections` number of inflection points spread over the input tensor + randomly if `random_regions` is True or evely spread if False + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + noise_power_db_low (:obj:`float`): + Defined as 10*log10(E[|n|^2]). + + noise_power_db_high (:obj:`float`): + Defined as 10*log10(E[|n|^2]). + + inflections (:obj:`int`): + Number of inflection points for time-varying nature + + random_regions (:obj:`bool`): + Specify if inflection points are randomly spread throughout tensor + or if evenly spread + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor with added noise. + """ + real_noise = np.random.randn(*tensor.shape) + imag_noise = np.random.randn(*tensor.shape) + noise_power_db = np.empty(*tensor.shape) + + if inflections == 0: + inflection_indices = np.array([0, tensor.shape[0]]) + else: + if random_regions: + inflection_indices = np.sort(np.random.choice(tensor.shape[0], size=inflections, replace=False)) + inflection_indices = np.append(inflection_indices, tensor.shape[0]) + inflection_indices = np.insert(inflection_indices, 0, 0) + else: + inflection_indices = np.arange(inflections+2) * int(tensor.shape[0] / (inflections+1)) + + for idx in range(len(inflection_indices)-1): + start_idx = inflection_indices[idx] + stop_idx = inflection_indices[idx+1] + duration = stop_idx - start_idx + start_power = noise_power_db_low if idx%2 == 0 else noise_power_db_high + stop_power = noise_power_db_high if idx%2 == 0 else noise_power_db_low + noise_power_db[start_idx:stop_idx] = np.linspace(start_power, stop_power, duration) + + return tensor + (10.0**(noise_power_db/20.0))*(real_noise + 1j*imag_noise)/np.sqrt(2) + + +def rayleigh_fading( + tensor: np.ndarray, + coherence_bandwidth: float, + power_delay_profile: np.ndarray, +) -> np.ndarray: + """Applies Rayleigh fading channel to tensor. Taps are generated by + interpolating and filtering Gaussian taps. + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + coherence_bandwidth (:obj:`float`): + coherence_bandwidth relative to the sample rate in [0, 1.0] + + power_delay_profile (:obj:`float`): + power_delay_profile assigned to channel + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone Rayleigh Fading. + + """ + num_taps = int(np.ceil(1.0 / coherence_bandwidth)) # filter length to get desired coherence bandwidth + power_taps = np.sqrt(np.interp( + np.linspace(0, 1.0, 100*num_taps), + np.linspace(0, 1.0, len(power_delay_profile)), + power_delay_profile + )) + # Generate initial taps + rayleigh_taps = (np.random.randn(num_taps) + 1j * np.random.randn(num_taps)) # multi-path channel + + # Linear interpolate taps by a factor of 100 -- so we can get accurate coherence bandwidths + old_time = np.linspace(0, 1.0, num_taps, endpoint=True) + real_tap_function = interpolate.interp1d(old_time, rayleigh_taps.real) + imag_tap_function = interpolate.interp1d(old_time, rayleigh_taps.imag) + + new_time = np.linspace(0, 1.0, 100*num_taps, endpoint=True) + rayleigh_taps = real_tap_function(new_time) + 1j*imag_tap_function(new_time) + rayleigh_taps *= power_taps + + # Ensure that we maintain the same amount of power before and after the transform + input_power = np.linalg.norm(tensor) + tensor = sp.upfirdn(rayleigh_taps, tensor, up=100, down=100)[-tensor.shape[0]:] + output_power = np.linalg.norm(tensor) + tensor = np.multiply(input_power/output_power, tensor) + return tensor + + +def phase_offset(tensor: np.ndarray, phase: float) -> np.ndarray: + """ Applies a phase rotation to tensor + + Args: + tensor: (:class:`numpy.ndarray`): + (batch_size, vector_length, ...)-sized tensor. + + phase (:obj:`float`): + phase to rotate sample in [-pi, pi] + + Returns: + transformed (:class:`numpy.ndarray`): + Tensor that has undergone a phase rotation + + """ + return tensor*np.exp(1j*phase) diff --git a/torchsig/transforms/wireless_channel/wce.py b/torchsig/transforms/wireless_channel/wce.py index 851e931..476d9ff 100644 --- a/torchsig/transforms/wireless_channel/wce.py +++ b/torchsig/transforms/wireless_channel/wce.py @@ -1,10 +1,10 @@ -from copy import deepcopy import numpy as np +from copy import deepcopy from typing import Optional, Tuple, List, Union, Any from torchsig.utils.types import SignalData, SignalDescription from torchsig.transforms.transforms import SignalTransform -from torchsig.transforms.wireless_channel import wce_functional as F +from torchsig.transforms.wireless_channel import functional as F from torchsig.transforms.functional import NumericParameter, FloatParameter, IntParameter from torchsig.transforms.functional import to_distribution, uniform_continuous_distribution, uniform_discrete_distribution @@ -81,11 +81,11 @@ def __call__(self, data: Any) -> Any: class AddNoise(SignalTransform): - """ Add random AWGN at specified power levels + """Add random AWGN at specified power levels Note: - Differs from the TargetSNR() transform in that this transform adds - noise at a specified power level, whereas AddNoise() + Differs from the TargetSNR() in that this transform adds + noise at a specified power level, whereas TargetSNR() assumes a basebanded signal and adds noise to achieve a specified SNR level for the signal of interest. This transform, AddNoise() is useful for simply adding a randomized @@ -100,6 +100,9 @@ class AddNoise(SignalTransform): * If int or float, target_snr is fixed at the value provided * If list, target_snr is any element in the list * If tuple, target_snr is in range of (tuple[0], tuple[1]) + + input_noise_floor_db (:obj:`float`): + The noise floor of the input data in dB linear (:obj:`bool`): If True, target_snr and signal_power is on linear scale not dB. @@ -107,18 +110,19 @@ class AddNoise(SignalTransform): Example: >>> import torchsig.transforms as ST >>> # Added AWGN power range is (-40, -20) dB - >>> transform = ST.AddNoiseTransform((-40, -20)) + >>> transform = ST.AddRandomNoiseTransform((-40, -20)) """ - def __init__( self, - noise_power_db : NumericParameter = uniform_continuous_distribution(-80, -60), + noise_power_db: NumericParameter = uniform_continuous_distribution(-80, -60), + input_noise_floor_db: float = 0.0, linear: Optional[bool] = False, **kwargs, ): super(AddNoise, self).__init__(**kwargs) - self.noise_power_db = to_distribution(noise_power_db) + self.noise_power_db = to_distribution(noise_power_db, self.random_generator) + self.input_noise_floor_db = input_noise_floor_db self.linear = linear def __call__(self, data: Any) -> Any: @@ -131,9 +135,17 @@ def __call__(self, data: Any) -> Any: signal_description=[], ) - # Apply data augmentation + # Retrieve random noise power value noise_power_db = self.noise_power_db() noise_power_db = 10*np.log10(noise_power_db) if self.linear else noise_power_db + + if self.input_noise_floor_db: + noise_floor = self.input_noise_floor_db + else: + # TODO: implement fast noise floor estimation technique? + noise_floor = 0 # Assumes 0dB noise floor + + # Apply data augmentation new_data.iq_data = F.awgn(data.iq_data, noise_power_db) # Update SignalDescription @@ -141,7 +153,7 @@ def __call__(self, data: Any) -> Any: signal_description = [data.signal_description] if isinstance(data.signal_description, SignalDescription) else data.signal_description for signal_desc in signal_description: new_signal_desc = deepcopy(signal_desc) - new_signal_desc.snr -= noise_power_db + new_signal_desc.snr = (new_signal_desc.snr - noise_power_db) if noise_power_db > noise_floor else new_signal_desc.snr new_signal_description.append(new_signal_desc) new_data.signal_description = new_signal_description diff --git a/torchsig/utils/visualize.py b/torchsig/utils/visualize.py index f81ec71..da3c78c 100644 --- a/torchsig/utils/visualize.py +++ b/torchsig/utils/visualize.py @@ -1,10 +1,11 @@ import pywt import numpy as np +from copy import deepcopy from scipy import ndimage from scipy import signal as sp from matplotlib import pyplot as plt from matplotlib.figure import Figure -from torch.utils.data import DataLoader +from torch.utils.data import dataloader from typing import Optional, Callable, Iterable, Union, Tuple, List @@ -24,7 +25,7 @@ class Visualizer: """ def __init__( self, - data_loader: DataLoader, + data_loader: dataloader, visualize_transform: Optional[Callable] = None, visualize_target_transform: Optional[Callable] = None ): @@ -278,6 +279,346 @@ def _visualize(self, data: np.ndarray, targets: np.ndarray) -> Figure: return figure + +class PSDVisualizer(Visualizer): + """ Visualize a PSD + + Args: + fft_size: + **kwargs: + """ + + def __init__(self, fft_size: int = 1024, **kwargs): + super(PSDVisualizer, self).__init__(**kwargs) + self.fft_size = fft_size + + def _visualize(self, iq_data: np.ndarray, targets: np.ndarray) -> Figure: + batch_size = iq_data.shape[0] + figure = plt.figure() + for sample_idx in range(batch_size): + plt.subplot(int(np.ceil(np.sqrt(batch_size))), + int(np.sqrt(batch_size)), sample_idx + 1) + Pxx, freqs = plt.psd(iq_data[sample_idx], NFFT=self.fft_size, Fs=1) + plt.xticks() + plt.yticks() + plt.title(str(targets[sample_idx])) + return figure + + +class MaskVisualizer(Visualizer): + """ Visualize data with mask label information overlaid + + Args: + **kwargs: + """ + def __init__(self, **kwargs): + super(MaskVisualizer, self).__init__(**kwargs) + + def __next__(self) -> Figure: + iq_data, targets = next(self.data_iter) + if self.visualize_transform: + iq_data = self.visualize_transform(deepcopy(iq_data)) + + if self.visualize_target_transform: + targets = self.visualize_target_transform(deepcopy(targets)) + else: + targets = None + + return self._visualize(iq_data, targets) + + def _visualize(self, data: np.ndarray, targets: np.ndarray) -> Figure: + batch_size = data.shape[0] + figure = plt.figure(frameon=False) + for sample_idx in range(batch_size): + plt.subplot(int(np.ceil(np.sqrt(batch_size))), + int(np.sqrt(batch_size)), sample_idx + 1) + extent = 0, data.shape[1], 0, data.shape[2] + data_img = plt.imshow( + data[sample_idx], + vmin=np.min(data[sample_idx]), + vmax=np.max(data[sample_idx]), + cmap="jet", + extent=extent, + ) + if targets is not None: + label = targets[sample_idx] + label_img = plt.imshow( + label, + vmin=np.min(label), + vmax=np.max(label), + cmap="gray", + alpha=0.5, + interpolation="none", + extent=extent, + ) + plt.xticks([]) + plt.yticks([]) + plt.title("Data") + + return figure + + +class MaskClassVisualizer(Visualizer): + """ + Visualize data with mask label information overlaid and the class of the + mask included in the title + + Args: + **kwargs: + """ + def __init__(self, class_list, **kwargs): + super(MaskClassVisualizer, self).__init__(**kwargs) + self.class_list = class_list + + def __next__(self) -> Figure: + iq_data, targets = next(self.data_iter) + if self.visualize_transform: + iq_data = self.visualize_transform(deepcopy(iq_data)) + + if self.visualize_target_transform: + classes, targets = self.visualize_target_transform(deepcopy(targets)) + else: + targets = None + + return self._visualize(iq_data, targets, classes) + + def _visualize(self, data: np.ndarray, targets: np.ndarray, classes: List) -> Figure: + batch_size = data.shape[0] + figure = plt.figure(frameon=False) + for sample_idx in range(batch_size): + plt.subplot(int(np.ceil(np.sqrt(batch_size))), + int(np.sqrt(batch_size)), sample_idx + 1) + extent = 0, data.shape[1], 0, data.shape[2] + data_img = plt.imshow( + data[sample_idx], + vmin=np.min(data[sample_idx]), + vmax=np.max(data[sample_idx]), + cmap="jet", + extent=extent, + ) + title = [] + if targets is not None: + class_idx = classes[sample_idx] + mask = targets[sample_idx] + mask_img = plt.imshow( + mask, + vmin=np.min(mask), + vmax=np.max(mask), + cmap="gray", + alpha=0.5, + interpolation="none", + extent=extent, + ) + title = [self.class_list[idx] for idx in class_idx] + else: + title = "Data" + plt.xticks([]) + plt.yticks([]) + plt.title(title) + + return figure + + +class SemanticMaskClassVisualizer(Visualizer): + """ + Visualize data with mask label information overlaid and the class of the + mask included in the title + + Args: + **kwargs: + """ + def __init__(self, class_list, **kwargs): + super(SemanticMaskClassVisualizer, self).__init__(**kwargs) + self.class_list = class_list + + def __next__(self) -> Figure: + iq_data, targets = next(self.data_iter) + if self.visualize_transform: + iq_data = self.visualize_transform(deepcopy(iq_data)) + + if self.visualize_target_transform: + targets = self.visualize_target_transform(deepcopy(targets)) + + return self._visualize(iq_data, targets) + + def _visualize(self, data: np.ndarray, targets: np.ndarray) -> Figure: + batch_size = data.shape[0] + figure = plt.figure(frameon=False) + for sample_idx in range(batch_size): + plt.subplot(int(np.ceil(np.sqrt(batch_size))), + int(np.sqrt(batch_size)), sample_idx + 1) + extent = 0, data.shape[1], 0, data.shape[2] + data_img = plt.imshow( + data[sample_idx], + vmin=np.min(data[sample_idx]), + vmax=np.max(data[sample_idx]), + cmap="jet", + extent=extent, + ) + title = [] + if targets is not None: + mask = np.ma.masked_where(targets[sample_idx] < 1, targets[sample_idx]) + mask_img = plt.imshow( + mask, + alpha=0.5, + interpolation="none", + extent=extent, + ) + classes_present = list(set(targets[sample_idx].flatten().tolist())) + classes_present.remove(0.0) # Remove 'background' class + title = [self.class_list[int(class_idx-1)] for class_idx in classes_present] + else: + title = "Data" + plt.xticks([]) + plt.yticks([]) + plt.title(title) + + return figure + + +class BoundingBoxVisualizer(Visualizer): + """ Visualize data with bounding box label information overlaid + + Args: + **kwargs: + """ + def __init__(self, **kwargs): + super(BoundingBoxVisualizer, self).__init__(**kwargs) + + def __next__(self) -> Figure: + iq_data, targets = next(self.data_iter) + + if self.visualize_transform: + iq_data = self.visualize_transform(deepcopy(iq_data)) + + if self.visualize_target_transform: + targets = self.visualize_target_transform(deepcopy(targets)) + else: + targets = targets + + return self._visualize(iq_data, targets) + + def _visualize(self, data: np.ndarray, targets: np.ndarray) -> Figure: + batch_size = data.shape[0] + figure = plt.figure(frameon=False) + for sample_idx in range(batch_size): + ax = plt.subplot(int(np.ceil(np.sqrt(batch_size))), + int(np.sqrt(batch_size)), sample_idx + 1) + + # Retrieve individual label + ax.imshow( + data[sample_idx], + vmin=np.min(data[sample_idx]), + vmax=np.max(data[sample_idx]), + cmap="jet", + ) + label = targets[sample_idx] + pixels_per_cell_x = data[sample_idx].shape[0] / label.shape[0] + pixels_per_cell_y = data[sample_idx].shape[1] / label.shape[1] + + for grid_cell_x_idx in range(label.shape[0]): + for grid_cell_y_idx in range(label.shape[1]): + if label[grid_cell_x_idx, grid_cell_y_idx, 0] == 1: + duration = label[grid_cell_x_idx, grid_cell_y_idx, 2]*data[sample_idx].shape[0] + bandwidth = label[grid_cell_x_idx, grid_cell_y_idx, 4]*data[sample_idx].shape[1] + start_pixel = (grid_cell_x_idx*pixels_per_cell_x) + (label[grid_cell_x_idx, grid_cell_y_idx, 1]*pixels_per_cell_x) - duration/2 + low_freq = (grid_cell_y_idx*pixels_per_cell_y) + (label[grid_cell_x_idx, grid_cell_y_idx, 3]*pixels_per_cell_y) \ + - (label[grid_cell_x_idx, grid_cell_y_idx, 4]/2 * data[sample_idx].shape[1]) + + rect = patches.Rectangle( + (start_pixel,low_freq), + duration, + bandwidth, # Bandwidth (pixels) + linewidth=3, + edgecolor='b', + facecolor='none' + ) + ax.add_patch(rect) + plt.imshow(data[sample_idx], aspect='auto', cmap="jet",vmin=np.min(data[sample_idx]),vmax=np.max(data[sample_idx])) + plt.xticks([]) + plt.yticks([]) + plt.title("Data") + + return figure + + +class AnchorBoxVisualizer(Visualizer): + """ Visualize data with anchor box label information overlaid + + Args: + **kwargs: + """ + def __init__( + self, + data_loader: dataloader, + visualize_transform: Optional[Callable] = None, + visualize_target_transform: Optional[Callable] = None, + anchor_boxes: List = None, + ): + self.data_loader = iter(data_loader) + self.visualize_transform = visualize_transform + self.visualize_target_transform = visualize_target_transform + self.anchor_boxes = anchor_boxes + self.num_anchor_boxes = len(anchor_boxes) + + def __next__(self) -> Figure: + iq_data, targets = next(self.data_iter) + + if self.visualize_transform: + iq_data = self.visualize_transform(deepcopy(iq_data)) + + if self.visualize_target_transform: + targets = self.visualize_target_transform(deepcopy(targets)) + else: + targets = targets + + return self._visualize(iq_data, targets) + + def _visualize(self, data: np.ndarray, targets: np.ndarray) -> Figure: + batch_size = data.shape[0] + figure = plt.figure(frameon=False) + for sample_idx in range(batch_size): + ax = plt.subplot(int(np.ceil(np.sqrt(batch_size))), + int(np.sqrt(batch_size)), sample_idx + 1) + + # Retrieve individual label + ax.imshow( + data[sample_idx], + vmin=np.min(data[sample_idx]), + vmax=np.max(data[sample_idx]), + cmap="jet", + ) + label = targets[sample_idx] + pixels_per_cell_x = data[sample_idx].shape[0] / label.shape[0] + pixels_per_cell_y = data[sample_idx].shape[1] / label.shape[1] + + for grid_cell_x_idx in range(label.shape[0]): + for grid_cell_y_idx in range(label.shape[1]): + for anchor_idx in range(self.num_anchor_boxes): + if label[grid_cell_x_idx, grid_cell_y_idx, 0+5*anchor_idx] == 1: + duration = label[grid_cell_x_idx, grid_cell_y_idx, 2+5*anchor_idx]*self.anchor_boxes[anchor_idx][0]*data[sample_idx].shape[0] + bandwidth = label[grid_cell_x_idx, grid_cell_y_idx, 4+5*anchor_idx]*self.anchor_boxes[anchor_idx][1]*data[sample_idx].shape[1] + start_pixel = (grid_cell_x_idx*pixels_per_cell_x) + (label[grid_cell_x_idx, grid_cell_y_idx, 1+5*anchor_idx]*pixels_per_cell_x) - duration/2 + low_freq = (grid_cell_y_idx*pixels_per_cell_y) + (label[grid_cell_x_idx, grid_cell_y_idx, 3+5*anchor_idx]*pixels_per_cell_y) \ + - (label[grid_cell_x_idx, grid_cell_y_idx, 4+5*anchor_idx]*self.anchor_boxes[anchor_idx][1]/2 * data[sample_idx].shape[1]) + + rect = patches.Rectangle( + (start_pixel,low_freq), + duration, + bandwidth, # Bandwidth (pixels) + linewidth=3, + edgecolor='b', + facecolor='none' + ) + ax.add_patch(rect) + + plt.imshow(data[sample_idx], aspect='auto', cmap="jet",vmin=np.min(data[sample_idx]),vmax=np.max(data[sample_idx])) + plt.xticks([]) + plt.yticks([]) + plt.title("Data") + + return figure + ############################################################################### # Visualizer Transform Functions @@ -358,7 +699,7 @@ def onehot_label_format(tensor: np.ndarray) -> List[str]: return label -def multihot_label_format(tensor: np.ndarray, class_list: List[str]) -> List[List[str]]: +def multihot_label_format(tensor: np.ndarray, class_list: List[str]) -> List[str]: """Target Transform: Format multihot labels for titles in visualizer """ @@ -371,3 +712,87 @@ def multihot_label_format(tensor: np.ndarray, class_list: List[str]) -> List[Lis curr_label.append(class_list[class_idx]) label.append(curr_label) return label + + +def mask_to_outline(tensor: np.ndarray) -> List[str]: + """Target Transform: Transforms masks for all bursts to outlines for the + MaskVisualizer. Overlapping mask outlines are represented as a single + polygon. + + """ + batch_size = tensor.shape[0] + labels = [] + struct = ndimage.generate_binary_structure(2,2) + for idx in range(batch_size): + label = tensor[idx].numpy() + label = np.sum(label, axis=0) + label[label>0] = 1 + label = label - ndimage.binary_erosion(label) + label = ndimage.binary_dilation(label, structure=struct, iterations=3).astype(label.dtype) + label = np.ma.masked_where(label == 0, label) + labels.append(label) + return labels + + +def mask_to_outline_overlap(tensor: np.ndarray) -> List[str]: + """Target Transform: Transforms masks for each burst to individual outlines + for the MaskVisualizer. Overlapping mask outlines are still shown as + overlapping. + + """ + batch_size = tensor.shape[0] + labels = [] + struct = ndimage.generate_binary_structure(2,2) + for idx in range(batch_size): + label = tensor[idx].numpy() + for individual_burst_idx in range(label.shape[0]): + label[individual_burst_idx] = label[individual_burst_idx] - \ + ndimage.binary_erosion(label[individual_burst_idx]) + label = np.sum(label, axis=0) + label[label>0] = 1 + label = ndimage.binary_dilation(label, structure=struct, iterations=2).astype(label.dtype) + label = np.ma.masked_where(label == 0, label) + labels.append(label) + return labels + + +def overlay_mask(tensor: np.ndarray) -> List[str]: + """Target Transform: Transforms multi-dimensional mask to binary overlay of + full mask. + + """ + batch_size = tensor.shape[0] + labels = [] + for idx in range(batch_size): + label = torch.sum(tensor[idx], axis=0).numpy() + label[label>0] = 1 + label = np.ma.masked_where(label == 0, label) + labels.append(label) + return labels + + +def mask_class_to_outline(tensor: np.ndarray) -> List[str]: + """Target Transform: Transforms masks for each burst to individual outlines + for the MaskClassVisualizer. Overlapping mask outlines are still shown as + overlapping. Each bursts' class index is also returned. + + """ + batch_size = tensor.shape[0] + labels = [] + class_idx = [] + struct = ndimage.generate_binary_structure(2,2) + for idx in range(batch_size): + label = tensor[idx].numpy() + class_idx_curr = [] + for individual_burst_idx in range(label.shape[0]): + if np.count_nonzero(label[individual_burst_idx]) > 0: + class_idx_curr.append(individual_burst_idx) + label[individual_burst_idx] = label[individual_burst_idx] - \ + ndimage.binary_erosion(label[individual_burst_idx]) + label = np.sum(label, axis=0) + label[label>0] = 1 + label = ndimage.binary_dilation(label, structure=struct, iterations=2).astype(label.dtype) + label = np.ma.masked_where(label == 0, label) + class_idx.append(class_idx_curr) + labels.append(label) + return class_idx, labels