# Custom File Handler Example

This notebook shows how to create a custom file handler to write the datasets in a different format. For this example, we will show how to write data as a PNG image and target labels as a yaml file.

---

The `DatasetCreator` and any `StaticDatasets` use a `BaseFileHandler` to write and read data, respectively. By default, they both use the `HDF5FileHandler` class. But if you want the dataset to be written and read differently, you can create a subclass of `BaseFileHandler`. 

**Note**: As of TorchSig 2.0, the default file handler is HDF5-based for improved performance and metadata storage.

In [None]:
from torchsig.datasets.dataset_metadata import load_dataset_metadata
from torchsig.utils.file_handlers.base_handler import FileWriter, FileReader, BaseFileHandler
import cv2
import yaml
import pathlib
import numpy as np

class SpectrogramWriter(FileWriter):

    def _setup(self) -> None:
        """Setup directory for storing spectrogram images and targets."""
        self.spectrogram_dir = self.root.joinpath("spectrograms")
        self.target_dir = self.root.joinpath("targets")
        self.spectrogram_dir.mkdir(parents=True, exist_ok=True)
        self.target_dir.mkdir(parents=True, exist_ok=True)

    def write(self, batch_idx: int, batch: list) -> None:
        """Write a single batch of Signals as spectrograms and targets to disk.

        Args:
            batch_idx (int): Index of the batch being written.
            batch (tuple): List of Signal objects in batch

        """
        for idx, sig in enumerate(batch):

            spectrogram = sig.data

            metadatas = sig.get_full_metadata()
            targets = {i: [] for i in range(len(metadatas))}
            for i, m in enumerate(metadatas):
                for k, v in m.to_dict().items():
                    if v is not None:
                        targets[i].append(v)

            # First normalize data from (0-255)
            mi = spectrogram.min()
            ma = spectrogram.max()
            spectrogram = ((spectrogram - mi) / (ma - mi)) * 255
            # convert to 8-bit
            spectrogram = spectrogram.astype(np.uint8)

            # apply colormap
            spectrogram = cv2.applyColorMap(spectrogram, cv2.COLORMAP_HOT)

            # Save spectrogram as PNG
            spectrogram_path = self.spectrogram_dir.joinpath(f"spectrogram_{batch_idx * len(batch) + idx}.png")
            cv2.imwrite(str(spectrogram_path), spectrogram)

            # Save target as YAML
            target_path = self.target_dir.joinpath(f"target_{batch_idx * len(batch) + idx}.yaml")
            with open(target_path, "w") as f:
                yaml.dump(targets, f, default_flow_style=False)

    def __len__(self) -> int:
        """Return the number of saved spectrograms."""
        return len(list(self.spectrogram_dir.glob("spectrogram_*.png")))

class SpectrogramReader(FileReader):


    def __init__(self, root):
        super().__init__(root=root)
        self.spectrogram_dir = self.root.joinpath("spectrograms")
        self.target_dir = self.root.joinpath("targets")
        self.dataset_metadata = load_dataset_metadata(self.dataset_info_filepath)


    def read(self, idx: int) -> tuple:
        """Read a spectrogram and target by index.

        Args:
            idx (int): Index of the data to read.

        Returns:
            tuple: (spectrogram, target)
        """
        # Read the spectrogram as a NumPy array
        spectrogram_path = self.root.joinpath("spectrograms", f"spectrogram_{idx}.png")
        spectrogram = cv2.imread(str(spectrogram_path), cv2.IMREAD_GRAYSCALE)

        # Read the target as a dictionary
        target_path = self.root.joinpath("targets", f"target_{idx}.yaml")
        targets = []
        with open(target_path, "r") as f:
            raw_targets = yaml.load(f, Loader=yaml.FullLoader)
            for item in raw_targets.keys():
                t = raw_targets[item]
                targets.append(t)

        return spectrogram, targets

    def size(self) -> int:
        """Return the total number of spectrograms in the dataset."""
        spectrograms_path = self.root.joinpath("spectrograms")
        return len(list(spectrograms_path.glob("spectrogram_*.png")))

    def __len__(self) -> int:
        """Return the total number of spectrograms in the dataset."""
        return self.size()


In [None]:
# Use this file handler
from torchsig.datasets.dataset_metadata import DatasetMetadata
from torchsig.datasets.datasets import TorchSigIterableDataset, StaticTorchSigDataset
from torchsig.utils.writer import DatasetCreator, default_collate_fn
from torchsig.transforms.transforms import Spectrogram
from torchsig.utils.data_loading import WorkerSeedingDataLoader

root = "./datasets/filehandler_example"
fft_size = 512
num_iq_samples_dataset = fft_size ** 2
dataset_length = 10
num_signals_min = 1
num_signals_max = 4

transforms = [Spectrogram(fft_size=fft_size)]

md = DatasetMetadata(
    num_iq_samples_dataset=num_iq_samples_dataset,
    fft_size=fft_size,
    num_signals_min = num_signals_min,
    num_signals_max = num_signals_max
)

dataset = TorchSigIterableDataset(
    dataset_metadata=md, 
    transforms=transforms, 
    
    target_labels=["class_name","class_index"]
)
dataloader = WorkerSeedingDataLoader(dataset, collate_fn=default_collate_fn)

dc = DatasetCreator(
    dataloader=dataloader,
    dataset_length=dataset_length,
    root=root,
    overwrite=True,
    # use our custom file handler class
    file_handler=SpectrogramWriter
) 
dc.create()

s = StaticTorchSigDataset(
    root=root,
    # be sure to use the same class that was written to disk
    file_handler_class=SpectrogramReader
)

data, targets = s[0]
print(f"\nData: {data.shape} ({type(data)})")
print(f"Targets: {targets}")

In [None]:
# We can plot the data as an RGB image

import matplotlib.pyplot as plt

plt.imshow(data)