# Combined Data Extractor & IQ Scalogram Generator

This notebook extracts data from the RadioML HDF5 dataset and generates IQ scalograms in a single pipeline.

**Output:** `[224 x 224 x 2]` NumPy arrays where:
- Channel 0: CWT of In-Phase (I) component
- Channel 1: CWT of Quadrature (Q) component

In [None]:
import os
import h5py
import json
import numpy as np
import pywt
import cv2
from tqdm import tqdm

## Configuration

In [None]:
# ======================
# CONFIGURATION
# ======================

# Input HDF5 file path
HDF5_FILE = "Dataset/GOLD_XYZ_OSC.0001_1024.hdf5"
CLASSES_JSON = "Dataset/classes-fixed.json"

# Output directory for scalograms
OUTPUT_DIR = "Dataset/Scalograms"

# SNR levels to process (-20 to 30 dB in steps of 2)
SNR_VALUES = [-20, -18, -16, -14, -12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]

# Classes to process (subset or all)
# Set to None to use all classes from JSON, or specify a list
SELECTED_CLASSES = [
    "OOK", "4ASK", "BPSK", "QPSK", "8PSK", "8ASK", "16APSK", "64QAM",
    "AM-SSB-WC", "AM-DSB-WC", "FM", "GMSK", "OQPSK"
]

# Processing limits
MAX_FRAMES_PER_CLASS_SNR = 200  # Set to None to process all 4096 frames

# Visual debugging
SAVE_SAMPLES = True
NUM_SAMPLES = 3  # Number of debug images per class
SAMPLES_DIR = "Dataset/Samples"

print(f"HDF5 File: {HDF5_FILE}")
print(f"Output Directory: {OUTPUT_DIR}")
print(f"SNR Levels: {len(SNR_VALUES)} levels from {SNR_VALUES[0]} to {SNR_VALUES[-1]} dB")
print(f"Selected Classes: {SELECTED_CLASSES}")
print(f"Max Frames per Class-SNR: {MAX_FRAMES_PER_CLASS_SNR or 'All (4096)'}")

## Load HDF5 Dataset Info

In [None]:
# Load class names
with open(CLASSES_JSON, 'r') as f:
    all_class_names = json.load(f)

print(f"All modulation classes ({len(all_class_names)}): {all_class_names}")

# Filter to selected classes if specified
if SELECTED_CLASSES:
    class_names = SELECTED_CLASSES
    # Get indices of selected classes in the original list
    class_indices = {name: all_class_names.index(name) for name in SELECTED_CLASSES if name in all_class_names}
else:
    class_names = all_class_names
    class_indices = {name: i for i, name in enumerate(all_class_names)}

print(f"\nProcessing classes ({len(class_names)}): {class_names}")
print(f"Class indices in HDF5: {class_indices}")

# Verify HDF5 file structure
with h5py.File(HDF5_FILE, 'r') as hdf:
    print(f"\nHDF5 Keys: {list(hdf.keys())}")
    print(f"X shape: {hdf['X'].shape} (frames, samples, I/Q)")
    print(f"Y shape: {hdf['Y'].shape} (frames, one-hot labels)")
    print(f"Z shape: {hdf['Z'].shape} (frames, SNR)")

## Helper Functions

In [None]:
def compute_cwt(signal, sampling_rate=1e6, wavelet='cmor1.5-0.5'):
    """
    Compute Continuous Wavelet Transform for a 1D signal.
    Returns magnitude of CWT coefficients.
    """
    sampling_period = 1 / sampling_rate
    scales = np.logspace(0.2, 1.5, num=224)
    coeffs, _ = pywt.cwt(signal, scales, wavelet, sampling_period=sampling_period)
    return np.abs(coeffs)


def process_frame_to_scalogram(i_signal, q_signal):
    """
    Convert I/Q signals to dual-channel scalogram.
    Returns array of shape [224, 224, 2].
    """
    # Compute CWT for I and Q channels
    cwt_i = compute_cwt(i_signal)
    cwt_q = compute_cwt(q_signal)
    
    # Resize to 224x224
    cwt_i = cv2.resize(cwt_i, (224, 224), interpolation=cv2.INTER_LANCZOS4)
    cwt_q = cv2.resize(cwt_q, (224, 224), interpolation=cv2.INTER_LANCZOS4)
    
    # Stack channels: [224, 224, 2]
    scalogram = np.stack([cwt_i, cwt_q], axis=-1)
    
    return scalogram.astype(np.float32)


def save_debug_image(cwt_i, cwt_q, save_path):
    """
    Save visual comparison of I and Q channels (side-by-side grayscale).
    """
    def normalize_to_img(arr):
        arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
        return (arr * 255).astype(np.uint8)
    
    img_i = normalize_to_img(cwt_i)
    img_q = normalize_to_img(cwt_q)
    
    # Stack side-by-side: Left = I, Right = Q
    combined = np.hstack([img_i, img_q])
    cv2.imwrite(save_path, combined)


print("Helper functions defined ✓")

## Create Output Directory Structure

In [None]:
# Create output directories
for snr in SNR_VALUES:
    for class_name in class_names:
        os.makedirs(os.path.join(OUTPUT_DIR, f"snr_{snr}", class_name), exist_ok=True)
        if SAVE_SAMPLES:
            os.makedirs(os.path.join(SAMPLES_DIR, f"snr_{snr}", class_name), exist_ok=True)

print(f"Created directory structure for {len(SNR_VALUES)} SNR levels × {len(class_names)} classes")
print(f"Output path: {OUTPUT_DIR}")

## Main Processing Pipeline

Extract data from HDF5 and generate scalograms in one pass.

In [None]:
def extract_and_generate_scalograms():
    """
    Main pipeline: Read HDF5 -> Extract I/Q -> Compute CWT -> Save Scalograms
    """
    frames_per_snr_class = 4096  # Fixed in RadioML dataset
    max_frames = MAX_FRAMES_PER_CLASS_SNR if MAX_FRAMES_PER_CLASS_SNR else frames_per_snr_class
    
    total_processed = 0
    
    with h5py.File(HDF5_FILE, 'r') as hdf:
        X = hdf['X']  # I/Q data: shape (N, 1024, 2)
        Y = hdf['Y']  # One-hot labels
        Z = hdf['Z']  # SNR values
        
        # Process each selected class
        for class_name in class_names:
            class_idx = class_indices[class_name]
            
            print(f"\n{'='*50}")
            print(f"Processing: {class_name} (index {class_idx})")
            print(f"{'='*50}")
            
            # For each SNR level
            for snr_idx, snr in enumerate(SNR_VALUES):
                
                # Calculate frame indices in HDF5
                # Data is organized: [class_0_snr_-20, class_0_snr_-18, ..., class_1_snr_-20, ...]
                start_idx = (class_idx * len(SNR_VALUES) * frames_per_snr_class) + \
                            (snr_idx * frames_per_snr_class)
                
                output_dir = os.path.join(OUTPUT_DIR, f"snr_{snr}", class_name)
                sample_dir = os.path.join(SAMPLES_DIR, f"snr_{snr}", class_name) if SAVE_SAMPLES else None
                
                sample_count = 0
                processed = 0
                
                # Process frames
                for frame_num in tqdm(range(min(max_frames, frames_per_snr_class)), 
                                      desc=f"SNR {snr:3d} dB", leave=False):
                    
                    frame_idx = start_idx + frame_num
                    
                    # Load I/Q data
                    frame_data = X[frame_idx]  # Shape: (1024, 2)
                    i_signal = frame_data[:, 0]
                    q_signal = frame_data[:, 1]
                    
                    # Verify label and SNR (optional sanity check)
                    actual_label = np.argmax(Y[frame_idx])
                    actual_snr = Z[frame_idx][0]
                    if actual_label != class_idx or actual_snr != snr:
                        print(f"[WARN] Mismatch at idx {frame_idx}: expected {class_idx}/{snr}, got {actual_label}/{actual_snr}")
                        continue
                    
                    # Generate scalogram
                    scalogram = process_frame_to_scalogram(i_signal, q_signal)
                    
                    # Save scalogram
                    save_path = os.path.join(output_dir, f"frame_{frame_num}.npy")
                    np.save(save_path, scalogram)
                    
                    # Save debug image
                    if SAVE_SAMPLES and sample_count < NUM_SAMPLES:
                        cwt_i = scalogram[:, :, 0]
                        cwt_q = scalogram[:, :, 1]
                        img_path = os.path.join(sample_dir, f"frame_{frame_num}_IQ.png")
                        save_debug_image(cwt_i, cwt_q, img_path)
                        sample_count += 1
                    
                    processed += 1
                
                total_processed += processed
                print(f"  SNR {snr:3d} dB: {processed} scalograms saved")
            
            print(f"✓ Completed {class_name}")
    
    return total_processed

print("Pipeline function defined ✓")
print(f"\nReady to process:")
print(f"  - {len(class_names)} classes")
print(f"  - {len(SNR_VALUES)} SNR levels")
print(f"  - Up to {MAX_FRAMES_PER_CLASS_SNR or 4096} frames per class-SNR combination")
expected_total = len(class_names) * len(SNR_VALUES) * (MAX_FRAMES_PER_CLASS_SNR or 4096)
print(f"  - Total expected: {expected_total:,} scalograms")

## Run the Pipeline

In [None]:
# Execute the pipeline
print("Starting extraction and scalogram generation...")
print("="*60)

total = extract_and_generate_scalograms()

print("\n" + "="*60)
print(f"✅ COMPLETE: Generated {total:,} IQ scalograms")
print(f"Output location: {os.path.abspath(OUTPUT_DIR)}")
if SAVE_SAMPLES:
    print(f"Sample images: {os.path.abspath(SAMPLES_DIR)}")

## Verify Output

In [None]:
# Quick verification
import matplotlib.pyplot as plt

# Load a random sample to verify
test_snr = 10
test_class = class_names[0]
test_path = os.path.join(OUTPUT_DIR, f"snr_{test_snr}", test_class, "frame_0.npy")

if os.path.exists(test_path):
    sample = np.load(test_path)
    print(f"Sample loaded: {test_path}")
    print(f"Shape: {sample.shape}")
    print(f"Dtype: {sample.dtype}")
    print(f"Value range: [{sample.min():.4f}, {sample.max():.4f}]")
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].imshow(sample[:, :, 0], aspect='auto', cmap='viridis')
    axes[0].set_title(f'{test_class} @ {test_snr}dB - I Channel')
    axes[0].set_xlabel('Time')
    axes[0].set_ylabel('Frequency (Scale)')
    
    axes[1].imshow(sample[:, :, 1], aspect='auto', cmap='viridis')
    axes[1].set_title(f'{test_class} @ {test_snr}dB - Q Channel')
    axes[1].set_xlabel('Time')
    axes[1].set_ylabel('Frequency (Scale)')
    
    plt.tight_layout()
    plt.show()
else:
    print(f"Test file not found: {test_path}")
    print("Run the pipeline first.")