# B-P Phoneme DL Data Preparation

Preparation of PyTorch datasets for deep learning models:
- Load features.parquet and spectrograms.h5
- Create PyTorch Dataset classes for different input types
- Train/Val/Test split with stratification
- Data normalization
- DataLoader creation with batch sampling
- Handle class imbalance

In [1]:
from pathlib import Path
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import h5py
import librosa
import soundfile as sf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
import warnings
import json
warnings.filterwarnings('ignore')

# Configuration
# Determine project root (parent of notebooks directory)
PROJECT_ROOT = Path.cwd().parent if Path.cwd().name in ['notebooks', 'b-p_first_experiments'] else Path.cwd()
DATA_DIR = PROJECT_ROOT / 'artifacts' / 'b-p_dataset'
FEATURES_FILE = DATA_DIR / 'features' / 'features.parquet'
SPECTROGRAMS_FILE = DATA_DIR / 'features' / 'spectrograms.h5'
PHONEMES_FILE = DATA_DIR / 'filtered_phonemes.csv'
PHONEME_WAV_DIR = PROJECT_ROOT / 'artifacts' / 'phoneme_wav'

OUTPUT_DIR = PROJECT_ROOT / 'artifacts' / 'b-p_dl_models'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

# Device setup
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(f"Using MPS device")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using CUDA device")
else:
    device = torch.device("cpu")
    print(f"Using CPU device")

print(f"Project root: {PROJECT_ROOT}")
print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

Using MPS device
Project root: /Volumes/SSanDisk/SpeechRec-German
Data directory: /Volumes/SSanDisk/SpeechRec-German/artifacts/b-p_dataset
Output directory: /Volumes/SSanDisk/SpeechRec-German/artifacts/b-p_dl_models


## 1. Load Data

In [2]:
# Load features
print("Loading features...")
df_features = pd.read_parquet(FEATURES_FILE)
print(f"Features shape: {df_features.shape}")
print(f"Features columns: {len(df_features.columns)}")

# Load phoneme metadata
print("\nLoading phoneme metadata...")
df_phonemes = pd.read_csv(PHONEMES_FILE)
print(f"Phonemes shape: {df_phonemes.shape}")
print(f"Phonemes columns: {list(df_phonemes.columns)}")

# Merge on phoneme_id
df = df_phonemes.merge(df_features, on='phoneme_id', how='inner', suffixes=('', '_features'))
print(f"\nMerged dataset shape: {df.shape}")

# Handle duplicate columns from merge (if 'class' exists in both, keep the one from df_phonemes)
if 'class_features' in df.columns:
    df = df.drop(columns=['class_features'])
if 'class' not in df.columns:
    if 'phoneme' in df.columns:
        print("\n'class' column not found, creating from 'phoneme' column...")
        df['class'] = df['phoneme']
    else:
        raise ValueError("Neither 'class' nor 'phoneme' column found in merged DataFrame")
else:
    print("\n'class' column found in merged DataFrame")

# Filter to only b and p classes (exclude pf if present)
if 'pf' in df['class'].values:
    print("\nFiltering out 'pf' class, keeping only 'b' and 'p'...")
    df = df[df['class'].isin(['b', 'p'])].copy()
    print(f"Dataset after filtering: {len(df)} samples")

# Check class distribution
print(f"\nClass distribution:")
print(df['class'].value_counts())
print(f"\nClass distribution (%):")
print(df['class'].value_counts(normalize=True) * 100)

# Encode target
le = LabelEncoder()
df['class_encoded'] = le.fit_transform(df['class'])  # b=0, p=1
print(f"\nClass encoding: {dict(zip(le.classes_, le.transform(le.classes_)))}")

# Get feature columns (exclude metadata and non-numeric columns)
exclude_cols = ['phoneme_id', 'utterance_id', 'phoneme', 'class', 'class_encoded', 
                'start_ms', 'end_ms', 'duration_ms', 'audio_path', 'is_outlier_iso',
                'class_x', 'class_y', 'class_features']  # Exclude merge suffixes
feature_cols = [col for col in df.columns if col not in exclude_cols]

# Filter to only numeric columns
numeric_cols = df[feature_cols].select_dtypes(include=[np.number]).columns.tolist()
feature_cols = [col for col in feature_cols if col in numeric_cols]

print(f"\nNumber of feature columns: {len(feature_cols)}")
print(f"First 10 features: {feature_cols[:10]}")

Loading features...
Features shape: (35660, 112)
Features columns: 112

Loading phoneme metadata...
Phonemes shape: (35660, 8)
Phonemes columns: ['phoneme_id', 'utterance_id', 'phoneme', 'class', 'start_ms', 'end_ms', 'duration_ms', 'audio_path']

Merged dataset shape: (35660, 119)

'class' column found in merged DataFrame

Class distribution:
class
b    24917
p    10743
Name: count, dtype: int64

Class distribution (%):
class
b    69.873808
p    30.126192
Name: proportion, dtype: float64

Class encoding: {'b': np.int64(0), 'p': np.int64(1)}

Number of feature columns: 109
First 10 features: ['energy_rms', 'energy_rms_std', 'energy_zcr', 'energy_zcr_std', 'spectral_centroid', 'spectral_centroid_std', 'spectral_rolloff', 'spectral_rolloff_std', 'spectral_bandwidth', 'spectral_bandwidth_std']


## 2. Load Spectrograms from H5

In [3]:
# Load spectrograms from H5 file
print("Loading spectrograms from H5...")
spectrograms_dict = {}
with h5py.File(SPECTROGRAMS_FILE, 'r') as f:
    phoneme_ids = list(f.keys())
    print(f"Found {len(phoneme_ids)} spectrograms in H5 file")
    
    # Load first spectrogram to check shape
    if phoneme_ids:
        first_key = phoneme_ids[0]
        first_spec = f[first_key][:]
        print(f"Spectrogram shape: {first_spec.shape}")
        
        # Load all spectrograms
        for phoneme_id in tqdm(phoneme_ids, desc="Loading spectrograms"):
            spectrograms_dict[phoneme_id] = f[phoneme_id][:]

print(f"\nLoaded {len(spectrograms_dict)} spectrograms")

# Check which phonemes have spectrograms
df['has_spectrogram'] = df['phoneme_id'].isin(spectrograms_dict.keys())
print(f"Phonemes with spectrograms: {df['has_spectrogram'].sum()} / {len(df)}")

Loading spectrograms from H5...
Found 35660 spectrograms in H5 file
Spectrogram shape: (128, 7)


Loading spectrograms: 100%|██████████| 35660/35660 [00:02<00:00, 16297.51it/s]



Loaded 35660 spectrograms
Phonemes with spectrograms: 35660 / 35660


## 3. Train/Val/Test Split

In [4]:
# Filter to only phonemes with spectrograms
df = df[df['has_spectrogram']].copy()
print(f"Dataset after filtering: {len(df)} samples")

# Train/Val/Test split (70/15/15) with stratification
X_temp, X_test, y_temp, y_test = train_test_split(
    df.index, df['class_encoded'], 
    test_size=0.15, 
    random_state=RANDOM_STATE, 
    stratify=df['class_encoded']
)

X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, 
    test_size=0.176,  # 0.176 ≈ 15/85
    random_state=RANDOM_STATE, 
    stratify=y_temp
)

# Create split column
df['split'] = 'train'
df.loc[X_val, 'split'] = 'val'
df.loc[X_test, 'split'] = 'test'

print(f"\nTrain set: {len(X_train):,} samples ({len(X_train)/len(df)*100:.1f}%)")
print(f"  Class distribution: {np.bincount(df.loc[X_train, 'class_encoded'])}")
print(f"Val set: {len(X_val):,} samples ({len(X_val)/len(df)*100:.1f}%)")
print(f"  Class distribution: {np.bincount(df.loc[X_val, 'class_encoded'])}")
print(f"Test set: {len(X_test):,} samples ({len(X_test)/len(df)*100:.1f}%)")
print(f"  Class distribution: {np.bincount(df.loc[X_test, 'class_encoded'])}")

# Save split indices
split_indices = {
    'train': [int(idx) for idx in X_train],
    'val': [int(idx) for idx in X_val],
    'test': [int(idx) for idx in X_test]
}

with open(OUTPUT_DIR / 'split_indices.json', 'w') as f:
    json.dump(split_indices, f)
print(f"\nSplit indices saved to {OUTPUT_DIR / 'split_indices.json'}")

Dataset after filtering: 35660 samples

Train set: 24,976 samples (70.0%)
  Class distribution: [17451  7525]
Val set: 5,335 samples (15.0%)
  Class distribution: [3728 1607]
Test set: 5,349 samples (15.0%)
  Class distribution: [3738 1611]

Split indices saved to /Volumes/SSanDisk/SpeechRec-German/artifacts/b-p_dl_models/split_indices.json


## 4. Create PyTorch Dataset Classes

In [5]:
class SpectrogramDataset(Dataset):
    """Dataset for models using spectrograms only"""
    def __init__(self, df, spectrograms_dict, split='train', transform=None):
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.spectrograms_dict = spectrograms_dict
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        phoneme_id = row['phoneme_id']
        
        # Get spectrogram
        spectrogram = self.spectrograms_dict[phoneme_id].astype(np.float32)
        
        # Add channel dimension if needed (for CNN: [1, 128, 7])
        if len(spectrogram.shape) == 2:
            spectrogram = np.expand_dims(spectrogram, axis=0)
        
        # Normalize to [0, 1]
        spectrogram = (spectrogram - spectrogram.min()) / (spectrogram.max() - spectrogram.min() + 1e-8)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        label = row['class_encoded']
        
        return torch.from_numpy(spectrogram), torch.tensor(label, dtype=torch.long)


class FeatureDataset(Dataset):
    """Dataset for models using extracted features only"""
    def __init__(self, df, feature_cols, scaler=None, split='train', fit_scaler=False):
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.feature_cols = feature_cols
        
        # Extract features
        X = self.df[feature_cols].values.astype(np.float32)
        
        # Handle missing values
        X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
        
        # Scale features
        if fit_scaler:
            self.scaler = StandardScaler()
            X = self.scaler.fit_transform(X)
        elif scaler is not None:
            self.scaler = scaler
            X = self.scaler.transform(X)
        else:
            self.scaler = None
        
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(self.df['class_encoded'].values).long()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


class HybridDataset(Dataset):
    """Dataset for hybrid models using both spectrograms and features"""
    def __init__(self, df, spectrograms_dict, feature_cols, scaler=None, split='train', fit_scaler=False, transform=None):
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.spectrograms_dict = spectrograms_dict
        self.feature_cols = feature_cols
        self.transform = transform
        
        # Extract and scale features
        X_features = self.df[feature_cols].values.astype(np.float32)
        X_features = np.nan_to_num(X_features, nan=0.0, posinf=0.0, neginf=0.0)
        
        if fit_scaler:
            self.scaler = StandardScaler()
            X_features = self.scaler.fit_transform(X_features)
        elif scaler is not None:
            self.scaler = scaler
            X_features = self.scaler.transform(X_features)
        else:
            self.scaler = None
        
        self.X_features = torch.from_numpy(X_features)
        self.y = torch.from_numpy(self.df['class_encoded'].values).long()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        phoneme_id = row['phoneme_id']
        
        # Get spectrogram
        spectrogram = self.spectrograms_dict[phoneme_id].astype(np.float32)
        if len(spectrogram.shape) == 2:
            spectrogram = np.expand_dims(spectrogram, axis=0)
        spectrogram = (spectrogram - spectrogram.min()) / (spectrogram.max() - spectrogram.min() + 1e-8)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        features = self.X_features[idx]
        label = self.y[idx]
        
        return (torch.from_numpy(spectrogram), features), label


class RawAudioDataset(Dataset):
    """Dataset for models using raw audio waveforms"""
    def __init__(self, df, split='train', sample_rate=16000, max_length=None, transform=None):
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.sample_rate = sample_rate
        self.max_length = max_length
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        audio_path = row['audio_path']
        
        # Load audio
        try:
            audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
        except:
            # If loading fails, return zeros
            audio = np.zeros(self.sample_rate // 10)  # 100ms of silence
        
        # Normalize audio
        if len(audio) > 0:
            audio = audio / (np.abs(audio).max() + 1e-8)
        
        # Pad or truncate to max_length
        if self.max_length is not None:
            if len(audio) < self.max_length:
                audio = np.pad(audio, (0, self.max_length - len(audio)), mode='constant')
            else:
                audio = audio[:self.max_length]
        
        if self.transform:
            audio = self.transform(audio)
        
        label = row['class_encoded']
        
        return torch.from_numpy(audio.astype(np.float32)), torch.tensor(label, dtype=torch.long)


class ContextAudioDataset(Dataset):
    """Dataset for models using raw audio with context from original utterance"""
    def __init__(self, df, split='train', sample_rate=16000, context_ms=500, transform=None):
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.sample_rate = sample_rate
        self.context_samples = int(context_ms * sample_rate / 1000)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        phoneme_audio_path = row['audio_path']
        
        # Load phoneme audio
        try:
            phoneme_audio, sr = librosa.load(phoneme_audio_path, sr=self.sample_rate, mono=True)
        except:
            phoneme_audio = np.zeros(self.sample_rate // 10)
        
        # For now, use just the phoneme audio as context
        # In a full implementation, you would load the full utterance and extract context
        context_audio = phoneme_audio  # Placeholder
        
        # Normalize
        if len(phoneme_audio) > 0:
            phoneme_audio = phoneme_audio / (np.abs(phoneme_audio).max() + 1e-8)
        if len(context_audio) > 0:
            context_audio = context_audio / (np.abs(context_audio).max() + 1e-8)
        
        if self.transform:
            phoneme_audio = self.transform(phoneme_audio)
            context_audio = self.transform(context_audio)
        
        label = row['class_encoded']
        
        return (
            torch.from_numpy(phoneme_audio.astype(np.float32)),
            torch.from_numpy(context_audio.astype(np.float32))
        ), torch.tensor(label, dtype=torch.long)


class SequenceDataset(Dataset):
    """Dataset for sequence models (LSTM, Transformer) using spectrograms as sequences"""
    def __init__(self, df, spectrograms_dict, split='train', transform=None):
        self.df = df[df['split'] == split].reset_index(drop=True)
        self.spectrograms_dict = spectrograms_dict
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        phoneme_id = row['phoneme_id']
        
        # Get spectrogram: shape (128, 7) -> (7, 128) for sequence models
        spectrogram = self.spectrograms_dict[phoneme_id].astype(np.float32)
        spectrogram = spectrogram.T  # Transpose: (7, 128) - 7 time steps, 128 features
        
        # Normalize
        spectrogram = (spectrogram - spectrogram.min()) / (spectrogram.max() - spectrogram.min() + 1e-8)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        label = row['class_encoded']
        
        return torch.from_numpy(spectrogram), torch.tensor(label, dtype=torch.long)

print("Dataset classes defined!")

Dataset classes defined!


## 5. Create Datasets and Compute Class Weights

In [6]:
# Create feature scaler on training data
train_df = df[df['split'] == 'train']
X_train_features = train_df[feature_cols].values.astype(np.float32)
X_train_features = np.nan_to_num(X_train_features, nan=0.0, posinf=0.0, neginf=0.0)

feature_scaler = StandardScaler()
feature_scaler.fit(X_train_features)

# Save scaler
import joblib
joblib.dump(feature_scaler, OUTPUT_DIR / 'feature_scaler.joblib')
print(f"Feature scaler saved to {OUTPUT_DIR / 'feature_scaler.joblib'}")

# Compute class weights
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(df[df['split'] == 'train']['class_encoded']),
    y=df[df['split'] == 'train']['class_encoded']
)
class_weights_dict = {i: weight for i, weight in enumerate(class_weights)}
print(f"\nClass weights: {class_weights_dict}")

# Save class weights
with open(OUTPUT_DIR / 'class_weights.json', 'w') as f:
    json.dump(class_weights_dict, f)
print(f"Class weights saved to {OUTPUT_DIR / 'class_weights.json'}")

# Create datasets
train_spectrogram_ds = SpectrogramDataset(df, spectrograms_dict, split='train')
val_spectrogram_ds = SpectrogramDataset(df, spectrograms_dict, split='val')
test_spectrogram_ds = SpectrogramDataset(df, spectrograms_dict, split='test')

train_feature_ds = FeatureDataset(df, feature_cols, scaler=feature_scaler, split='train')
val_feature_ds = FeatureDataset(df, feature_cols, scaler=feature_scaler, split='val')
test_feature_ds = FeatureDataset(df, feature_cols, scaler=feature_scaler, split='test')

train_hybrid_ds = HybridDataset(df, spectrograms_dict, feature_cols, scaler=feature_scaler, split='train')
val_hybrid_ds = HybridDataset(df, spectrograms_dict, feature_cols, scaler=feature_scaler, split='val')
test_hybrid_ds = HybridDataset(df, spectrograms_dict, feature_cols, scaler=feature_scaler, split='test')

train_sequence_ds = SequenceDataset(df, spectrograms_dict, split='train')
val_sequence_ds = SequenceDataset(df, spectrograms_dict, split='val')
test_sequence_ds = SequenceDataset(df, spectrograms_dict, split='test')

train_raw_audio_ds = RawAudioDataset(df, split='train', sample_rate=16000, max_length=3200)  # 200ms at 16kHz
val_raw_audio_ds = RawAudioDataset(df, split='val', sample_rate=16000, max_length=3200)
test_raw_audio_ds = RawAudioDataset(df, split='test', sample_rate=16000, max_length=3200)

train_context_audio_ds = ContextAudioDataset(df, split='train', sample_rate=16000)
val_context_audio_ds = ContextAudioDataset(df, split='val', sample_rate=16000)
test_context_audio_ds = ContextAudioDataset(df, split='test', sample_rate=16000)

print("\nAll datasets created!")
print(f"Train spectrogram dataset: {len(train_spectrogram_ds)} samples")
print(f"Train feature dataset: {len(train_feature_ds)} samples")
print(f"Train hybrid dataset: {len(train_hybrid_ds)} samples")
print(f"Train sequence dataset: {len(train_sequence_ds)} samples")
print(f"Train raw audio dataset: {len(train_raw_audio_ds)} samples")

Feature scaler saved to /Volumes/SSanDisk/SpeechRec-German/artifacts/b-p_dl_models/feature_scaler.joblib

Class weights: {0: np.float64(0.7156036903329323), 1: np.float64(1.6595348837209303)}
Class weights saved to /Volumes/SSanDisk/SpeechRec-German/artifacts/b-p_dl_models/class_weights.json

All datasets created!
Train spectrogram dataset: 24976 samples
Train feature dataset: 24976 samples
Train hybrid dataset: 24976 samples
Train sequence dataset: 24976 samples
Train raw audio dataset: 24976 samples


## 6. Create DataLoaders with Weighted Sampling

In [7]:
# Compute sample weights for weighted sampling
train_labels = df[df['split'] == 'train']['class_encoded'].values
sample_weights = np.array([class_weights[label] for label in train_labels])
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# Create DataLoaders
BATCH_SIZE = 64

train_spectrogram_loader = DataLoader(train_spectrogram_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_spectrogram_loader = DataLoader(val_spectrogram_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_spectrogram_loader = DataLoader(test_spectrogram_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

train_feature_loader = DataLoader(train_feature_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_feature_loader = DataLoader(val_feature_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_feature_loader = DataLoader(test_feature_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

train_hybrid_loader = DataLoader(train_hybrid_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_hybrid_loader = DataLoader(val_hybrid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_hybrid_loader = DataLoader(test_hybrid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

train_sequence_loader = DataLoader(train_sequence_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_sequence_loader = DataLoader(val_sequence_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_sequence_loader = DataLoader(test_sequence_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

train_raw_audio_loader = DataLoader(train_raw_audio_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_raw_audio_loader = DataLoader(val_raw_audio_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_raw_audio_loader = DataLoader(test_raw_audio_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

train_context_audio_loader = DataLoader(train_context_audio_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
val_context_audio_loader = DataLoader(val_context_audio_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_context_audio_loader = DataLoader(test_context_audio_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print("All DataLoaders created!")
print(f"\nTrain batches (spectrogram): {len(train_spectrogram_loader)}")
print(f"Train batches (feature): {len(train_feature_loader)}")
print(f"Train batches (hybrid): {len(train_hybrid_loader)}")

# Test a batch
print("\nTesting a batch from spectrogram dataset...")
sample_batch = next(iter(train_spectrogram_loader))
print(f"Batch shape: {sample_batch[0].shape}, Labels shape: {sample_batch[1].shape}")

All DataLoaders created!

Train batches (spectrogram): 391
Train batches (feature): 391
Train batches (hybrid): 391

Testing a batch from spectrogram dataset...
Batch shape: torch.Size([64, 1, 128, 7]), Labels shape: torch.Size([64])


## 7. Save Dataset Information

In [8]:
# Save dataset information
dataset_info = {
    'total_samples': len(df),
    'train_samples': len(df[df['split'] == 'train']),
    'val_samples': len(df[df['split'] == 'val']),
    'test_samples': len(df[df['split'] == 'test']),
    'n_features': len(feature_cols),
    'spectrogram_shape': list(spectrograms_dict[list(spectrograms_dict.keys())[0]].shape),
    'class_distribution': {
        'train': df[df['split'] == 'train']['class'].value_counts().to_dict(),
        'val': df[df['split'] == 'val']['class'].value_counts().to_dict(),
        'test': df[df['split'] == 'test']['class'].value_counts().to_dict()
    },
    'class_weights': class_weights_dict,
    'feature_columns': feature_cols
}

with open(OUTPUT_DIR / 'dataset_info.json', 'w') as f:
    json.dump(dataset_info, f, indent=2)

print(f"Dataset info saved to {OUTPUT_DIR / 'dataset_info.json'}")
print(f"\nDataset summary:")
print(f"  Total samples: {dataset_info['total_samples']}")
print(f"  Train: {dataset_info['train_samples']}")
print(f"  Val: {dataset_info['val_samples']}")
print(f"  Test: {dataset_info['test_samples']}")
print(f"  Features: {dataset_info['n_features']}")
print(f"  Spectrogram shape: {dataset_info['spectrogram_shape']}")

Dataset info saved to /Volumes/SSanDisk/SpeechRec-German/artifacts/b-p_dl_models/dataset_info.json

Dataset summary:
  Total samples: 35660
  Train: 24976
  Val: 5335
  Test: 5349
  Features: 109
  Spectrogram shape: [128, 7]
