# Part 3: Advanced Analysis

In this part, we will implement advanced analysis techniques for physiological time series data, including time-domain feature extraction, frequency analysis, and wavelet transforms.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
import pywt
import os

# Set plotting style
plt.style.use('seaborn')
sns.set_context('notebook')

## 1. Time-Domain Feature Extraction

Implement the `extract_time_domain_features` function to extract various time-domain features from physiological signals.

In [22]:
def extract_time_domain_features(data, window_size=60):
    # convert window_size from seconds to number of samples
    # assuming data is sampled at 1 Hz (1 sample per second)
    window_samples = window_size
    
    # initialize DataFrame for features
    features = pd.DataFrame()
    features['timestamp'] = data['timestamp']
    features['subject_id'] = data['subject_id']
    features['session'] = data['session']
    
    # process each physiological signal
    for signal_name in ['heart_rate', 'eda', 'temperature']:
        if signal_name in data.columns:
            # basic statistics using rolling window
            features[f'{signal_name}_mean'] = data[signal_name].rolling(window=window_samples).mean()
            features[f'{signal_name}_std'] = data[signal_name].rolling(window=window_samples).std()
            features[f'{signal_name}_min'] = data[signal_name].rolling(window=window_samples).min()
            features[f'{signal_name}_max'] = data[signal_name].rolling(window=window_samples).max()
    
    # heart rate variability measures
    if 'heart_rate' in data.columns:
        # convert HR to RR intervals in milliseconds (ms)
        rr_intervals = 60000 / data['heart_rate']
        # calculate successive differences between RR intervals
        rr_diff = rr_intervals.diff()
        # RMSSD (Root Mean Square of Successive Differences) in ms
        features['hrv_rmssd_ms'] = np.sqrt(
            rr_diff.abs().rolling(window=window_samples).apply(
                lambda x: np.mean(x**2)
            )
        )
        
        # SDNN (Standard Deviation of NN intervals) in ms
        features['hrv_sdnn_ms'] = rr_intervals.rolling(window=window_samples).std()
        # pNN50 (Percentage of successive RR intervals differing by >50ms)
        features['hrv_pnn50_percent'] = rr_diff.rolling(window=window_samples).apply(
            lambda x: 100 * np.sum(np.abs(x) > 50) / len(x) if len(x) > 0 else 0
        )
    
    # EDA-specific features if available
    if 'eda' in data.columns:
        # skin conductance level (tonic component)
        features['eda_tonic_uS'] = data['eda'].rolling(window=window_samples).mean()
        
        # skin conductance response frequency (number of significant rises)
        # define a threshold for significant EDA change (typically 0.05 µS)
        eda_diff = data['eda'].diff()
        features['eda_response_freq_per_min'] = eda_diff.rolling(window=window_samples).apply(
            lambda x: np.sum(x > 0.05) * (60 / window_size)  # Convert to per minute
        )
    
    # temperature-specific features if available
    if 'temperature' in data.columns:
        # temperature rate of change (°C/min)
        features['temp_change_rate_C_per_min'] = data['temperature'].diff().rolling(window=window_samples).mean() * (60 / window_size)
    
    features = features.dropna()
    
    return features

In [None]:
# Time-Domain Feature Data Test
# load sample data for testing
from pathlib import Path

# check if processed data exists
processed_files = list(Path('data/processed').glob('*.csv'))
subject_files = [f for f in processed_files if not f.name.startswith('all_')]

if subject_files:
    print(f"Found {len(subject_files)} subject data files")
    
    # create a directory for individual feature plots
    os.makedirs('plots/time_features', exist_ok=True)
    
    # dictionary to store features for all subjects
    all_features = {}
    
    # process each subject file
    for sample_file in subject_files:
        subject_id = sample_file.stem.split('_')[0] 
        session = '_'.join(sample_file.stem.split('_')[1:-1])  
        
        print(f"\nProcessing: {subject_id}, {session}")
        
        sample_data = pd.read_csv(sample_file)
        
        # convert timestamp to datetime
        sample_data['timestamp'] = pd.to_datetime(sample_data['timestamp'])
        # extract features
        time_features = extract_time_domain_features(sample_data, window_size=60)
        # store in dictionary
        all_features[f"{subject_id}_{session}"] = time_features
        # display basic info
        print(f"  Extracted {time_features.shape[0]} feature windows")
        # create a visualization for this subject
        plt.figure(figsize=(14, 10))
        
        # heart rate plot with variability
        plt.subplot(3, 1, 1)
        plt.plot(time_features['timestamp'], time_features['heart_rate_mean'], 'b-', label='Mean HR')
        plt.fill_between(
            time_features['timestamp'],
            time_features['heart_rate_mean'] - time_features['heart_rate_std'],
            time_features['heart_rate_mean'] + time_features['heart_rate_std'],
            alpha=0.2, color='b'
        )
        plt.title(f'Heart Rate with Standard Deviation - {subject_id}, {session}')
        plt.ylabel('BPM')
        plt.legend()
        
        # HRV measures
        plt.subplot(3, 1, 2)
        plt.plot(time_features['timestamp'], time_features['hrv_rmssd_ms'], 'r-', label='RMSSD')
        plt.plot(time_features['timestamp'], time_features['hrv_sdnn_ms'], 'g-', label='SDNN')
        plt.title('Heart Rate Variability Measures')
        plt.ylabel('ms')
        plt.legend()
        
        # EDA with response frequency
        plt.subplot(3, 1, 3)
        plt.plot(time_features['timestamp'], time_features['eda_tonic_uS'], 'b-', label='Tonic EDA')
        plt.plot(time_features['timestamp'], time_features['eda_response_freq_per_min'], 'm-', label='EDA Responses/min')
        plt.title('Electrodermal Activity Features')
        plt.xlabel('Time')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(f'plots/time_features/{subject_id}_{session}_time_features.png', dpi=300)
        plt.close()
    
    # combined feature dataset
    combined_features = pd.concat(all_features.values(), ignore_index=True)
    combined_features.to_csv('data/processed/all_time_features.csv', index=False)
    
    print(f"\nCombined features dataset created with shape: {combined_features.shape}")
    print("Saved to: data/processed/all_time_features.csv")
    
else:
    print("No processed data files found. Please run part1_exploration.ipynb first.")

## 2. Frequency Analysis

Implement the `analyze_frequency_components` function to perform frequency-domain analysis on the signals.

In [24]:
def analyze_frequency_components(data, sampling_rate, window_size=60, subject_id=None, session=None, output_dir='plots/frequency'):
    os.makedirs(output_dir, exist_ok=True)
    
    # convert window_size from seconds to number of samples
    window_samples = int(window_size * sampling_rate)
    results = {}
    
    # process physiological signals
    for signal_name in ['heart_rate', 'eda', 'temperature']:
        if signal_name not in data.columns:
            continue
            
        signal_results = {}
        
        # process data in windows
        n_windows = max(1, len(data) // window_samples)
        all_frequencies = []
        all_power = []
        
        for i in range(n_windows):
            start_idx = i * window_samples
            end_idx = min((i + 1) * window_samples, len(data))
            
            window_data = data[signal_name].iloc[start_idx:end_idx]
            
            # skip windows with not enough data
            if len(window_data) < window_samples // 2:
                continue
                
            # check for NaN or inf values and clean the data
            if np.isnan(window_data).any() or np.isinf(window_data).any():
                # replace infinities with NaNs 
                window_data = window_data.replace([np.inf, -np.inf], np.nan)
                nan_ratio = window_data.isna().mean()
                if nan_ratio > 0.5:  
                    print(f"  Skipping window with {nan_ratio*100:.1f}% missing data")
                    continue
                    
                # interpolate any remaining NaN values
                window_data = window_data.interpolate(method='linear').fillna(method='ffill').fillna(method='bfill')
                
                if np.isnan(window_data).any() or np.isinf(window_data).any():
                    print(f"  Skipping window with remaining NaN/Inf values")
                    continue
            
            # detrend the data to remove linear trends
            try:
                window_data_detrended = signal.detrend(window_data.values)
            except Exception as e:
                print(f"  Error during detrending: {e}")
                continue
            
            # apply Hann window to reduce spectral leakage
            window_data_windowed = window_data_detrended * signal.windows.hann(len(window_data_detrended))
            
            # calculate PSD using Welch's method
            try:
                frequencies, power = signal.welch(
                    window_data_windowed,
                    fs=sampling_rate,
                    nperseg=min(256, len(window_data_windowed)),
                    noverlap=min(128, len(window_data_windowed) // 2),
                    scaling='density'
                )
                
                all_frequencies.append(frequencies)
                all_power.append(power)
            except Exception as e:
                print(f"  Error in Welch's method: {e}")
                continue
        
        # skip signals with no valid windows
        if not all_frequencies:
            print(f"  No valid windows for {signal_name}")
            continue
            
        # calculate average results
        signal_results['frequencies'] = np.mean(all_frequencies, axis=0)
        signal_results['power'] = np.mean(all_power, axis=0)
        
        # frequency bands
        bands = {
            'VLF': (0.003, 0.04),  
            'LF': (0.04, 0.15),    
            'HF': (0.15, 0.4)      
        }
        
        # calculate power in each band
        signal_results['bands'] = {}
        for band_name, (low, high) in bands.items():
            mask = (signal_results['frequencies'] >= low) & (signal_results['frequencies'] <= high)
            if not any(mask):
                signal_results['bands'][band_name] = 0
                continue
            
            signal_results['bands'][band_name] = np.sum(signal_results['power'][mask])
        
        # calculate LF/HF ratio
        if signal_results['bands']['HF'] > 0:
            signal_results['bands']['LF/HF'] = signal_results['bands']['LF'] / signal_results['bands']['HF']
        else:
            signal_results['bands']['LF/HF'] = np.nan
        
        # create a plot of the power spectrum
        plt.figure(figsize=(12, 6))
        
        # plot the power spectrum
        plt.semilogy(signal_results['frequencies'], signal_results['power'], 'b-')
        
        # highlight frequency bands
        colors = {'VLF': 'green', 'LF': 'orange', 'HF': 'red'}
        for band_name, (low, high) in bands.items():
            mask = (signal_results['frequencies'] >= low) & (signal_results['frequencies'] <= high)
            if any(mask):
                plt.fill_between(
                    signal_results['frequencies'][mask], 
                    signal_results['power'][mask], 
                    alpha=0.3, 
                    color=colors[band_name],
                    label=f"{band_name}: {low:.3f}-{high:.3f} Hz"
                )
        
        # add labels and title
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power Spectral Density')
        
        title = f'Power Spectrum of {signal_name.replace("_", " ").title()}'
        if subject_id and session:
            title += f' - Subject {subject_id}, {session}'
        
        plt.title(title)
        plt.grid(True, which='both', linestyle='--', alpha=0.5)
        plt.legend()
        
        if subject_id and session:
            plot_filename = f"{subject_id}_{session}_{signal_name}_fft.png"
        else:
            plot_filename = f"{signal_name}_fft.png"
        
        plt.savefig(os.path.join(output_dir, plot_filename), dpi=300, bbox_inches='tight')
        plt.close()
        
        if subject_id and session:
            data_filename = f"{subject_id}_{session}_{signal_name}_fft"
        else:
            data_filename = f"{signal_name}_fft"
        
        np.savez(
            os.path.join(output_dir, f"{data_filename}.npz"),
            frequencies=signal_results['frequencies'],
            power=signal_results['power'],
            vlf_power=signal_results['bands']['VLF'],
            lf_power=signal_results['bands']['LF'],
            hf_power=signal_results['bands']['HF'],
            lf_hf_ratio=signal_results['bands']['LF/HF']
        )
        
        freq_df = pd.DataFrame({
            'frequency': signal_results['frequencies'],
            'power': signal_results['power']
        })
        freq_df.to_csv(os.path.join(output_dir, f"{data_filename}.csv"), index=False)
        
        results[signal_name] = signal_results
    
    return results

In [None]:
# Frequency Analysis Data Test
freq_output_dir = 'plots/frequency'
os.makedirs(freq_output_dir, exist_ok=True)

processed_files = list(Path('data/processed').glob('*.csv'))
subject_files = [f for f in processed_files if not f.name.startswith('all_')]

if subject_files:
    print(f"Found {len(subject_files)} subject files for frequency analysis")
    all_freq_results = {}
    
    for sample_file in subject_files:
        try:
            subject_id = sample_file.stem.split('_')[0]
            session = '_'.join(sample_file.stem.split('_')[1:-1])
            
            print(f"\nPerforming frequency analysis for Subject {subject_id}, Session {session}")
            
            # load the data
            sample_data = pd.read_csv(sample_file)
            
            # convert timestamp to datetime
            sample_data['timestamp'] = pd.to_datetime(sample_data['timestamp'])
            
            # check for and report missing data
            missing_data = sample_data[['heart_rate', 'eda', 'temperature']].isna().sum()
            if missing_data.sum() > 0:
                print(f"  Warning: Dataset contains missing values:\n{missing_data}")
            
            # check for infinite values
            inf_data = np.isinf(sample_data[['heart_rate', 'eda', 'temperature']]).sum()
            if inf_data.sum() > 0:
                print(f"  Warning: Dataset contains infinite values:\n{inf_data}")
                
                # clean infinities
                sample_data = sample_data.replace([np.inf, -np.inf], np.nan)
                sample_data = sample_data.interpolate(method='linear')
                sample_data = sample_data.fillna(method='ffill').fillna(method='bfill')
                
                print("  Replaced infinite values with interpolated values")
            
            # determine sampling rate based on timestamps
            time_diff = np.median(np.diff(sample_data['timestamp'].astype(int))) / 1e9
            sampling_rate = 1.0 / time_diff
            
            print(f"  Detected sampling rate: {sampling_rate:.2f} Hz")
            
            # frequency analysis
            freq_results = analyze_frequency_components(
                sample_data,
                sampling_rate,
                window_size=60,
                subject_id=subject_id,
                session=session,
                output_dir=freq_output_dir
            )
            
            if not freq_results:
                print(f"  No valid results for {subject_id}, {session}")
                continue
                
            all_freq_results[f"{subject_id}_{session}"] = freq_results
            
            # key results
            for signal_name, results in freq_results.items():
                print(f"\n  {signal_name.upper()} frequency analysis:")
                print(f"    VLF power: {results['bands']['VLF']:.4f}")
                print(f"    LF power: {results['bands']['LF']:.4f}")
                print(f"    HF power: {results['bands']['HF']:.4f}")
                print(f"    LF/HF ratio: {results['bands']['LF/HF']:.4f}")
        
        except Exception as e:
            print(f"Error processing {sample_file.name}: {e}")
    
    # continue with summary table if there are results
    if all_freq_results:
        # summary table of LF/HF ratios (a key stress indicator)
        lf_hf_summary = []
        for subject_session, freq_results in all_freq_results.items():
            subject_id, session = subject_session.split('_', 1)
            
            for signal_name, results in freq_results.items():
                lf_hf_summary.append({
                    'subject_id': subject_id,
                    'session': session,
                    'signal': signal_name,
                    'VLF_power': results['bands']['VLF'],
                    'LF_power': results['bands']['LF'],
                    'HF_power': results['bands']['HF'],
                    'LF_HF_ratio': results['bands']['LF/HF']
                })
        
        lf_hf_df = pd.DataFrame(lf_hf_summary)
        lf_hf_df.to_csv(os.path.join(freq_output_dir, 'frequency_analysis_summary.csv'), index=False)
        
        print("\nFrequency analysis complete!")
        print(f"Results saved to {freq_output_dir}")
        print(f"Summary file: {os.path.join(freq_output_dir, 'frequency_analysis_summary.csv')}")
        
        # comparison plot of LF/HF ratios across all subjects and sessions
        plt.figure(figsize=(14, 8))
        
        # filter for heart rate only
        hr_data = lf_hf_df[lf_hf_df['signal'] == 'heart_rate']
        
        # plot LF/HF ratios by subject and session
        ax = sns.barplot(x='subject_id', y='LF_HF_ratio', hue='session', data=hr_data)
        
        plt.title('Heart Rate LF/HF Ratio by Subject and Session')
        plt.xlabel('Subject ID')
        plt.ylabel('LF/HF Ratio (higher values indicate more stress)')
        plt.xticks(rotation=45)
        plt.legend(title='Session')
        plt.tight_layout()
        
        plt.savefig(os.path.join(freq_output_dir, 'lf_hf_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
else:
    print("No processed data files found. Please run part1_exploration.ipynb first.")

## 3. Time-Frequency Analysis

Implement the `analyze_time_frequency_features` function to analyze time-frequency features using wavelet transforms.

In [19]:
def analyze_time_frequency_features(data, sampling_rate, window_size=60, signal_name='heart_rate', subject_id=None, session=None, output_dir='plots/wavelet'):
    os.makedirs(output_dir, exist_ok=True)
    # convert window_size from seconds to number of samples
    window_samples = int(window_size * sampling_rate)
    results = {}
    
    # check if signal exists in data
    if signal_name not in data.columns:
        print(f"Signal {signal_name} not found in data")
        return results
        
    # replace NaN and infinity values
    signal_data = data[signal_name].copy()
    signal_data = signal_data.replace([np.inf, -np.inf], np.nan)
    nan_ratio = signal_data.isna().mean()
    if nan_ratio > 0.5:
        print(f"Too many missing values in {signal_name}: {nan_ratio*100:.1f}%")
        return results
        
    # interpolate missing values
    signal_data = signal_data.interpolate(method='linear').ffill().bfill()
    if signal_data.isna().any():
        print(f"Could not interpolate all missing values in {signal_name}")
        return results
    
    # define wavelet scales
    scales = np.logspace(0, np.log10(128), num=64)
    results['scales'] = scales
    # convert scales to frequencies for interpretation
    frequencies = pywt.scale2frequency('morl', scales) * sampling_rate
    results['frequencies'] = frequencies
    # process data in windows
    n_windows = max(1, len(signal_data) // window_samples)
    all_coefficients = []
    all_energy = []
    dominant_freqs = []
    times = []
    
    for i in range(n_windows):
        start_idx = i * window_samples
        end_idx = min((i + 1) * window_samples, len(signal_data))
        window_data = signal_data.iloc[start_idx:end_idx].values
        
        if len(window_data) < window_samples // 2:
            continue
        try:
            coefficients, _ = pywt.cwt(
                window_data,
                scales,
                'morl',
                sampling_period=1/sampling_rate
            )
            
            # calculate energy distribution
            energy = np.abs(coefficients)**2
            
            # find dominant frequency at each time point
            dominant_scale_idx = np.argmax(energy, axis=0)
            dominant_freq = frequencies[dominant_scale_idx]
            
            all_coefficients.append(coefficients)
            all_energy.append(energy)
            dominant_freqs.append(dominant_freq)
            
            # time information for plotting
            window_times = data['timestamp'].iloc[start_idx:end_idx]
            times.append(window_times)
            
        except Exception as e:
            print(f"Error in wavelet transform: {e}")
            continue

    if len(all_coefficients) == 0:
        print(f"No valid windows for wavelet analysis")
        return results
    
    # average results across windows
    results['coefficients'] = np.mean(all_coefficients, axis=0)
    results['time_frequency_energy'] = np.mean(all_energy, axis=0)
    results['dominant_frequency'] = np.concatenate(dominant_freqs) if dominant_freqs else np.array([])
    
    # flatten times for plotting
    flattened_times = pd.concat(times) if times else pd.Series()
    
    # visualization
    if len(all_energy) > 0:
        plt.figure(figsize=(14, 10))
        
        # time-frequency plot (scalogram)
        plt.subplot(211)
        T, S = np.meshgrid(
            np.arange(results['time_frequency_energy'].shape[1]), 
            np.arange(results['time_frequency_energy'].shape[0])
        )
        
        plt.contourf(T, S, results['time_frequency_energy'], 100, cmap='jet')
        plt.ylabel('Scale')
        plt.title(f'Wavelet Scalogram - {signal_name}')
        
        plt.colorbar(label='Energy')
        
        ax1 = plt.gca()
        ax2 = ax1.twinx()
        ax2.set_yticks(np.arange(0, len(scales), len(scales)//5))
        ax2.set_yticklabels([f"{frequencies[i]:.2f}" for i in range(0, len(scales), len(scales)//5)])
        ax2.set_ylabel('Frequency (Hz)')
        
        # plot dominant frequency over time
        plt.subplot(212)
        
        # concatenated dominant frequencies
        if len(results['dominant_frequency']) > 0:
            plt.plot(np.arange(len(results['dominant_frequency'])), results['dominant_frequency'], 'r-')
            plt.xlabel('Time (samples)')
            plt.ylabel('Dominant Frequency (Hz)')
            plt.title('Dominant Frequency Components Over Time')
        
        plt.tight_layout()
        
        if subject_id and session:
            fig_name = f"{subject_id}_{session}_{signal_name}_wavelet.png"
        else:
            fig_name = f"{signal_name}_wavelet.png"
        
        plt.savefig(os.path.join(output_dir, fig_name), dpi=300, bbox_inches='tight')
        plt.close()
        
        if subject_id and session:
            data_name = f"{subject_id}_{session}_{signal_name}_wavelet"
        else:
            data_name = f"{signal_name}_wavelet"
        
        np.savez(
            os.path.join(output_dir, f"{data_name}.npz"),
            scales=results['scales'],
            frequencies=results['frequencies'],
            coefficients=results['coefficients'],
            energy=results['time_frequency_energy'],
            dominant_frequency=results['dominant_frequency']
        )
        
        energy_df = pd.DataFrame(
            results['time_frequency_energy'],
            index=[f"scale_{i}" for i in range(len(scales))],
            columns=[f"time_{i}" for i in range(results['time_frequency_energy'].shape[1])]
        )
        energy_df.to_csv(os.path.join(output_dir, f"{data_name}_energy.csv"))
        
        if len(results['dominant_frequency']) > 0:
            dom_freq_df = pd.DataFrame({
                'time': np.arange(len(results['dominant_frequency'])),
                'dominant_frequency': results['dominant_frequency']
            })
            dom_freq_df.to_csv(os.path.join(output_dir, f"{data_name}_dominant_freq.csv"), index=False)
    
    return results

In [None]:
# Time-Frequency Analysis Data Test 
wavelet_output_dir = 'plots/wavelet'
os.makedirs(wavelet_output_dir, exist_ok=True)

processed_files = list(Path('data/processed').glob('*.csv'))
subject_files = [f for f in processed_files if not f.name.startswith('all_')]

if subject_files:
    print(f"Found {len(subject_files)} subject files for wavelet analysis")
    all_wavelet_results = {}
    
    # process each subject file
    for sample_file in subject_files:
        try:
            subject_id = sample_file.stem.split('_')[0]
            session = '_'.join(sample_file.stem.split('_')[1:-1])
            
            print(f"\nPerforming wavelet analysis for Subject {subject_id}, Session {session}")
            
            sample_data = pd.read_csv(sample_file)
            
            # convert timestamp to datetime
            sample_data['timestamp'] = pd.to_datetime(sample_data['timestamp'])
            
            # determine sampling rate based on timestamps
            time_diff = np.median(np.diff(sample_data['timestamp'].astype(int))) / 1e9
            sampling_rate = 1.0 / time_diff
            
            print(f"  Detected sampling rate: {sampling_rate:.2f} Hz")
            
            # perform wavelet analysis on different signals
            signals_to_analyze = ['heart_rate', 'eda', 'temperature']
            
            for signal_name in signals_to_analyze:
                if signal_name not in sample_data.columns:
                    print(f"  Signal {signal_name} not found in data")
                    continue
                    
                print(f"  Analyzing {signal_name} signal...")
                
                # wavelet analysis
                wavelet_results = analyze_time_frequency_features(
                    sample_data,
                    sampling_rate,
                    window_size=60,
                    signal_name=signal_name,
                    subject_id=subject_id,
                    session=session,
                    output_dir=wavelet_output_dir
                )
                
                if wavelet_results:
                    result_key = f"{subject_id}_{session}_{signal_name}"
                    all_wavelet_results[result_key] = wavelet_results
                    print(f"  ✓ Wavelet analysis complete for {signal_name}")
                else:
                    print(f"  ✗ Wavelet analysis failed for {signal_name}")
                    
        except Exception as e:
            print(f"Error processing {sample_file.name}: {e}")
    
    # results summary
    wavelet_summary = []
    
    for result_key, wavelet_result in all_wavelet_results.items():
        if 'time_frequency_energy' in wavelet_result:
            # extract subject, session, signal from the key
            parts = result_key.split('_')
            subject_id = parts[0]
            signal = parts[-1]
            session = '_'.join(parts[1:-1])
            
            # summary metrics
            energy_matrix = wavelet_result['time_frequency_energy']
            total_energy = np.sum(energy_matrix)
            max_energy = np.max(energy_matrix)
            frequency_with_max_energy = wavelet_result['frequencies'][np.argmax(np.sum(energy_matrix, axis=1))]
            
            wavelet_summary.append({
                'subject_id': subject_id,
                'session': session,
                'signal': signal,
                'total_energy': total_energy,
                'max_energy': max_energy,
                'frequency_with_max_energy': frequency_with_max_energy
            })
    
    if wavelet_summary:
        summary_df = pd.DataFrame(wavelet_summary)
        summary_file = os.path.join(wavelet_output_dir, 'wavelet_analysis_summary.csv')
        summary_df.to_csv(summary_file, index=False)
        
        print("\nWavelet analysis complete!")
        print(f"Results saved to {wavelet_output_dir}")
        print(f"Summary file: {summary_file}")
        
        # comparison plot of dominant frequencies for heart rate
        hr_summary = summary_df[summary_df['signal'] == 'heart_rate']
        
        if not hr_summary.empty:
            plt.figure(figsize=(12, 8))
            
            ax = sns.barplot(
                x='subject_id', 
                y='frequency_with_max_energy', 
                hue='session', 
                data=hr_summary
            )
            
            plt.title('Dominant Frequency of Heart Rate by Subject and Session')
            plt.xlabel('Subject ID')
            plt.ylabel('Frequency (Hz)')
            plt.xticks(rotation=45)
            plt.legend(title='Session')
            plt.tight_layout()
            
            comparison_file = os.path.join(wavelet_output_dir, 'dominant_frequency_comparison.png')
            plt.savefig(comparison_file, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"Comparison plot saved as: {comparison_file}")
    else:
        print("No wavelet results were generated")
        
else:
    print("No processed data files found. Please run part1_exploration.ipynb first.")

## Example Usage

Here's how to use these functions with your data:

In [None]:
# Load your data
data = pd.read_csv('data/processed/S1_processed.csv')

# Extract time-domain features
features = extract_time_domain_features(data, window_size=60)
print("Time-domain features:")
print(features.head())

# Analyze frequency components
sampling_rate = 4.0  # Hz
freq_results = analyze_frequency_components(data, sampling_rate, window_size=60)
print("\nFrequency analysis results:")
print("Frequency bands:", freq_results['bands'])

# Analyze time-frequency features
tf_results = analyze_time_frequency_features(data, sampling_rate, window_size=60)
print("\nTime-frequency analysis results:")
print("Wavelet scales:", tf_results['scales'].shape)
print("Coefficients shape:", tf_results['coefficients'].shape)