# TorchSig Iterable Dataset Example
This notebook showcases the TorchSigIterableDataset dataset.

---

In [None]:
# Define Variables
num_signals_max = 5
num_signals_min = 1

## Dataset Metadata
In order to create a TorchSigIterableDataset, you must define parameters for any signals included in the dataset. Different signal types may require different parameters, however most will require a sample_rate, bandwidth_min/bandwidth_max, and frequency_min/frequency_max.

Here we are importing and using default parameters from TorchSig's defaults module. These default parameters are intended to allow users to quickly make reasonable looking signal data, and should correspond to the default parameters used in previous TorchSig versions.

In [None]:
from torchsig.utils.defaults import TorchSigDefaults

dataset_metadata = TorchSigDefaults().default_dataset_metadata
dataset_metadata["num_signals_min"] = num_signals_min
dataset_metadata["num_signals_max"] = num_signals_max
print(dataset_metadata)

## Synthetic Dataset
To create a new dataset, simply instantiate a TorchSigIterableDataset object.

In [None]:
from torchsig.datasets.datasets import TorchSigIterableDataset
from torchsig.transforms.transforms import Spectrogram

# Without target_labels, returns Signal objects with rich metadata
dataset = TorchSigIterableDataset(
    metadata=dataset_metadata,
    transforms=[Spectrogram(fft_size=dataset_metadata["fft_size"])],
)
from matplotlib import pyplot as plt

In [None]:
plt.imshow(dataset().data)

The resulting Dataset has several SignalGenerators attached to it by default, each responsible for creating data for different signal types.
Right now these generators all share the same metadata, but they can also be modified individually as needed.
For example:

In [None]:
for generator in dataset.signal_generators:
    if "ask" in generator["class_name"]:
        generator["signal_duration_max"] = 240000  # lower than default max duration
        generator["anomalous"] = True
    else:
        generator["anomalous"] = False

In [None]:
[signal.anomalous for signal in dataset().component_signals]

Here we've shortened the maximum duration for the ASK signals in out dataset, so that they will on average be shorter than other signals.
We've also set a new metadata variable 'anomalous', which is true for our ASK signals and False otherwise.
This naturally sets us up for classification or anomaly detection models.
If we wanted, we could add other, non-anomalous ASK signal generators with different parameters to make detecting them harder.

In [None]:
from torchsig.datasets.datasets import TorchSigIterableDataset

# Without target_labels, returns Signal objects with rich metadata
dataset = TorchSigIterableDataset(
    metadata=dataset_metadata,
)

for i in range(5):
    signal = next(dataset)

    print(f"IQ Data shape: {signal.data.shape}")
    print(f"Component Signals: {len(signal.component_signals)}")

    # Access metadata from component signals
    for j, comp_signal in enumerate(signal.component_signals):
        print(f"  Signal {j}: {comp_signal.class_name}")
    print()

### Time Domain Plot
The default data returned from datasets are IQ samples.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# With target_labels, returns (data, metadata) tuples
dataset_time_series = TorchSigIterableDataset(
    metadata=dataset_metadata, target_labels=["class_index"]
)

data, metadata = next(dataset_time_series)
t = np.arange(0, len(data)) / dataset_time_series.sample_rate

fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(1, 1, 1)
ax.plot(t, np.real(data), alpha=0.5, label="Real")
ax.plot(t, np.imag(data), alpha=0.5, label="Imag")
ax.set_xlim([t[0], t[-1]])
ax.set_xlabel("Time (sec)")
ax.set_ylabel("Amplitude")
ax.set_title("Time Domain")
ax.grid()

### Spectrogram Plot
Spectrograms can be generated by passing in the `Spectrogram()` transform into the metadata. For more detailed code and information on spectrogram plots, please see following section.

In [None]:
from torchsig.transforms.transforms import Spectrogram

dataset_spectrogram = TorchSigIterableDataset(
    metadata=dataset_metadata,
    target_labels=["class_index"],
    transforms=[Spectrogram(fft_size=dataset_metadata["fft_size"])],
)

data, metadata = next(dataset_spectrogram)

fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(1, 1, 1)
ax.imshow(data, aspect="auto", cmap="Wistia", vmin=0)
ax.set_xlabel("Time Axis")
ax.set_ylabel("Frequency Axis")

### Writing Dataset to Disk
In order to access previously generated examples, or save the finite dataset for later, use the `DatasetCreator`. Pass in the Dataset to be saved, where to write the dataset (root), and whether to overwrite any existing datasets. `num_samples` must be defined, otherwise the `DatasetCreator` will attempt to create an infinite dataset.

**Note:** The new DatasetCreator API is simpler and uses HDF5 storage only.

In [None]:
from torchsig.utils.writer import DatasetCreator
from torchsig.datasets.datasets import TorchSigIterableDataset
from torchsig.utils.data_loading import WorkerSeedingDataLoader

root = "./datasets/create_dataset_example"
seed = 123456789

In [None]:
# Don't use target_labels to get Signal objects with rich metadata
dataset = TorchSigIterableDataset(
    metadata=dataset_metadata,
    transforms=[Spectrogram(fft_size=dataset_metadata["fft_size"])],
)

dataset_length = len(dataset.signal_generators) * 10

dataloader = WorkerSeedingDataLoader(dataset, batch_size=11)
dataloader.seed(seed)

# New simplified DatasetCreator API
dataset_creator = DatasetCreator(
    dataset_length=dataset_length,
    dataloader=dataloader,
    root=root,
    overwrite=True,
    multithreading=False,
)

dataset_creator.create()

### Reading Dataset from Disk
Assuming you wrote a dataset to disk, you can load it back in my instantiating a `StaticTorchSigDataset`.

Warning: The following code assumes you have the code in the section **Writing Dataset to Disk**

In [None]:
from torchsig.datasets.datasets import StaticTorchSigDataset

static_dataset = StaticTorchSigDataset(
    root=root, target_labels=["class_name", "start", "stop", "lower_freq", "upper_freq"]
)

# can access any sample
print(static_dataset[100][1])
print(static_dataset[4])

If the dataset written to disk is raw (aka no Transforms or Target Transforms were applied to it before writing to disk), then you can define whatever transforms and target transforms for use by the static dataset.

In [None]:
def update_spectrogram(ax, start_time, stop_time, lower, upper):
    start_time = float(start_time)
    stop_time = float(stop_time)
    lower = float(lower)
    upper = float(upper)
    ax.plot([start_time, start_time], [lower, upper], "b", alpha=0.5)
    ax.plot([stop_time, stop_time], [lower, upper], "b", alpha=0.5)
    ax.plot([start_time, stop_time], [lower, lower], "b", alpha=0.5)
    ax.plot([start_time, stop_time], [upper, upper], "b", alpha=0.5)
    textDisplay = str(classname)
    ax.text(
        start_time, lower, textDisplay, bbox=dict(facecolor="w", alpha=0.5, linewidth=0)
    )
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    ax.set_ylabel("Frequency (Hz)")
    ax.set_xlabel(f"Time ({xunits})")

In [None]:
from torchsig.datasets.datasets import StaticTorchSigDataset
import matplotlib.pyplot as plt

sample_rate = dataset.sample_rate
noise_power_db = dataset.noise_power_db
num_iq_samples_dataset = dataset.num_iq_samples_dataset

num_show = 4
num_cols = 2
num_rows = num_show // num_cols

fig = plt.figure(figsize=(18, 12))
fig.tight_layout()

max_time = num_iq_samples_dataset / sample_rate

ns_time = 1e-9
us_time = 1e-6
ms_time = 1e-3
s_time = 1

if max_time < ns_time:
    max_time /= ns_time
    xunits = "ns"
elif max_time < us_time:
    max_time /= us_time
    xunits = "us"
elif max_time < ms_time:
    max_time /= ms_time
    xunits = "ms"
else:
    max_time /= s_time
    xunits = "s"

xmin = 0 * max_time
xmax = 1 * max_time

ymin = -sample_rate / 2
ymax = sample_rate / 2

for i in range(num_show):

    data, targets = static_dataset[i]
    ax = fig.add_subplot(num_rows, num_cols, i + 1)
    pos = ax.imshow(
        data,
        extent=[xmin, xmax, ymin, ymax],
        aspect="auto",
        cmap="Wistia",
        vmin=noise_power_db,
    )
    fig.colorbar(pos, ax=ax, label="SNR")
    for i in range(len(targets[0])):
        classname = targets[0][i]
        start = targets[1][i]
        stop = targets[2][i]
        lower = targets[3][1 + i]
        upper = targets[4][1 + i]

        # convert normalized time into real-world time
        start_time = float(start) * max_time
        stop_time = float(stop) * max_time

        update_spectrogram(ax, start_time, stop_time, lower, upper)

## Dataset Statistics

Below are some plots and statistics about the dataset.

In [None]:
# By default, `TorchSigIterableDataset` uses all the signals available.
from torchsig.signals.signal_lists import TORCHSIG_NUM_SIGNALS, TORCHSIG_NUM_FAMILIES

print(f"Number of signal classes in torchsig: {TORCHSIG_NUM_SIGNALS}")
print(f"Number of signal families in torchsig: {TORCHSIG_NUM_FAMILIES}")

In [None]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

class_list = dataset.class_names
class_counter = {class_name: 0 for class_name in class_list}
num_signals_per_sample = []

for sample in tqdm(static_dataset, desc="Calculating Dataset Stats"):
    data, targets = sample
    for i in range(len(targets[0])):
        classname = targets[0][i]
        num_signals_per_sample.append(len(targets[0]))
        class_counter[classname] += 1

class_counts = list(class_counter.values())
class_names = list(class_counter.keys())

In [None]:
# Class Distribution Pie Chart
# by default, the class distribution is None aka uniform
plt.figure(figsize=(15, 15))
plt.pie(class_counts, labels=class_names)
plt.title("Class Distribution Pie Chart")

In [None]:
# Class Distribution Bar Chart
plt.figure(figsize=(18, 4))
plt.bar(class_names, class_counts)
plt.xticks(rotation=90)
plt.title("Class Distribution Bar Chart")
plt.xlabel("Modulation Class Name")
plt.ylabel("Counts")

In [None]:
# Number of signals per sample Distribution
import numpy as np

# Default number of signals per sample is 3
print(f"Min num signals Setting: {dataset.num_signals_min}")
print(f"Max num signals Setting: {dataset.num_signals_max}")

total = sum(num_signals_per_sample)
avg = np.mean(np.asarray(num_signals_per_sample))

plt.figure(figsize=(11, 8))
plt.hist(
    x=num_signals_per_sample, bins=np.arange(max(num_signals_per_sample) + 1) + 0.5
)
plt.xticks(np.arange(0, max(num_signals_per_sample) + 1, 1).tolist())
plt.title(
    f"Distribution of Number of Signals Per Sample\nTotal Number: {total}, Average: {avg}"
)
plt.xlabel("Number of Signal Bins")
plt.ylabel("Counts")

## Applying Impairments to datasets
To add realism to synthetic datasets, we've created some sets of transforms that simulate real world signal impairments on the data. These can be found in `transforms/impairments.py`, and come in multiple impairment 'levels' which determine how much the signal is being impaired. More on impairments can be found in the `examples/transforms/impairments notebook`.

In the example below, we are applying level 1 impairments to a dataset.

In [None]:
from torchsig.transforms.impairments import Impairments

In [None]:
impairments = Impairments(level=1)
burst_impairments = impairments.signal_transforms
whole_signal_impairments = impairments.dataset_transforms

burst_impairments, whole_signal_impairments

In [None]:
dataset_impaired = TorchSigIterableDataset(
    metadata=dataset_metadata,
    transforms=[
        whole_signal_impairments,
        Spectrogram(fft_size=dataset_metadata["fft_size"]),
    ],
    component_transforms=[burst_impairments],
    target_labels=[],
)

dataset_unimpaired = TorchSigIterableDataset(
    metadata=dataset_metadata,
    transforms=[Spectrogram(fft_size=dataset_metadata["fft_size"])],
    component_transforms=[],
    target_labels=[],
)

dataset_impaired.seed(seed)
dataset_unimpaired.seed(seed)

In [None]:
unimpaired_data = next(dataset_unimpaired)
impaired_data = next(dataset_impaired)

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(1, 2, 1)
ax.imshow(unimpaired_data, cmap="Wistia", vmin=0)
ax.set_xlabel("Time Axis")
ax.set_ylabel("Frequency Axis")
ax.set_title("Un-Impaired Data")

ax2 = fig.add_subplot(1, 2, 2)
ax2.imshow(impaired_data, cmap="Wistia", vmin=0)
ax2.set_xlabel("Time Axis")
ax2.set_ylabel("Frequency Axis")
ax2.set_title("Impaired Data")