# Custom File Handler Example

This notebook shows how to create a custom file handler to write the datasets in a different format. For this example, we will show how to write data as a PNG image and target labels as a yaml file.

---

The `DatasetCreator` and any `StaticDatasets` use a `BaseFileHandler` to write and read data, respectively. By default, they both use the `HDF5FileHandler` class. But if you want the dataset to be written and read differently, you can create a subclass of `BaseFileHandler`. 

**Note**: As of TorchSig 2.0, the default file handler is HDF5-based for improved performance and metadata storage.

In [ ]:
# TorchSig
from torchsig.utils.file_handlers.base_handler import FileWriter, FileReader
from torchsig.utils.writer import DatasetCreator
from torchsig.utils.data_loading import WorkerSeedingDataLoader
from torchsig.utils.defaults import TorchSigDefaults
from torchsig.signals.signal_types import Signal
from torchsig.datasets.datasets import TorchSigIterableDataset, StaticTorchSigDataset
from torchsig.transforms.transforms import Spectrogram

# Third Party
import cv2
import yaml
import pathlib
import matplotlib.pyplot as plt
import numpy as np

In [ ]:
# Conversion functions for YAML file handling.
def _to_builtin(val):
    """Convert NumPy types to built-in Python types."""
    if isinstance(val, np.generic):
        return val.item()
    else:
        return val


def signal_to_dict(signal: Signal) -> tuple:
    """Recursively convert a Signal object into a dictionary for YAML serialization."""
    return {
        "metadata": {key: _to_builtin(val) for key, val in signal.metadata.items()},
        "component_signals": [
            signal_to_dict(child) for child in signal.component_signals
        ],
    }


def dict_to_signal(data: dict) -> Signal:
    """Recursively construct a Signal object from a nested dictionary (YAML deserialization)."""
    # Create Signal from metadata
    sig = Signal(metadata=data.get("metadata"))
    # Recurse for each component signal (if any)
    for child_data in data.get("component_signals", []):
        sig.component_signals.append(dict_to_signal(child_data))
    return sig

In [ ]:
class SpectrogramWriter(FileWriter):
    """TorchSig FileWriter for handling our custom storage format. Data
    is written as a png image file. Metadata is written using human-
    readable yaml files, one per element.

    """

    def _setup(self) -> None:
        """Setup directory for storing spectrogram images and targets."""
        self.spectrogram_dir = self.root.joinpath("spectrograms")
        self.target_dir = self.root.joinpath("targets")
        self.spectrogram_dir.mkdir(parents=True, exist_ok=True)
        self.target_dir.mkdir(parents=True, exist_ok=True)

    def write(self, batch_idx: int, batch: list) -> None:
        """Write a single batch of Signals as spectrograms and targets to disk.

        Args:
            batch_idx (int): Index of the batch being written.
            batch (list): List of Signal objects in batch.

        """
        for idx, sig in enumerate(batch):
            spectrogram = sig.data  # numpy spectrogram data
            metadata = signal_to_dict(sig)  # convert Signal to dictionary structure

            # First normalize data from (0-255)
            mi = spectrogram.min()
            ma = spectrogram.max()
            spectrogram = ((spectrogram - mi) / (ma - mi)) * 255
            # convert to 8-bit
            spectrogram = spectrogram.astype(np.uint8)

            # apply colormap
            spectrogram = cv2.applyColorMap(spectrogram, cv2.COLORMAP_HOT)

            # Save spectrogram as PNG
            spectrogram_path = self.spectrogram_dir.joinpath(
                f"spectrogram_{batch_idx * len(batch) + idx}.png"
            )
            cv2.imwrite(str(spectrogram_path), spectrogram)

            # Save target as YAML: one yaml file per element
            target_path = self.target_dir.joinpath(
                f"target_{batch_idx * len(batch) + idx}.yaml"
            )
            with open(target_path, "w") as f:
                # Dump YAML with block style for readability
                yaml.dump(metadata, f, default_flow_style=False, sort_keys=False)

    def __len__(self) -> int:
        """Return the number of saved spectrograms."""
        return len(list(self.spectrogram_dir.glob("spectrogram_*.png")))


class SpectrogramReader(FileReader):
    """TorchSig FileReader for handling our custom storage format."""

    def __init__(self, root):
        super().__init__(root=root)
        self.spectrogram_dir = self.root.joinpath("spectrograms")
        self.target_dir = self.root.joinpath("targets")

    def read(self, idx: int) -> tuple:
        """Read a spectrogram and target by index.

        Args:
            idx (int): Index of the data to read.

        Returns:
            tuple: (spectrogram, target)
        """
        # Read the spectrogram data as a NumPy array
        spectrogram_path = self.root.joinpath("spectrograms", f"spectrogram_{idx}.png")
        spectrogram = cv2.imread(str(spectrogram_path), cv2.IMREAD_GRAYSCALE)

        # Read the metadata from yaml file
        target_path = self.root.joinpath("targets", f"target_{idx}.yaml")
        with open(target_path, "r") as f:
            raw_metadata = yaml.load(f, Loader=yaml.FullLoader)  # parse YAML into dict
            targets = dict_to_signal(raw_metadata)

        return spectrogram, targets

    def size(self) -> int:
        """Return the total number of spectrograms in the dataset."""
        spectrograms_path = self.root.joinpath("spectrograms")
        return len(list(spectrograms_path.glob("spectrogram_*.png")))

    def __len__(self) -> int:
        """Return the total number of spectrograms in the dataset."""
        return self.size()

In [ ]:
# Define dataset metadata
dm = TorchSigDefaults().default_dataset_metadata
dm["num_signals_max"] = 4

In [ ]:
# Exercise custom file writing and reading with TorchSig datasets

# Parameters
root = "./datasets/filehandler_example"
fft_size = 512
dataset_length = 10
transforms = [Spectrogram(fft_size=fft_size)]
target_labels = ["class_name", "class_index"]

# Define original dataset and dataloader
dataset = TorchSigIterableDataset(
    metadata=dm, transforms=transforms, target_labels=target_labels
)
dataloader = WorkerSeedingDataLoader(dataset, collate_fn=lambda x: x)
dataloader.seed(42)  # seed for reproducibility

# Create data on disk
dc = DatasetCreator(
    dataloader=dataloader,
    root=root,
    dataset_length=dataset_length,
    overwrite=True,
    # use our custom file handler class
    file_handler=SpectrogramWriter,
)
dc.create()

# Read data as a static dataset
s = StaticTorchSigDataset(
    root=root,
    # be sure to use the same class that was written to disk
    file_handler_class=SpectrogramReader,
)

# Read a particular element
data, targets = s[6]
print(f"\nData: {data.shape}")
print(f"Targets: {targets}")

In [ ]:
# Plot the data as an RGB image
plt.imshow(data)