In [3]:
import numpy as np
import pandas as pd
import h5py
import json
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from torchinfo import summary
from typing import Tuple, Optional
import torch.utils.checkpoint as checkpoint
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm, trange
import seaborn as sns
import gc
import time 
import seaborn 
from IPython.display import display, HTML
from typing import Dict, List, Tuple, Optional
import time
import warnings
from collections import deque
import psutil
import os
import math
import warnings
warnings.filterwarnings('ignore')

In [5]:
# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

print("‚úÖ Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

‚úÖ Libraries imported successfully
PyTorch version: 2.8.0+cu129
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3050 Laptop GPU


In [44]:
FILE_PATH = "/home/lipplopp/research/AMC_Repository/dataset/GOLD_XYZ_OSC.0001_1024.hdf5"
JSON_PATH = '/home/lipplopp/research/AMC_Repository/dataset/classes-fixed.json' 
BATCH_SIZE = 64  # Adjust based on your GPU memory
NUM_WORKERS = 4  # Adjust based on your CPU cores
patience = 10
TARGET_MODULATIONS = [
                      #'OOK',
                      # '4ASK',
                      '8ASK',
                      'BPSK', 
                      #'QPSK',
                      '8PSK', 
                      '16QAM',
                      '64QAM', 
                      #'OQPSK'
                     ]
NUM_CLASSES = len(TARGET_MODULATIONS)
NUM_EPOCHS = 100
SUBSAMPLE_TRAIN_RATIO = 0.2  # Use 20% of training data (set to 1.0 for full data)
TRAIN_SIZE = 0.7
VALID_SIZE = 0.2
TEST_SIZE = 0.1
SPLIT_SEED = 48
NORM_SEED = 49
CHUNK_SIZE = 10000 

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üîß Configuration loaded")
print(f"üìä Split ratio: Train={TRAIN_SIZE}, Valid={VALID_SIZE}, Test={TEST_SIZE}")
print(f"üíæ Batch size: {BATCH_SIZE}")
print(f"üñ•Ô∏è Device: {device}")
print(f"üì¶ Chunk size for processing: {CHUNK_SIZE}")

üîß Configuration loaded
üìä Split ratio: Train=0.7, Valid=0.2, Test=0.1
üíæ Batch size: 64
üñ•Ô∏è Device: cuda
üì¶ Chunk size for processing: 10000


In [19]:
def load_dataset_metadata(file_path: str, json_path: str) -> Tuple[np.ndarray, np.ndarray, List[str], int]:
    """
    Load only metadata and labels from HDF5 file (memory efficient).
    
    Returns:
        (Y_strings, Z_data, available_modulations, total_samples)
    """
    print("üìÇ Loading dataset metadata (memory efficient)...")
    
    with h5py.File(file_path, 'r') as hdf5_file:
        # Get dataset shape without loading data
        total_samples = hdf5_file['X'].shape[0]
        signal_length = hdf5_file['X'].shape[1]
        num_channels = hdf5_file['X'].shape[2]
        
        print(f"üìä Dataset shape: ({total_samples:,} √ó {signal_length} √ó {num_channels})")
        
        # Load only labels (much smaller than signal data)
        Y_int = np.argmax(hdf5_file['Y'][:], axis=1)
        Z_data = hdf5_file['Z'][:, 0]
        
        # Load modulation classes
        with open(json_path, 'r') as f:
            modulation_classes = json.load(f)
        
        # Convert integer labels to string labels
        Y_strings = np.array([modulation_classes[i] for i in Y_int])
        
        # Get available modulations
        available_modulations = list(np.unique(Y_strings))
        
        print(f"‚úÖ Metadata loaded: {total_samples:,} samples")
        print(f"üì° Available modulations: {len(available_modulations)}")
        print(f"üìä SNR range: {np.min(Z_data):.1f} to {np.max(Z_data):.1f} dB")
        
        # Memory usage estimate
        data_size_gb = (total_samples * signal_length * num_channels * 4) / (1024**3)
        print(f"üíæ Full dataset size: ~{data_size_gb:.2f} GB")
        
    return Y_strings, Z_data, available_modulations, total_samples

In [20]:
def stratified_dataset_split_memory_efficient(
    modulations: np.ndarray, 
    snrs: np.ndarray, 
    target_modulations: List[str],
    train_size: float = 0.7, 
    valid_size: float = 0.2, 
    test_size: float = 0.1,
    seed: int = 48, 
    subsample_train_ratio: Optional[float] = None) -> Tuple[Dict, Dict]:
    """
    Memory-efficient stratified split (doesn't require loading X data).
    
    Args:
        modulations: Modulation labels (Y)
        snrs: SNR labels (Z)
        target_modulations: Modulations to include
        train_size, valid_size, test_size: Split proportions
        seed: Random seed
        subsample_train_ratio: Optional subsampling of training data
    
    Returns:
        (splits_dict, label_map)
    """
    
    # Input validation
    if not np.isclose(train_size + valid_size + test_size, 1.0, atol=1e-6):
        raise ValueError(f"Split sizes must sum to 1.0")
    
    if len(target_modulations) == 0:
        raise ValueError("target_modulations cannot be empty")
    
    print("üîÑ Performing stratified split...")
    
    # Filter to include only target modulations
    target_mask = np.isin(modulations, target_modulations)
    target_indices = np.where(target_mask)[0]
    
    if len(target_indices) == 0:
        raise ValueError("No samples found for target modulations")
    
    print(f"üìä Found {len(target_indices):,} samples for target modulations")
    
    filtered_modulations = modulations[target_indices]
    filtered_snrs = snrs[target_indices]
    
    # Create stratification key
    stratify_key = [f"{mod}_{snr}" for mod, snr in zip(filtered_modulations, filtered_snrs)]
    
    # Check sample distribution
    unique_keys, key_counts = np.unique(stratify_key, return_counts=True)
    min_samples_per_key = np.min(key_counts)
    
    if min_samples_per_key < 2:
        warnings.warn(f"‚ö†Ô∏è Some modulation-SNR combinations have only {min_samples_per_key} samples.")
    
    # Set random seed
    np.random.seed(seed)
    
    # First split: separate test set
    train_val_indices, test_indices = train_test_split(
        target_indices, 
        test_size=test_size, 
        random_state=seed, 
        stratify=stratify_key
    )
    
    # Second split: separate train and validation
    remaining_modulations = modulations[train_val_indices]
    remaining_snrs = snrs[train_val_indices]
    remaining_stratify_key = [f"{mod}_{snr}" for mod, snr in zip(remaining_modulations, remaining_snrs)]
    
    relative_valid_size = valid_size / (1 - test_size)
    
    train_indices, valid_indices = train_test_split(
        train_val_indices,
        test_size=relative_valid_size,
        random_state=seed,
        stratify=remaining_stratify_key
    )
    
    # Apply subsampling if requested
    if subsample_train_ratio is not None and subsample_train_ratio < 1.0:
        n_keep = int(len(train_indices) * subsample_train_ratio)
        if n_keep == 0:
            raise ValueError("Subsampling ratio too small")
        
        np.random.seed(seed)
        train_indices = np.random.choice(train_indices, n_keep, replace=False)
        print(f"üìâ Subsampled training data to {n_keep:,} samples ({subsample_train_ratio:.1%})")
    
    # Create label mapping
    label_map = {name: i for i, name in enumerate(target_modulations)}
    
    # Print statistics
    print(f"\n‚úÖ Dataset split completed:")
    print(f"  üìö Train: {len(train_indices):,} samples")
    print(f"  üîç Valid: {len(valid_indices):,} samples") 
    print(f"  üß™ Test: {len(test_indices):,} samples")
    
    splits = {
        'train': train_indices,
        'valid': valid_indices,
        'test': test_indices
    }
    
    return splits, label_map

print("‚úÖ Memory-efficient functions defined")

‚úÖ Memory-efficient functions defined


In [10]:
""" NOt evicient for memory"""
# def stratified_dataset_split(data: np.ndarray, 
#                            modulations: np.ndarray, 
#                            snrs: np.ndarray, 
#                            target_modulations: List[str],
#                            train_size: float = 0.7, 
#                            valid_size: float = 0.2, 
#                            test_size: float = 0.1,
#                            seed: int = 48, 
#                            subsample_train_ratio: Optional[float] = None) -> Tuple[Dict, Dict]:
#     # Input validation
#     if not np.isclose(train_size + valid_size + test_size, 1.0, atol=1e-6):
#         raise ValueError(f"Split sizes must sum to 1.0, got {train_size + valid_size + test_size}")
    
#     if len(target_modulations) == 0:
#         raise ValueError("target_modulations cannot be empty")
    
#     if subsample_train_ratio is not None and (subsample_train_ratio <= 0 or subsample_train_ratio > 1.0):
#         raise ValueError("subsample_train_ratio must be between 0 and 1")
    
#     print("üîÑ Performing stratified split on dataset...")
    
#     # Filter data to include only target modulations
#     target_mask = np.isin(modulations, target_modulations)
#     target_indices = np.where(target_mask)[0]
    
#     if len(target_indices) == 0:
#         raise ValueError("No samples found for target modulations")
    
#     filtered_modulations = modulations[target_indices]
#     filtered_snrs = snrs[target_indices]
    
#     # Create stratification key combining modulation and SNR
#     stratify_key = [f"{mod}_{snr}" for mod, snr in zip(filtered_modulations, filtered_snrs)]
    
#     # Check if we have enough samples for stratification
#     unique_keys, key_counts = np.unique(stratify_key, return_counts=True)
#     min_samples_per_key = np.min(key_counts)
    
#     if min_samples_per_key < 2:
#         warnings.warn(f"‚ö†Ô∏è Some modulation-SNR combinations have only {min_samples_per_key} samples.")
    
#     # Set random seed for reproducibility
#     np.random.seed(seed)
    
#     # First split: separate test set
#     train_val_indices, test_indices = train_test_split(
#         target_indices, 
#         test_size=test_size, 
#         random_state=seed, 
#         stratify=stratify_key
#     )
    
#     # Second split: separate train and validation
#     remaining_modulations = modulations[train_val_indices]
#     remaining_snrs = snrs[train_val_indices]
#     remaining_stratify_key = [f"{mod}_{snr}" for mod, snr in zip(remaining_modulations, remaining_snrs)]
    
#     relative_valid_size = valid_size / (1 - test_size)
    
#     train_indices, valid_indices = train_test_split(
#         train_val_indices,
#         test_size=relative_valid_size,
#         random_state=seed,
#         stratify=remaining_stratify_key
#     )
    
#     # Apply training data subsampling if requested
#     if subsample_train_ratio is not None and subsample_train_ratio < 1.0:
#         n_keep = int(len(train_indices) * subsample_train_ratio)
#         if n_keep == 0:
#             raise ValueError("Subsampling ratio too small")
        
#         np.random.seed(seed)
#         train_indices = np.random.choice(train_indices, n_keep, replace=False)
#         print(f"üìâ Subsampled training data to {n_keep} samples ({subsample_train_ratio:.1%})")
    
#     # Create label mapping
#     label_map = {name: i for i, name in enumerate(target_modulations)}
    
#     # Print split statistics
#     print(f"\n‚úÖ Dataset split completed:")
#     print(f"  üìö Train: {len(train_indices)} samples")
#     print(f"  üîç Valid: {len(valid_indices)} samples") 
#     print(f"  üß™ Test: {len(test_indices)} samples")
    
#     splits = {
#         'train': train_indices,
#         'valid': valid_indices,
#         'test': test_indices
#     }
    
#     return splits, label_map

# print("‚úÖ Splitting function defined")

‚úÖ Splitting function defined


In [21]:
class DualStreamRadioMLDataset(Dataset):
    """Memory-efficient dataset class that reads from HDF5 on-demand."""
    
    def __init__(self, 
                 file_path: str, 
                 json_path: str, 
                 target_modulations: List[str], 
                 mode: str, 
                 indices: np.ndarray, 
                 label_map: Dict[str, int], 
                 normalization_stats: Optional[Dict] = None, 
                 seed: int = 49):
        """Initialize the dataset."""
        super(DualStreamRadioMLDataset, self).__init__()

        # Validate inputs
        if mode not in ['train', 'valid', 'test']:
            raise ValueError(f"mode must be 'train', 'valid', or 'test'")
        
        if len(indices) == 0:
            raise ValueError("indices cannot be empty")
        
        # Store parameters
        self.file_path = file_path
        self.json_path = json_path
        self.target_modulations = target_modulations
        self.mode = mode
        self.indices = np.array(indices, dtype=int)
        self.label_map = label_map
        self.seed = seed

        # Open HDF5 file in read mode (keep it open for efficiency)
        self.hdf5_file = h5py.File(self.file_path, 'r')
        self.X_h5 = self.hdf5_file['X']  # Reference, not loaded into memory
        
        # Load only labels (small memory footprint)
        self.Y_int = np.argmax(self.hdf5_file['Y'][:], axis=1)
        self.Z = self.hdf5_file['Z'][:, 0]

        # Load modulation classes
        with open(self.json_path, 'r') as f:
            self.modulation_classes = json.load(f)

        self.Y_strings = np.array([self.modulation_classes[i] for i in self.Y_int])

        # Validate signal dimensions
        signal_length = self.X_h5.shape[1]
        if signal_length != 1024:
            raise ValueError(f"Expected signal length 1024, got {signal_length}")
        self.H, self.W = 32, 32

        # Handle normalization statistics
        if mode == 'train':
            if normalization_stats is None:
                print(f"üìä Calculating normalization stats for {mode} mode...")
                self.norm_stats = self._calculate_normalization_stats()
                print(f"‚úÖ Stats calculated")
            else:
                self.norm_stats = normalization_stats
        else:
            if normalization_stats is None:
                raise ValueError(f"normalization_stats required for '{mode}' mode")
            self.norm_stats = normalization_stats

        print(f"‚úÖ {mode.capitalize()} dataset: {len(self.indices):,} samples (memory-efficient mode)")

    def _calculate_normalization_stats(self) -> Dict[str, float]:
        """Calculate normalization statistics using chunked processing."""
        num_samples = min(5000, len(self.indices))
        
        np.random.seed(self.seed)
        sample_indices = np.random.choice(self.indices, num_samples, replace=False)
        sorted_indices = np.sort(sample_indices)
        
        # Process in chunks to avoid memory issues
        chunk_size = min(500, num_samples)  # Smaller chunks for stats calculation
        
        i_vals = []
        q_vals = []
        amp_vals = []
        
        print(f"  Processing {num_samples} samples in chunks of {chunk_size}...")
        
        for i in range(0, len(sorted_indices), chunk_size):
            chunk_indices = sorted_indices[i:i+chunk_size]
            chunk_data = self.X_h5[chunk_indices, ...]
            
            # Convert to tensor
            chunk_tensor = torch.from_numpy(chunk_data).float()
            
            # Extract I/Q values
            i_chunk = chunk_tensor[:, :, 0].flatten()
            q_chunk = chunk_tensor[:, :, 1].flatten()
            amp_chunk = torch.sqrt(i_chunk**2 + q_chunk**2)
            
            i_vals.append(i_chunk)
            q_vals.append(q_chunk)
            amp_vals.append(amp_chunk)
            
            # Clear chunk data to free memory
            del chunk_data, chunk_tensor
        
        # Concatenate all chunks
        i_all = torch.cat(i_vals)
        q_all = torch.cat(q_vals)
        amp_all = torch.cat(amp_vals)
        
        stats = {
            'i_mean': i_all.mean().item(),
            'i_std': i_all.std().item(),
            'q_mean': q_all.mean().item(), 
            'q_std': q_all.std().item(),
            'amp_max': amp_all.max().item()
        }
        
        # Validate statistics
        if stats['i_std'] == 0 or stats['q_std'] == 0:
            warnings.warn("‚ö†Ô∏è Std dev is 0, adding epsilon")
            stats['i_std'] = max(stats['i_std'], 1e-8)
            stats['q_std'] = max(stats['q_std'], 1e-8)
        
        # Clean up
        del i_vals, q_vals, amp_vals, i_all, q_all, amp_all
        gc.collect()
        
        return stats

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, float]:
        """Get a single sample (loaded on-demand from HDF5)."""
        if idx >= len(self.indices):
            raise IndexError(f"Index {idx} out of range")
        
        true_index = int(self.indices[idx])
        
        # Load only this single sample from disk
        x_raw = self.X_h5[true_index]
        y_string = self.Y_strings[true_index]
        z = float(self.Z[true_index])
        
        if y_string not in self.label_map:
            raise ValueError(f"Modulation '{y_string}' not found")
        y = self.label_map[y_string]

        # Convert to tensor and normalize
        iq_sequence = torch.from_numpy(x_raw.copy()).float()  # .copy() ensures contiguous memory
        iq_sequence[:, 0] = (iq_sequence[:, 0] - self.norm_stats['i_mean']) / self.norm_stats['i_std']
        iq_sequence[:, 1] = (iq_sequence[:, 1] - self.norm_stats['q_mean']) / self.norm_stats['q_std']

        # Calculate amplitude and phase
        i_signal = iq_sequence[:, 0]
        q_signal = iq_sequence[:, 1]
        amplitude = torch.sqrt(i_signal**2 + q_signal**2)
        phase = torch.atan2(q_signal, i_signal)

        # Reshape to 2D
        amplitude_2d = amplitude.view(1, self.H, self.W)
        phase_2d = phase.view(1, self.H, self.W)

        # Normalize
        amplitude_2d = amplitude_2d / self.norm_stats['amp_max']
        phase_2d = phase_2d / math.pi
        
        return amplitude_2d, phase_2d, iq_sequence, y, z

    def get_normalization_stats(self) -> Dict[str, float]:
        return self.norm_stats.copy()

    def get_class_distribution(self) -> Dict[str, int]:
        y_strings_subset = self.Y_strings[self.indices]
        unique, counts = np.unique(y_strings_subset, return_counts=True)
        return dict(zip(unique, counts))

    def get_snr_distribution(self) -> Dict[float, int]:
        z_subset = self.Z[self.indices]
        unique, counts = np.unique(z_subset, return_counts=True)
        return dict(zip(unique, counts))

    def close(self):
        """Important: Close the HDF5 file when done."""
        if hasattr(self, 'hdf5_file') and self.hdf5_file is not None:
            try:
                self.hdf5_file.close()
                print(f"üîí {self.mode.capitalize()} dataset: HDF5 file closed")
            except:
                pass
            finally:
                self.hdf5_file = None

    def __del__(self):
        self.close()

print("‚úÖ Memory-efficient dataset class defined")

‚úÖ Memory-efficient dataset class defined


In [11]:
""Not efficient"""
# class DualStreamRadioMLDataset(Dataset):
#     """Dataset class for dual-stream radio ML data (Version 3)."""
    
#     def __init__(self, 
#                  file_path: str, 
#                  json_path: str, 
#                  target_modulations: List[str], 
#                  mode: str, 
#                  indices: np.ndarray, 
#                  label_map: Dict[str, int], 
#                  normalization_stats: Optional[Dict] = None, 
#                  seed: int = 49):
#         """Initialize the dataset."""
#         super(DualStreamRadioMLDataset, self).__init__()

#         # Validate inputs
#         if mode not in ['train', 'valid', 'test']:
#             raise ValueError(f"mode must be 'train', 'valid', or 'test', got '{mode}'")
        
#         if len(indices) == 0:
#             raise ValueError("indices cannot be empty")
        
#         # Store parameters
#         self.file_path = file_path
#         self.json_path = json_path
#         self.target_modulations = target_modulations
#         self.mode = mode
#         self.indices = np.array(indices, dtype=int)
#         self.label_map = label_map
#         self.seed = seed

#         # Open HDF5 file and load metadata
#         self.hdf5_file = h5py.File(self.file_path, 'r')
#         self.X_h5 = self.hdf5_file['X']
#         self.Y_int = np.argmax(self.hdf5_file['Y'][:], axis=1)
#         self.Z = self.hdf5_file['Z'][:, 0]

#         # Load modulation classes
#         with open(self.json_path, 'r') as f:
#             self.modulation_classes = json.load(f)

#         self.Y_strings = np.array([self.modulation_classes[i] for i in self.Y_int])

#         # Validate signal dimensions
#         signal_length = self.X_h5.shape[1]
#         if signal_length != 1024:
#             raise ValueError(f"Expected signal length 1024, got {signal_length}")
#         self.H, self.W = 32, 32

#         # Handle normalization statistics
#         if mode == 'train':
#             if normalization_stats is None:
#                 print(f"üìä Calculating normalization stats for {mode} mode...")
#                 self.norm_stats = self._calculate_normalization_stats()
#                 print(f"‚úÖ Stats: {self.norm_stats}")
#             else:
#                 print(f"üìà Using provided normalization stats for {mode} mode")
#                 self.norm_stats = normalization_stats
#         else:
#             if normalization_stats is None:
#                 raise ValueError(f"normalization_stats required for '{mode}' mode")
#             print(f"üìà Using provided normalization stats for {mode} mode")
#             self.norm_stats = normalization_stats

#         print(f"‚úÖ {mode.capitalize()} dataset: {len(self.indices)} samples")

#     def _calculate_normalization_stats(self) -> Dict[str, float]:
#         """Calculate normalization statistics."""
#         num_samples = min(5000, len(self.indices))
        
#         np.random.seed(self.seed)
#         sample_indices = np.random.choice(self.indices, num_samples, replace=False)
#         sorted_sample_indices = np.sort(sample_indices)
        
#         sample_data = torch.from_numpy(self.X_h5[sorted_sample_indices, ...]).float()

#         i_flat = sample_data[:, :, 0].flatten()
#         q_flat = sample_data[:, :, 1].flatten()
#         amplitude = torch.sqrt(i_flat**2 + q_flat**2)

#         stats = {
#             'i_mean': i_flat.mean().item(),
#             'i_std': i_flat.std().item(),
#             'q_mean': q_flat.mean().item(), 
#             'q_std': q_flat.std().item(),
#             'amp_max': amplitude.max().item()
#         }
        
#         if stats['i_std'] == 0 or stats['q_std'] == 0:
#             warnings.warn("‚ö†Ô∏è Std dev is 0, adding epsilon")
#             stats['i_std'] = max(stats['i_std'], 1e-8)
#             stats['q_std'] = max(stats['q_std'], 1e-8)
        
#         return stats

#     def __len__(self) -> int:
#         return len(self.indices)

#     def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, float]:
#         """Get a single sample."""
#         if idx >= len(self.indices):
#             raise IndexError(f"Index {idx} out of range")
        
#         true_index = int(self.indices[idx])
        
#         # Load raw data
#         x_raw = self.X_h5[true_index]
#         y_string = self.Y_strings[true_index]
#         z = float(self.Z[true_index])
        
#         if y_string not in self.label_map:
#             raise ValueError(f"Modulation '{y_string}' not found")
#         y = self.label_map[y_string]

#         # Normalize I/Q channels
#         iq_sequence = torch.from_numpy(x_raw).float()
#         iq_sequence[:, 0] = (iq_sequence[:, 0] - self.norm_stats['i_mean']) / self.norm_stats['i_std']
#         iq_sequence[:, 1] = (iq_sequence[:, 1] - self.norm_stats['q_mean']) / self.norm_stats['q_std']

#         # Calculate amplitude and phase
#         i_signal = iq_sequence[:, 0]
#         q_signal = iq_sequence[:, 1]
#         amplitude = torch.sqrt(i_signal**2 + q_signal**2)
#         phase = torch.atan2(q_signal, i_signal)

#         # Reshape to 2D format (32x32)
#         amplitude_2d = amplitude.view(1, self.H, self.W)
#         phase_2d = phase.view(1, self.H, self.W)

#         # Normalize
#         amplitude_2d = amplitude_2d / self.norm_stats['amp_max']
#         phase_2d = phase_2d / math.pi
        
#         return amplitude_2d, phase_2d, iq_sequence, y, z

#     def get_normalization_stats(self) -> Dict[str, float]:
#         return self.norm_stats.copy()

#     def get_class_distribution(self) -> Dict[str, int]:
#         y_strings_subset = self.Y_strings[self.indices]
#         unique, counts = np.unique(y_strings_subset, return_counts=True)
#         return dict(zip(unique, counts))

#     def get_snr_distribution(self) -> Dict[float, int]:
#         z_subset = self.Z[self.indices]
#         unique, counts = np.unique(z_subset, return_counts=True)
#         return dict(zip(unique, counts))

#     def close(self):
#         if hasattr(self, 'hdf5_file') and self.hdf5_file is not None:
#             try:
#                 self.hdf5_file.close()
#                 print(f"üîí {self.mode.capitalize()} dataset: HDF5 file closed")
#             except:
#                 pass
#             finally:
#                 self.hdf5_file = None

#     def __del__(self):
#         self.close()

# print("‚úÖ Dataset class defined")

‚úÖ Dataset class defined


In [13]:
def visualize_data_distribution(datasets: Dict, label_map: Dict):
    """Visualize the distribution of data across splits."""
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    fig.suptitle('Dataset Distribution Analysis', fontsize=16)
    
    for idx, (split_name, dataset) in enumerate(datasets.items()):
        # Class distribution
        class_dist = dataset.get_class_distribution()
        ax = axes[0, idx]
        ax.bar(range(len(class_dist)), list(class_dist.values()))
        ax.set_title(f'{split_name.capitalize()} - Class Distribution')
        ax.set_xlabel('Modulation Type')
        ax.set_ylabel('Sample Count')
        ax.set_xticks(range(len(class_dist)))
        ax.set_xticklabels(list(class_dist.keys()), rotation=45, ha='right')
        
        # SNR distribution
        snr_dist = dataset.get_snr_distribution()
        ax = axes[1, idx]
        snrs = list(snr_dist.keys())
        counts = list(snr_dist.values())
        ax.bar(snrs, counts)
        ax.set_title(f'{split_name.capitalize()} - SNR Distribution')
        ax.set_xlabel('SNR (dB)')
        ax.set_ylabel('Sample Count')
    
    plt.tight_layout()
    plt.show()

In [22]:
def test_batch_loading(loader: DataLoader, name: str = "Train", show_plot: bool = True):
    """Test loading a batch and display statistics."""
    print(f"\nüß™ Testing {name} loader...")
    
    for batch_idx, (amp, phase, iq, labels, snrs) in enumerate(loader):
        print(f"\n  Batch {batch_idx + 1}:")
        print(f"  ‚îú‚îÄ Amplitude shape: {amp.shape}")
        print(f"  ‚îú‚îÄ Phase shape: {phase.shape}")
        print(f"  ‚îú‚îÄ IQ sequence shape: {iq.shape}")
        print(f"  ‚îú‚îÄ Labels shape: {labels.shape}")
        print(f"  ‚îú‚îÄ SNRs shape: {snrs.shape}")
        print(f"  ‚îú‚îÄ Label range: [{labels.min().item()}, {labels.max().item()}]")
        print(f"  ‚îî‚îÄ SNR range: [{snrs.min().item():.1f}, {snrs.max().item():.1f}] dB")
        
        if batch_idx == 0 and show_plot:
            # Visualize first sample
            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            
            # Amplitude
            axes[0].imshow(amp[0, 0].cpu().numpy(), cmap='hot')
            axes[0].set_title(f'Amplitude (Label: {labels[0].item()})')
            axes[0].axis('off')
            
            # Phase
            axes[1].imshow(phase[0, 0].cpu().numpy(), cmap='hsv')
            axes[1].set_title(f'Phase (SNR: {snrs[0].item():.1f} dB)')
            axes[1].axis('off')
            
            # IQ time series
            axes[2].plot(iq[0, :, 0].cpu().numpy(), label='I', alpha=0.7)
            axes[2].plot(iq[0, :, 1].cpu().numpy(), label='Q', alpha=0.7)
            axes[2].set_title('I/Q Signals')
            axes[2].set_xlabel('Sample')
            axes[2].set_ylabel('Amplitude')
            axes[2].legend()
            axes[2].grid(True, alpha=0.3)
            
            plt.suptitle(f'Sample from {name} Batch', fontsize=14)
            plt.tight_layout()
            plt.show()
            break

print("‚úÖ Helper functions defined")

‚úÖ Helper functions defined


In [36]:
# %%
# Initialize tracking variables
X_data = None
Y_strings = None
Z_data = None
train_dataset = None
valid_dataset = None
test_dataset = None

try:
    # ============= Load Metadata Only =============
    print("\n" + "="*60)
    print("STEP 1: LOADING DATASET METADATA")
    print("="*60)
    
    Y_strings, Z_data, available_modulations, total_samples = load_dataset_metadata(
        FILE_PATH, JSON_PATH
    )
    
    # Display available modulations
    print(f"\nüì° Available modulations ({len(available_modulations)}):")
    for i, mod in enumerate(available_modulations):
        print(f"  {i:2d}. {mod}")
    
    # Set target modulations
    target_modulations = TARGET_MODULATIONS
    
    print(f"\nüéØ Using {len(target_modulations)} target modulations")
    
    # Clean up memory
    gc.collect()

except FileNotFoundError as e:
    print(f"\n‚ùå Error: File not found - {e}")
    print("Please update FILE_PATH and JSON_PATH in the configuration section.")
    raise
except MemoryError as e:
    print(f"\n‚ùå Memory Error: {e}")
    print("The dataset is too large. The memory-efficient version should handle this.")
    raise


STEP 1: LOADING DATASET METADATA
üìÇ Loading dataset metadata (memory efficient)...
üìä Dataset shape: (2,555,904 √ó 1024 √ó 2)
‚úÖ Metadata loaded: 2,555,904 samples
üì° Available modulations: 24
üìä SNR range: -20.0 to 30.0 dB
üíæ Full dataset size: ~19.50 GB

üì° Available modulations (24):
   0. 128APSK
   1. 128QAM
   2. 16APSK
   3. 16PSK
   4. 16QAM
   5. 256QAM
   6. 32APSK
   7. 32PSK
   8. 32QAM
   9. 4ASK
  10. 64APSK
  11. 64QAM
  12. 8ASK
  13. 8PSK
  14. AM-DSB-SC
  15. AM-DSB-WC
  16. AM-SSB-SC
  17. AM-SSB-WC
  18. BPSK
  19. FM
  20. GMSK
  21. OOK
  22. OQPSK
  23. QPSK

üéØ Using 5 target modulations


In [24]:
splits, label_map = stratified_dataset_split_memory_efficient(
    Y_strings, Z_data, target_modulations,
    train_size=TRAIN_SIZE, 
    valid_size=VALID_SIZE, 
    test_size=TEST_SIZE,
    seed=SPLIT_SEED,
    subsample_train_ratio=SUBSAMPLE_TRAIN_RATIO
)

# Display label mapping
print("\nüè∑Ô∏è Label mapping:")
label_df = pd.DataFrame(list(label_map.items()), columns=['Modulation', 'Label'])
display(label_df)

üîÑ Performing stratified split...
üìä Found 532,480 samples for target modulations
üìâ Subsampled training data to 74,547 samples (20.0%)

‚úÖ Dataset split completed:
  üìö Train: 74,547 samples
  üîç Valid: 106,497 samples
  üß™ Test: 53,248 samples

üè∑Ô∏è Label mapping:


Unnamed: 0,Modulation,Label
0,8ASK,0
1,BPSK,1
2,8PSK,2
3,16QAM,3
4,64QAM,4


In [25]:
# Clean up - we don't need full Y_strings and Z_data anymore
del Y_strings, Z_data
gc.collect()

0

In [26]:
# ============= Create Datasets =============
print("\n" + "="*60)
print("STEP 3: CREATING DATASETS (Memory Efficient)")
print("="*60)

# Create train dataset
print("Creating training dataset...")
train_dataset = DualStreamRadioMLDataset(
    FILE_PATH, JSON_PATH, target_modulations,
    mode='train', 
    indices=splits['train'], 
    label_map=label_map,
    normalization_stats=None,  # Will be calculated
    seed=NORM_SEED
)

# Get normalization stats
train_norm_stats = train_dataset.get_normalization_stats()

# Create validation dataset
print("\nCreating validation dataset...")
valid_dataset = DualStreamRadioMLDataset(
    FILE_PATH, JSON_PATH, target_modulations,
    mode='valid',
    indices=splits['valid'],
    label_map=label_map,
    normalization_stats=train_norm_stats,
    seed=NORM_SEED
)

# Create test dataset
print("\nCreating test dataset...")
test_dataset = DualStreamRadioMLDataset(
    FILE_PATH, JSON_PATH, target_modulations,
    mode='test',
    indices=splits['test'], 
    label_map=label_map,
    normalization_stats=train_norm_stats,
    seed=NORM_SEED
)

print("\n‚úÖ All datasets created successfully!")


STEP 3: CREATING DATASETS (Memory Efficient)
Creating training dataset...
üìä Calculating normalization stats for train mode...
  Processing 5000 samples in chunks of 500...
‚úÖ Stats calculated
‚úÖ Train dataset: 74,547 samples (memory-efficient mode)

Creating validation dataset...
‚úÖ Valid dataset: 106,497 samples (memory-efficient mode)

Creating test dataset...
‚úÖ Test dataset: 53,248 samples (memory-efficient mode)

‚úÖ All datasets created successfully!


In [27]:
# ============= Create Data Loaders =============
print("\n" + "="*60)
print("STEP 4: CREATING DATA LOADERS")
print("="*60)

# Note: For memory efficiency, consider using smaller num_workers
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    pin_memory=torch.cuda.is_available(),  # Only if using GPU
    num_workers=min(NUM_WORKERS, 2),  # Reduce for memory efficiency
    persistent_workers=False  # Set to False for memory efficiency
)

valid_loader = DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    pin_memory=torch.cuda.is_available(),
    num_workers=min(NUM_WORKERS, 2),
    persistent_workers=False
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    pin_memory=torch.cuda.is_available(),
    num_workers=min(NUM_WORKERS, 2),
    persistent_workers=False
)

print("‚úÖ Data loaders created successfully")


STEP 4: CREATING DATA LOADERS
‚úÖ Data loaders created successfully


In [37]:
try:
    from CNN_LSTM_new import create_multi_domain_model
    from torchinfo import summary
    model_new = create_multi_domain_model(num_classes=NUM_CLASSES, dropout_rate=0.7).to(device)
    print(summary(model_new))
except Exception as e:
    print(f"‚ö†Ô∏è Could not build Multi domain Model: {e}")

Layer (type:depth-idx)                   Param #
MultiDomainFusionModel                   --
‚îú‚îÄCnn2DBranch: 1-1                       --
‚îÇ    ‚îî‚îÄSequential: 2-1                   --
‚îÇ    ‚îÇ    ‚îî‚îÄConv2d: 3-1                  640
‚îÇ    ‚îÇ    ‚îî‚îÄBatchNorm2d: 3-2             128
‚îÇ    ‚îÇ    ‚îî‚îÄLeakyReLU: 3-3               --
‚îÇ    ‚îÇ    ‚îî‚îÄConv2d: 3-4                  73,856
‚îÇ    ‚îÇ    ‚îî‚îÄBatchNorm2d: 3-5             256
‚îÇ    ‚îÇ    ‚îî‚îÄLeakyReLU: 3-6               --
‚îÇ    ‚îÇ    ‚îî‚îÄMaxPool2d: 3-7               --
‚îÇ    ‚îÇ    ‚îî‚îÄDropout2d: 3-8               --
‚îÇ    ‚îÇ    ‚îî‚îÄConv2d: 3-9                  295,168
‚îÇ    ‚îÇ    ‚îî‚îÄBatchNorm2d: 3-10            512
‚îÇ    ‚îÇ    ‚îî‚îÄLeakyReLU: 3-11              --
‚îÇ    ‚îÇ    ‚îî‚îÄMaxPool2d: 3-12              --
‚îÇ    ‚îÇ    ‚îî‚îÄDropout2d: 3-13              --
‚îÇ    ‚îî‚îÄAdaptiveAvgPool2d: 2-2            --
‚îú‚îÄCnn2DBranch: 1-2                       --
‚îÇ    ‚îî‚îÄSequentia

In [38]:
# Optimizers for both models
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer_new = optim.AdamW(
    model_new.parameters(), 
    lr=1e-3,
    weight_decay=1e-4,
    betas=(0.9, 0.99)
)

# Schedulers for both models
scheduler_new = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_new, 
    mode='max',           # Monitor accuracy
    patience=5,           # Wait for 5 epochs without improvement
    factor=0.5,           # Reduce LR by half
    min_lr=1e-6,          # Minimum learning rate
    #verbose=True          # Print updates
)

# Scalers for mixed precision
scaler_new = torch.cuda.amp.GradScaler()


In [39]:
metrics = {
    'train_losses': [], 'valid_losses': [],
    'train_accuracies': [], 'valid_accuracies': [],
    'training_times': [], 'best_accuracy': 0.0,
    'final_predictions': [], 'final_true_labels': []
}
patience_counter = 0
best_model_state = None
scaler = GradScaler()

In [40]:
def train_epoch(model, train_loader, optimizer, criterion, scaler, device):
    """
    Train one epoch with the new multi-domain model.
    Handles three data inputs from the dataloader.
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    
    # --- ADJUSTMENT 1: Unpack the new data format ---
    # The dataloader now yields 5 items. We need the first 4.
    for amp_inputs, phase_inputs, iq_seq_inputs, labels, _ in train_loader:
        # --- ADJUSTMENT 2: Move all required tensors to the device ---
        amp_inputs = amp_inputs.to(device)
        phase_inputs = phase_inputs.to(device)
        iq_seq_inputs = iq_seq_inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            # --- ADJUSTMENT 3: Call the model with three inputs ---
            outputs = model(amp_inputs, phase_inputs, iq_seq_inputs)
            loss = criterion(outputs, labels)
        
        # Mixed-precision training steps (unchanged)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        # Metrics calculation (unchanged)
        running_loss += loss.item() * amp_inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_time = time.time() - start_time
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_accuracy = 100. * correct / total
    
    return epoch_loss, epoch_accuracy, epoch_time

def validate_epoch(model, valid_loader, criterion, device):
    """
    Validate one epoch with the new multi-domain model.
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        # --- ADJUSTMENT 1: Unpack the new data format ---
        for amp_inputs, phase_inputs, iq_seq_inputs, labels, _ in valid_loader:
            # --- ADJUSTMENT 2: Move all required tensors to the device ---
            amp_inputs = amp_inputs.to(device)
            phase_inputs = phase_inputs.to(device)
            iq_seq_inputs = iq_seq_inputs.to(device)
            labels = labels.to(device)
            
            # --- ADJUSTMENT 3: Call the model with three inputs ---
            outputs = model(amp_inputs, phase_inputs, iq_seq_inputs)
            loss = criterion(outputs, labels)
            
            # Metrics calculation (unchanged)
            running_loss += loss.item() * amp_inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(valid_loader.dataset)
    epoch_accuracy = 100. * correct / total
    
    return epoch_loss, epoch_accuracy, predictions, true_labels

def test_epoch(model, test_loader, criterion, device):
    """
    Test one epoch with the new multi-domain model and a progress bar.
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    predictions = []
    true_labels = []
    
    test_iterator = tqdm(test_loader, desc="Testing", leave=False, ncols=100)
    
    with torch.no_grad():
        # --- ADJUSTMENT 1: Unpack the new data format ---
        for amp_inputs, phase_inputs, iq_seq_inputs, labels, _ in test_iterator:
            # --- ADJUSTMENT 2: Move all required tensors to the device ---
            amp_inputs = amp_inputs.to(device)
            phase_inputs = phase_inputs.to(device)
            iq_seq_inputs = iq_seq_inputs.to(device)
            labels = labels.to(device)
            
            # --- ADJUSTMENT 3: Call the model with three inputs ---
            outputs = model(amp_inputs, phase_inputs, iq_seq_inputs)
            loss = criterion(outputs, labels)
            
            # Metrics calculation (unchanged)
            running_loss += loss.item() * amp_inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            avg_loss = running_loss / total
            accuracy = 100. * correct / total
            test_iterator.set_postfix(loss=f"{avg_loss:.4f}", accuracy=f"{accuracy:.2f}%")

    epoch_loss = running_loss / len(test_loader.dataset)
    epoch_accuracy = 100. * correct / total

    return epoch_loss, epoch_accuracy, predictions, true_labels

In [None]:
# --- Simplified Training Loop ---
import copy
import time
print("\nüéØ Starting model training...")
print("=" * 80)

for epoch in tqdm(range(NUM_EPOCHS), desc="Training Progress"):
    epoch_start = time.time()
    
    # --- Train the model ---
    # The train_epoch function is the adjusted one that handles 3 data inputs
    train_loss, train_acc, train_time = train_epoch(
        model_new, train_loader, optimizer_new, criterion, scaler_new, device
    )
    
    # --- Validate the model ---
    # The validate_epoch function is also the adjusted one
    valid_loss, valid_acc, predictions, true_labels = validate_epoch(
        model_new, valid_loader, criterion, device
    )
    
    # --- Update the flattened metrics dictionary ---
    metrics['train_losses'].append(train_loss)
    metrics['train_accuracies'].append(train_acc)
    metrics['valid_losses'].append(valid_loss)
    metrics['valid_accuracies'].append(valid_acc)
    metrics['training_times'].append(train_time)
    
    # --- Update scheduler (if you have one) ---
    # scheduler_parallel.step()
    
    # --- Check for best model and handle early stopping ---
    if valid_acc > metrics['best_accuracy']:
        metrics['best_accuracy'] = valid_acc
        # The validation function already returns numpy arrays
        metrics['final_predictions'] = predictions
        metrics['final_true_labels'] = true_labels
        # Save the best model's state
        best_model_state = copy.deepcopy(model_new.state_dict())
        patience_counter = 0 # Reset patience
    else:
        patience_counter += 1
    
    # --- Print epoch results ---
    if (epoch + 1) % 5 == 0:
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}:")
        print(f"Train Acc: {train_acc:.2f}% | Valid Acc: {valid_acc:.2f}% | Time: {train_time:.2f}s")
    
    # --- Early stopping check ---
    if patience_counter >= patience:
        print(f"\nEarly stopping at epoch {epoch+1} - Model has not improved for {patience} epochs.")
        break

print("\nüéâ Training Complete!")
print(f"Best Validation Accuracy: {metrics['best_accuracy']:.2f}%")

# --- Final Testing on Test Set ---
# Check if a test_loader and a saved best model exist
if 'test_loader' in locals() and best_model_state is not None:
    print("\nüîç Final testing on the test set with the best model...")
    
    # Load the best model's weights
    model_new.load_state_dict(best_model_state)
    
    # Test the model
    test_loss, test_acc, test_predictions, test_true_labels = test_epoch(
        model_new, test_loader, criterion, device
    )
    
    # Store the final test results
    metrics['test_accuracy'] = test_acc
    metrics['test_loss'] = test_loss
    metrics['test_predictions'] = test_predictions
    metrics['test_true_labels'] = test_true_labels
    
    print(f"Final Test Accuracy: {test_acc:.2f}%")
else:
    print("\n‚ö†Ô∏è No test_loader found or model did not improve - skipping final testing.")


üéØ Starting model training...


Training Progress:   5%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ                                                                                                                 | 5/100 [11:22<3:38:19, 137.89s/it]


Epoch 5/100:
Train Acc: 62.42% | Valid Acc: 63.79% | Time: 80.67s


Training Progress:   7%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé                                                                                                              | 7/100 [16:04<3:36:18, 139.55s/it]