In [3]:
import h5py
import numpy as np
import torch
import scipy.signal as signal
import matplotlib.pyplot as plt
import matplotlib as mpl
import time 
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import Dataset, DataLoader

# Load the RF singal IQ data


# Convert from 32bit float sample by interleaving real and imaginary parts into 64bit complex numbers
def hdf5_to_dataset(hdf5_file):
    """
    Load the IQ data from the HDF5 file
    """
    source_data = h5py.File(hdf5_file, "r")
    source_keys = list(source_data.keys())

    return source_data, source_keys

def interleaved_to_complex(source_data):
    """
    Convert interleaved IQ data to complex numbers
    """
    source_data = np.array(source_data)
    source_data = source_data / 32768
    source_data = source_data.astype(np.float32).view(np.complex64)

    return source_data

def iq_to_psd(iq_data, num_fft, sample_rate):
    """
    Compute the PSD of the input IQ data
    """

    f, psd = signal.welch(iq_data, fs=sample_rate, nperseg=num_fft)
    # psd = np.abs(np.fft.fft(iq_data))**2 / (num_fft*sample_rate)
    # psd_log = 10.0*np.log10(psd)
    # psd_shifted = np.fft.fftshift(psd_log)

    return psd

def render_psd(data):
    """
    Render the PSD data
    """
    center_freq = 2437000000  # Center frequency in Hz
    f, psd = signal.welch(data, fs=30000000, nfft=2048)
    f = f + center_freq  # Adjust frequency axis to center frequency
    psd_log = 10.0 * np.log10(psd)  # Convert PSD to dB scale
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(f, psd)  
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power Spectral Density (dB/Hz)')
    ax.set_title('Power Spectral Density')
    plt.show()


In [None]:
# Custom PyTorch Dataset to load IQ data from HDF5 and convert to PSD
class HDF5KeyValueDataset(Dataset):
    def __init__(self, h5_file_path):
        """
        Args:
            h5_file_path (str): Path to the HDF5 file.
        """
        self.h5_file_path = h5_file_path
        
        # Open the HDF5 file and get all keys (datasets) in the file
        f = h5py.File(self.h5_file_path, 'r')
        self.keys = list(f.keys())  # List all keys (dataset names)
        self.dataset = f  # Store the file object for later use

    def __len__(self):
        # Return the number of datasets (keys) in the HDF5 file
        return len(self.keys)

    def __getitem__(self, idx):
        # Create a dictionary to hold the processed data
        data = {}

        # Retrieve IQ data for the specified key (dataset)
        iq_data = self.dataset.get(self.keys[idx])

        # Convert the interleaved IQ data to complex numbers
        iq_data = interleaved_to_complex(iq_data)

        # Compute the PSD of the IQ data
        f, psd_data = signal.welch(iq_data, fs=30000000, nfft=2048, return_onesided=False)

        # Store the PSD data in the dictionary with a tensor format
        data[self.keys[idx]] = torch.tensor(psd_data)

        return data


# Custom collate function to handle batching of dictionary data

# Usage example
h5_file_path = '2_4ghz_indoor.h5'

# Initialize dataset
dataset = HDF5KeyValueDataset(h5_file_path)

# Initialize DataLoader with the custom collate function
batch_size = 1
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=16)

# Example loop through data
for i, batch in enumerate(data_loader):
    if i >= 5:
        break
    # Access individual datasets within the batch
    for key, psd_data in batch.items():
        print(f"{key}: PSD shape {psd_data.shape}")
        # Optionally, visualize the PSD data



144700: PSD shape torch.Size([1, 2048])
144800: PSD shape torch.Size([1, 2048])
144900: PSD shape torch.Size([1, 2048])
145000: PSD shape torch.Size([1, 2048])
145100: PSD shape torch.Size([1, 2048])
