# TorchSig Iterable Dataset Example
This notebook showcases the TorchSigIterableDataset dataset.

---

In [None]:
# Define Variables

num_iq_samples_dataset = 4096 # 64^2
fft_size = 64
num_signals_max = 5
num_signals_min = 1

## Dataset Metadata
In order to create a TorchSigIterableDataset, you must define parameters in DatasetMetadata. This can be done either in code or inside a YAML file. Below we show how to do both. Look at `create_dataset_example.yaml` for a sample YAML file.

There are three required parameters: 
1. `num_iq_samples_dataset` -> how much IQ data per sample
2. `fft_size` -> Size of FFT (number of bins) to be used in spectrogram.
3. `num_signals_max` -> maximum number of signals per sample.

Additionally, there are several optional parameters that can be overridden.

In [None]:
# Option 1: Instantiate DatasetMetadata object
from torchsig.datasets.dataset_metadata import DatasetMetadata

dataset_metadata_1 = DatasetMetadata(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    num_signals_max = num_signals_max,
    num_signals_min = num_signals_min,
)
print(dataset_metadata_1)

In [None]:
# Option 2: Instantiate as a dictionary object

dataset_metadata_2 = dict(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    num_signals_max = num_signals_max,
    num_signals_min = num_signals_min,
    impairment_level = 0  # Added required parameter
)
print(dataset_metadata_2)

## Synthetic Dataset
To create a new dataset, simply instantiate a TorchSigIterableDataset object.

In [None]:
from torchsig.datasets.dataset_metadata import DatasetMetadata
from torchsig.datasets.datasets import TorchSigIterableDataset

dataset_infinite_metadata = DatasetMetadata(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    num_signals_max = num_signals_max,
    num_signals_min = num_signals_min
)

# Without target_labels, returns Signal objects with rich metadata
dataset_infinite = TorchSigIterableDataset(
    dataset_metadata = dataset_infinite_metadata
)

for i in range(5):
    signal = next(dataset_infinite)
    
    print(f"IQ Data shape: {signal.data.shape}")
    print(f"Component Signals: {len(signal.component_signals)}")
    
    # Access metadata from component signals
    for j, comp_signal in enumerate(signal.component_signals):
        print(f"  Signal {j}: {comp_signal.metadata.class_name}, SNR: {comp_signal.metadata.snr_db:.1f}dB")
    print()

### Time Domain Plot
The default data returned from datasets are IQ samples.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

dataset_finite_time_series_metadata = DatasetMetadata(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    num_signals_max = num_signals_max,
    num_signals_min = num_signals_min
)

# With target_labels, returns (data, metadata) tuples
dataset_finite_time_series = TorchSigIterableDataset(
    dataset_metadata = dataset_finite_time_series_metadata,
    target_labels=["class_index"]
)

data, metadata = next(dataset_finite_time_series)
t = np.arange(0,len(data))/dataset_finite_time_series_metadata.sample_rate

fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1,1,1)
ax.plot(t,np.real(data),alpha=0.5,label='Real')
ax.plot(t,np.imag(data),alpha=0.5,label='Imag')
ax.set_xlim([t[0], t[-1]])
ax.set_xlabel('Time (sec)')
ax.set_ylabel('Amplitude')
ax.set_title('Time Domain')
ax.grid()

### Spectrogram Plot
Spectrograms can be generated by passing in the `Spectrogram()` transform into the metadata. For more detailed code and information on spectrogram plots, please see following section.

In [None]:
from torchsig.transforms.transforms import Spectrogram

dataset_finite_spectrogram_metadata = DatasetMetadata(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    num_signals_max = num_signals_max,
    num_signals_min = num_signals_min
)

dataset_finite_spectrogram = TorchSigIterableDataset(
    dataset_metadata = dataset_finite_spectrogram_metadata,
    target_labels=["class_index"],
    transforms = [Spectrogram(fft_size=fft_size)]
)

data, metadata = next(dataset_finite_spectrogram)

fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(1,1,1)
ax.imshow(data,aspect='auto',cmap='Wistia',vmin=0)
ax.set_xlabel('Time Axis')
ax.set_ylabel('Frequency Axis')

### Writing Dataset to Disk
In order to access previously generated examples, or save the finite dataset for later, use the `DatasetCreator`. Pass in the Dataset to be saved, where to write the dataset (root), and whether to overwrite any existing datasets. `num_samples` must be defined, otherwise the `DatasetCreator` will attempt to create an infinite dataset.

**Note:** The new DatasetCreator API is simpler and uses HDF5 storage only.

In [None]:
from torchsig.datasets.dataset_metadata import DatasetMetadata
from torchsig.utils.writer import DatasetCreator, default_collate_fn
from torchsig.signals.signal_lists import TorchSigSignalLists
from torchsig.datasets.datasets import TorchSigIterableDataset
from torchsig.utils.data_loading import WorkerSeedingDataLoader
from torchsig.transforms.transforms import Spectrogram

root = "./datasets/create_dataset_example"
class_list = TorchSigSignalLists.all_signals
dataset_length = len(class_list) * 10
seed = 123456789

dataset_finite_metadata = DatasetMetadata(
    num_iq_samples_dataset = num_iq_samples_dataset,
    fft_size = fft_size,
    num_signals_max = num_signals_max,
    num_signals_min = num_signals_min,
)

# Don't use target_labels to get Signal objects with rich metadata
dataset = TorchSigIterableDataset(
    dataset_metadata = dataset_finite_metadata,
    transforms = [Spectrogram(fft_size = fft_size)],
    target_labels = ["class_name", "start", "stop", "lower_freq", "upper_freq", "snr_db"],
)

dataloader = WorkerSeedingDataLoader(dataset, batch_size=11, num_workers=4, collate_fn=default_collate_fn)
dataloader.seed(seed)

# New simplified DatasetCreator API
dataset_creator = DatasetCreator(
    dataset_length=dataset_length,
    dataloader = dataloader,
    root = root,
    overwrite = True,
    multithreading=False
)

dataset_creator.create()

### Reading Dataset from Disk
Assuming you wrote a dataset to disk, you can load it back in my instantiating a `StaticTorchSigDataset`.

Warning: The following code assumes you have the code in the section **Writing Dataset to Disk**

In [None]:
from torchsig.datasets.datasets import StaticTorchSigDataset

static_dataset = StaticTorchSigDataset(
    root = root,
    target_labels=dataset.target_labels
)

# can access any sample
print(static_dataset[0][1])
print(static_dataset[5])

If the dataset written to disk is raw (aka no Transforms or Target Transforms were applied to it before writing to disk), then you can define whatever transforms and target transforms for use by the static dataset.

In [None]:
static_dataset[10]

In [None]:
def update_spectrogram ( ax, start_time, stop_time, lower, upper):
    start_time = float(start_time)
    stop_time = float(stop_time)
    lower = float(lower)
    upper = float(upper)
    ax.plot([start_time,start_time],[lower,upper],'b',alpha=0.5)
    ax.plot([stop_time, stop_time],[lower,upper],'b',alpha=0.5)
    ax.plot([start_time,stop_time],[lower,lower],'b',alpha=0.5)
    ax.plot([start_time,stop_time],[upper,upper],'b',alpha=0.5)
    textDisplay = str(classname) + ', SNR = ' + str(snr) + ' dB'
    ax.text(start_time,lower,textDisplay, bbox=dict(facecolor='w', alpha=0.5, linewidth=0))
    ax.set_xlim([xmin,xmax])
    ax.set_ylim([ymin,ymax])
    ax.set_ylabel("Frequency (Hz)")
    ax.set_xlabel(f"Time ({xunits})")

In [None]:
from torchsig.datasets.datasets import StaticTorchSigDataset
import matplotlib.pyplot as plt

static_dataset = StaticTorchSigDataset(
    root = root,
    target_labels=dataset.target_labels
)
sample_rate = dataset.dataset_metadata.sample_rate
noise_power_db = dataset.dataset_metadata.noise_power_db

num_show = 4
num_cols = 2
num_rows = num_show // num_cols

fig = plt.figure(figsize=(18,12))
fig.tight_layout()

max_time = num_iq_samples_dataset / sample_rate

ns_time = 1e-9
us_time = 1e-6
ms_time = 1e-3
s_time = 1

if (max_time < ns_time):
    max_time /= ns_time
    xunits = 'ns'
elif (max_time < us_time):
    max_time /= us_time
    xunits = 'us'
elif (max_time < ms_time):
    max_time /= ms_time
    xunits = 'ms'
else:
    max_time /= s_time
    xunits = 's'

xmin = 0 * max_time
xmax = 1 * max_time

ymin = -sample_rate / 2
ymax = sample_rate / 2

for i in range(num_show):

    data, targets = static_dataset[i]
    print(targets)
    ax = fig.add_subplot(num_rows,num_cols,i + 1)
    pos = ax.imshow(data,extent=[xmin,xmax,ymin,ymax],aspect='auto',cmap='Wistia',vmin=noise_power_db)
    fig.colorbar(pos, ax=ax, label='SNR')

    for t in targets:
        classname, start, stop, lower, upper, snr = t

        # convert normalized time into real-world time
        start_time = float(start) * max_time
        stop_time = float(stop) * max_time

        update_spectrogram( ax, start_time, stop_time, lower, upper)

## Dataset Statistics

Below are some plots and statistics about the dataset.

In [None]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

class_list = dataset.dataset_metadata.class_list
class_counter = {class_name: 0 for class_name in class_list}
snr_list = []
num_signals_per_sample = []

for sample in tqdm(static_dataset, desc = "Calculating Dataset Stats"):
    data, targets = sample

    if (isinstance(targets,tuple)):
        num_signals_per_sample.append(1)
        t = targets
        classname, start, stop, lower, upper, snr = t
        class_counter[classname] += 1
        snr_list.append(snr)
    elif (isinstance(targets,list)):
        num_signals_per_sample.append(len(targets))
        for t in targets:
            classname, start, stop, lower, upper, snr = t
            class_counter[classname] += 1
            snr_list.append(snr)

    

class_counts = list(class_counter.values())
class_names = list(class_counter.keys())


In [None]:
# Class Distribution Pie Chart
# by default, the class distribution is None aka uniform
print(f"Class Distribution Setting: {dataset.dataset_metadata.class_distribution}")
plt.figure(figsize=(15, 15))
plt.pie(class_counts, labels = class_names)
plt.title("Class Distribution Pie Chart")

In [None]:
# Class Distribution Bar Chart
plt.figure(figsize=(18, 4))
plt.bar(class_names, class_counts)
plt.xticks(rotation=90)
plt.title("Class Distribution Bar Chart")
plt.xlabel("Modulation Class Name")
plt.ylabel("Counts")

In [None]:
# Number of signals per sample Distribution
import numpy as np
# Default number of signals per sample is 3
print(f"Min num signals Setting: {dataset.dataset_metadata.num_signals_min}")
print(f"Max num signals Setting: {dataset.dataset_metadata.num_signals_max}")

total = sum(num_signals_per_sample)
avg = np.mean(np.asarray(num_signals_per_sample))

plt.figure(figsize=(11,8))
plt.hist(x=num_signals_per_sample, bins=np.arange(max(num_signals_per_sample) + 1) + 0.5)
plt.xticks(np.arange(0, max(num_signals_per_sample) + 1, 1).tolist())
plt.title(f"Distribution of Number of Signals Per Sample\nTotal Number: {total}, Average: {avg}")
plt.xlabel("Number of Signal Bins")
plt.ylabel("Counts")

In [None]:
# SNR Distributions
# Default min = 0 db and max = 50 db
print(f"Min SNR Setting: {dataset.dataset_metadata.snr_db_min}")
print(f"Max SNR Setting: {dataset.dataset_metadata.snr_db_max}")
plt.figure(figsize=(11, 4))
plt.hist(x=snr_list, bins=100)
plt.title("SNR Distribution")
plt.xlabel("SNR Bins (dB)")
plt.ylabel("Counts")

## Applying Impairments to datasets
To add realism to synthetic datasets, we've created some sets of transforms that simulate real world signal impairments on the data. These can be found in `transforms/impairments.py`, and come in multiple impairment 'levels' which determine how much the signal is being impaired. More on impairments can be found in the `examples/transforms/impairments notebook`.

In the example below, we are applying level 1 impairments to a dataset.

In [None]:
from torchsig.transforms.impairments import Impairments

In [None]:
impairments = Impairments(level=1)
burst_impairments = impairments.signal_transforms
whole_signal_impairments = impairments.dataset_transforms

burst_impairments, whole_signal_impairments

In [None]:
dataset_impaired = TorchSigIterableDataset(
    dataset_metadata=dataset_finite_metadata, 
    transforms=[whole_signal_impairments, Spectrogram(fft_size=fft_size)], 
    component_transforms=[burst_impairments], 
    target_labels=[])

dataset_unimpaired = TorchSigIterableDataset(
    dataset_metadata=dataset_finite_metadata, 
    transforms=[Spectrogram(fft_size=fft_size)], 
    component_transforms=[], 
    target_labels=[])

dataset_impaired.seed(seed)
dataset_unimpaired.seed(seed)

In [None]:

fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(1,2,1)
ax.imshow(next(dataset_unimpaired),cmap='Wistia',vmin=0)
ax.set_xlabel('Time Axis')
ax.set_ylabel('Frequency Axis')
ax.set_title('Un-Impaired Data')

ax2 = fig.add_subplot(1,2,2)
ax2.imshow(next(dataset_impaired),cmap='Wistia',vmin=0)
ax2.set_xlabel('Time Axis')
ax2.set_ylabel('Frequency Axis')
ax2.set_title('Impaired Data')