In [1]:
import numpy as np
import pandas as pd
import pywt
from utils import read_data

def calculate_wavelet_features_per_channel(signals, wavelet='haar', level=5):
    """
    Calculate wavelet-based features separately for each detail coefficient (wavelet subband) for each channel.
    Args:
        signals (np.ndarray): Multi-channel time series data with shape (n_signals, Time, N_channels).
        wavelet (str): Wavelet name (e.g., 'haar', 'db4', etc.).
        level (int): Number of decomposition levels.

    Returns:
        pd.DataFrame: DataFrame containing features for each channel and detail coefficient.
    """
    n_signals, Time, N_channels = signals.shape
    data = [] 
    # Calculate features for each channel
    for i in range(n_signals):
        signal_features = {}
        for channel in range(N_channels):
            channel_signal = signals[i, :, channel]
            channel_features = calculate_wavelet_features(channel_signal,f'ch_{channel+1}' ,wavelet, level)
            signal_features.update(channel_features)
        data.append(signal_features)
    # Create DataFrame
    data_frame = pd.DataFrame(data)
    return data_frame
    
def calculate_wavelet_features(signal,channel_name,wavelet,level):
    coeffs = pywt.wavedec(signal, wavelet, level=level)
    channel_features = {}
    # Calculate features for each detail coefficient (excluding the approximation coefficients)
    for i, detail_coeff in enumerate(coeffs[1:]):
        detail_variance = np.var(detail_coeff)
        squared_coeffs = detail_coeff**2 + 1e-10
        detail_entropy = -np.sum(squared_coeffs * np.log(squared_coeffs))
        detail_mean = np.mean(detail_coeff)
        detail_median = np.median(detail_coeff)
        detail_std = np.std(detail_coeff)
        detail_hurst = np.log(detail_std  + 1e-10 ) / np.log(2)
        detail_rms = np.sqrt(np.mean(detail_coeff**2))
        energy = np.sum(squared_coeffs)
        total_power = np.sum(squared_coeffs)
        exponent = np.linspace(1, len(detail_coeff), len(detail_coeff))
        weighted_variance = np.sum(exponent * detail_coeff**2) / total_power
        
        channel_features[f"{channel_name}_D_{i+1}_variance"] = detail_variance
        channel_features[f"{channel_name}_D_{i+1}_entropy"] = detail_entropy
        channel_features[f"{channel_name}_D_{i+1}_mean"] = detail_mean
        channel_features[f"{channel_name}_D_{i+1}_median"] = detail_median
        channel_features[f"{channel_name}_D_{i+1}_std"] = detail_std
        channel_features[f"{channel_name}_D_{i+1}_hurst"] = detail_hurst
        channel_features[f"{channel_name}_D_{i+1}_rms"] = detail_rms
        channel_features[f"{channel_name}_D_{i+1}_energy"] = energy
        channel_features[f"{channel_name}_D_{i+1}_weighted_variance"] = weighted_variance
        channel_features[f"{channel_name}_D_{i+1}_total_power"] = total_power

    return channel_features


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
train_18_X,train_18_y= read_data("data/train.hdf5")
val_18_X,val_18_y=read_data("data/test.hdf5")

In [3]:
train_18_X = calculate_wavelet_features_per_channel(train_18_X,wavelet='db4',level=4)
val_18_X = calculate_wavelet_features_per_channel(val_18_X,wavelet='db4',level=4)

In [4]:
train_18_y = np.where(train_18_y==1,'seizure','normal')
val_18_y = np.where(val_18_y==1,'seizure','normal')

In [5]:
train_features_df = pd.DataFrame(np.hstack((train_18_X,train_18_y[:,np.newaxis])))
test_features_df = pd.DataFrame(np.hstack((val_18_X,val_18_y[:,np.newaxis])))


In [6]:
train_features_df.to_csv('data/dwt_features_train.csv',index=False)
test_features_df.to_csv('data/dwt_features_test.csv',index=False)

In [7]:
train_features_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,711,712,713,714,715,716,717,718,719,720
0,6681.705078125,-16119238.0,9.156848907470703,-1.8285011053085327,81.74169921875,6.353000328815155,82.25298309326172,1664326.125,111.4685550904562,1664326.125,...,-299656.4375,-5.607783532468602e-05,-0.0771043747663497,5.881520748138428,2.5561892319793635,5.881520748138428,66520.96875,1249.6895057147117,66520.96875,seizure
1,4764.48193359375,-10791290.0,1.391047477722168,-0.1879464387893676,69.02523040771484,6.109051893458143,69.03923797607422,1172538.375,125.45376838291484,1172538.375,...,293.59228515625,0.0023175063543021,0.0055739879608154,0.53663569688797,-0.8979850690727593,0.5366407036781311,553.791748046875,913.9969829671588,553.791748046875,normal
2,1155.1351318359375,-2232761.5,0.136054590344429,0.7841262221336365,33.98727798461914,5.086922917043515,33.987552642822266,284167.8125,126.73402042566366,284167.8125,...,-772.6814575195312,0.0061958790756762,0.0108598172664642,0.7502501606941223,-0.4145563718643879,0.750275731086731,1082.48291015625,797.6871977444098,1082.48291015625,normal
3,179232.0,-576166400.0,25.7539176940918,20.391891479492188,423.3580017089844,8.725734346859209,424.140625,44254240.0,116.58953589851768,44254240.0,...,-25108756.0,0.3956546187400818,-0.6125924587249756,38.16904830932617,5.254331310353905,38.17110061645508,2801874.25,676.4616726145025,2801874.25,seizure
4,4072.689697265625,-9051136.0,-1.9151976108551023,-3.64087176322937,63.817626953125,5.9958830581510005,63.84635925292969,1002784.0,122.79834766924164,1002784.0,...,-4191.65185546875,-0.0029425693210214,-0.004854142665863,1.1750333309173584,0.2327016808851117,1.1750370264053345,2655.109375,1100.456555946981,2655.109375,normal
