# Part 3: Advanced Analysis

In this part, we will implement advanced analysis techniques for physiological time series data, including time-domain feature extraction, frequency analysis, and wavelet transforms.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal
import pywt

# Set plotting style
# plt.style.use('seaborn')
sns.set_context('notebook')
plt.style.use('seaborn-v0_8')


## 1. Time-Domain Feature Extraction

Implement the `extract_time_domain_features` function to extract various time-domain features from physiological signals.

In [2]:
def extract_time_domain_features(data, window_size=60):
    """Extract time-domain features from physiological signals.
    
    Parameters
    ----------
    data : pd.DataFrame
        Input data with columns: ['timestamp', 'heart_rate', 'eda', 'temperature', 'subject_id', 'session']
    window_size : int
        Size of the rolling window in seconds (assumes 1 Hz sampling rate)
        
    Returns
    -------
    pd.DataFrame
        DataFrame with extracted time-domain features
    """
    # convert HR to RR intervals in milliseconds
    rr_intervals = 60000 / data['heart_rate']
    rr_diff = rr_intervals.diff()

    # Rolling calculations
    features = pd.DataFrame(index=data.index)

    # Basic time-domain stats (units: bpm)
    features['heart_rate_mean_bpm'] = data['heart_rate'].rolling(window=window_size).mean()
    features['heart_rate_std_bpm'] = data['heart_rate'].rolling(window=window_size).std()
    features['heart_rate_min_bpm'] = data['heart_rate'].rolling(window=window_size).min()
    features['heart_rate_max_bpm'] = data['heart_rate'].rolling(window=window_size).max()

    # Beat-to-beat variability features (units: ms)
    features['rr_mean_ms'] = rr_intervals.rolling(window=window_size).mean()
    features['rr_sdnn_ms'] = rr_intervals.rolling(window=window_size).std()
    features['rr_rmssd_ms'] = rr_diff.rolling(window=window_size).apply(lambda x: np.sqrt(np.mean(x**2)), raw=False)
    features['rr_pnn50_percent'] = rr_diff.rolling(window=window_size).apply(
        lambda x: 100 * np.sum(np.abs(x) > 50) / len(x) if len(x) > 0 else np.nan, raw=False
    )

    # Drop NaNs caused by rolling calculations
    return features.dropna()


## 2. Frequency Analysis

Implement the `analyze_frequency_components` function to perform frequency-domain analysis on the signals.

In [7]:
def analyze_frequency_components(data, sampling_rate, window_size=60, signal_list=None, subject_id=None, session=None, output_dir='plots/frequency'):
    os.makedirs(output_dir, exist_ok=True)
    
    if signal_list is None:
        signal_list = ['heart_rate', 'eda', 'temperature']
    
    window_samples = int(window_size * sampling_rate)
    results = {}
    
    for signal_name in signal_list:
        if signal_name not in data.columns:
            print(f"Skipping: {signal_name} not found in data.")
            continue
        
        signal_results = {}
        n_windows = max(1, len(data) // window_samples)
        all_frequencies, all_power = [], []
        
        for i in range(n_windows):
            start, end = i * window_samples, min((i + 1) * window_samples, len(data))
            window_data = data[signal_name].iloc[start:end]
            
            if len(window_data) < window_samples // 2:
                continue
            
            window_data = window_data.replace([np.inf, -np.inf], np.nan)
            if window_data.isna().mean() > 0.5:
                print(f"  Skipping window {i}: >50% missing")
                continue
            window_data = window_data.interpolate(method='linear').fillna(method='ffill').fillna(method='bfill')
            if window_data.isna().any() or np.isinf(window_data).any():
                print(f"  Skipping window {i}: contains unresolved NaN/Inf")
                continue
            
            try:
                detrended = signal.detrend(window_data.values)
                windowed = detrended * signal.windows.hann(len(detrended))
            except Exception as e:
                print(f"  Window {i} error during preprocessing: {e}")
                continue
            
            try:
                freqs, psd = signal.welch(windowed, fs=sampling_rate,
                                          nperseg=min(256, len(windowed)),
                                          noverlap=min(128, len(windowed) // 2),
                                          scaling='density')
                all_frequencies.append(freqs)
                all_power.append(psd)
            except Exception as e:
                print(f"  Welch error in window {i}: {e}")
                continue
        
        if not all_frequencies:
            print(f"No valid windows for {signal_name}")
            continue
        
        avg_freqs = np.mean(all_frequencies, axis=0)
        avg_power = np.mean(all_power, axis=0)
        
        signal_results['frequencies'] = avg_freqs
        signal_results['power'] = avg_power
        
        bands = {
            'VLF': (0.003, 0.04),
            'LF': (0.04, 0.15),
            'HF': (0.15, 0.4)
        }
        
        signal_results['bands'] = {}
        for band, (low, high) in bands.items():
            mask = (avg_freqs >= low) & (avg_freqs <= high)
            signal_results['bands'][band] = np.sum(avg_power[mask]) if any(mask) else 0
        
        hf_power = signal_results['bands']['HF']
        lf_power = signal_results['bands']['LF']
        signal_results['bands']['LF/HF'] = lf_power / hf_power if hf_power > 0 else np.nan
        
        plt.figure(figsize=(12, 6))
        plt.semilogy(avg_freqs, avg_power, 'b-', label='Power Spectrum')
        colors = {'VLF': 'green', 'LF': 'orange', 'HF': 'red'}
        for band, (low, high) in bands.items():
            mask = (avg_freqs >= low) & (avg_freqs <= high)
            if any(mask):
                plt.fill_between(avg_freqs[mask], avg_power[mask], alpha=0.3, color=colors[band], label=f'{band} Band')
        
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power Spectral Density')
        title = f'{signal_name.replace("_", " ").title()} Power Spectrum'
        if subject_id and session:
            title += f' - Subject {subject_id}, {session}'
        plt.title(title)
        plt.grid(True, which='both', linestyle='--', alpha=0.5)
        plt.legend()
        
        plot_name = f"{subject_id}_{session}_{signal_name}_fft.png" if subject_id and session else f"{signal_name}_fft.png"
        plt.savefig(os.path.join(output_dir, plot_name), dpi=300, bbox_inches='tight')
        plt.close()
        
        base_name = f"{subject_id}_{session}_{signal_name}_fft" if subject_id and session else f"{signal_name}_fft"
        np.savez(
            os.path.join(output_dir, f"{base_name}.npz"),
            frequencies=avg_freqs,
            power=avg_power,
            **{f"{band.lower()}_power": signal_results['bands'][band] for band in bands},
            lf_hf_ratio=signal_results['bands']['LF/HF']
        )
        pd.DataFrame({'frequency': avg_freqs, 'power': avg_power}).to_csv(
            os.path.join(output_dir, f"{base_name}.csv"), index=False
        )
        
        results[signal_name] = signal_results
    
    return results

## 3. Time-Frequency Analysis

Implement the `analyze_time_frequency_features` function to analyze time-frequency features using wavelet transforms.

In [8]:
# Define the analyze_time_frequency_features function
import os 
def analyze_time_frequency_features(data, sampling_rate, window_size=60, signal_name='heart_rate', subject_id=None, session=None, output_dir='plots/wavelet'):
    """
    Analyze time-frequency features using wavelet transforms.
    """
    os.makedirs(output_dir, exist_ok=True)
    window_samples = int(window_size * sampling_rate)
    n_windows = len(data) // window_samples

    # Define wavelet scales
    scales = np.arange(1, 128)
    all_coefficients = []
    all_energy = []

    for i in range(n_windows):
        segment = data[signal_name].iloc[i*window_samples:(i+1)*window_samples]

        if segment.isna().mean() > 0.5:
            continue

        # Interpolate missing values
        segment = segment.replace([np.inf, -np.inf], np.nan)
        segment = segment.interpolate(method='linear').bfill().ffill()

        if segment.isna().any():
            continue

        # Apply wavelet transform
        coeffs, freqs = pywt.cwt(segment.values, scales, 'morl', sampling_period=1/sampling_rate)
        all_coefficients.append(coeffs)
        all_energy.append(np.abs(coeffs) ** 2)

    if not all_coefficients:
        print(f"No valid windows for {signal_name}.")
        return {}

    mean_coefficients = np.mean(all_coefficients, axis=0)
    mean_energy = np.mean(all_energy, axis=0)

    # Prepare output file base name
    base_name = f"{subject_id}_{session}_{signal_name}_wavelet" if subject_id and session else f"{signal_name}_wavelet"

    # Save coefficients & energy
    np.savez(os.path.join(output_dir, f"{base_name}.npz"),
             scales=scales,
             coefficients=mean_coefficients,
             time_frequency_energy=mean_energy)

    np.save(os.path.join(output_dir, f"{base_name}_energy.npy"), mean_energy)

    # Save as CSV (energy matrix)
    energy_df = pd.DataFrame(mean_energy, index=scales)
    energy_df.to_csv(os.path.join(output_dir, f"{base_name}_energy.csv"))

    # Optional visualization
    plt.figure(figsize=(12, 6))
    plt.imshow(mean_energy, extent=[0, window_size, scales[-1], scales[0]],
               cmap='viridis', aspect='auto')
    plt.colorbar(label='Energy')
    plt.title(f"Wavelet Time-Frequency Energy - {signal_name}")
    plt.xlabel("Time (s)")
    plt.ylabel("Scale")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{base_name}.png"), dpi=300)
    plt.close()

    return {
        'scales': scales,
        'coefficients': mean_coefficients,
        'time_frequency_energy': mean_energy
    }

# Now load your dataset (use the actual path to your data)
data = pd.read_csv('data/processed/preprocessed_data.csv')

# Set your sampling rate (adjust according to your data)
sampling_rate = 4.0  # Hz (example)

# Call the analyze_time_frequency_features function for heart rate signal
time_freq_results = analyze_time_frequency_features(data, sampling_rate, window_size=60, signal_name='heart_rate')

# Access the results
print("Wavelet Scales:", time_freq_results['scales'])
print("Wavelet Coefficients shape:", time_freq_results['coefficients'].shape)
print("Time-Frequency Energy shape:", time_freq_results['time_frequency_energy'].shape)


Wavelet Scales: [  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72
  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90
  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108
 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
 127]
Wavelet Coefficients shape: (127, 240)
Time-Frequency Energy shape: (127, 240)


## Example Usage

Here's how to use these functions with your data:

In [10]:
import pandas as pd
import os 

# Load your data
data = pd.read_csv('data/processed/preprocessed_data.csv')

# Extract time-domain features
features = extract_time_domain_features(data, window_size=60)
print("Time-domain features:")
print(features.head())

# Analyze frequency components
sampling_rate = 4.0  # Hz
freq_results = analyze_frequency_components(data, sampling_rate, window_size=60)

# ✅ Corrected frequency analysis output
print("\nFrequency analysis results:")
for signal_name, result in freq_results.items():
    print(f"\nFrequency bands for {signal_name}:")
    if 'bands' in result:
        for band_name, power in result['bands'].items():
            print(f"  {band_name}: {power:.4f}")
    else:
        print("  No frequency band data available.")

# Analyze time-frequency features
tf_results = analyze_time_frequency_features(data, sampling_rate, window_size=60)
print("\nTime-frequency analysis results:")
print("Wavelet scales:", tf_results['scales'].shape)
print("Coefficients shape:", tf_results['coefficients'].shape)


Time-domain features:
    heart_rate_mean_bpm  heart_rate_std_bpm  heart_rate_min_bpm  \
60           164.041333           35.064915                50.0   
61           166.424333           31.898846                53.5   
62           168.789833           28.460003                61.0   
63           171.066500           24.936774                89.6   
64           172.846833           22.869215                89.6   

    heart_rate_max_bpm  rr_mean_ms  rr_sdnn_ms  rr_rmssd_ms  rr_pnn50_percent  
60              190.42  401.196225  179.982797  7591.209079         10.000000  
61              192.98  386.378109  146.607027    49.553774          8.333333  
62              195.43  372.803442  110.696442    48.508900          6.666667  
63              197.60  361.470728   76.686157    45.126647          5.000000  
64              198.07  355.560544   66.465694    16.259471          3.333333  


  window_data = window_data.interpolate(method='linear').fillna(method='ffill').fillna(method='bfill')



Frequency analysis results:

Frequency bands for heart_rate:
  VLF: 2241.8698
  LF: 419.2689
  HF: 29.1552
  LF/HF: 14.3806

Frequency bands for eda:
  VLF: 0.1494
  LF: 0.1239
  HF: 0.0823
  LF/HF: 1.5051

Frequency bands for temperature:
  VLF: 2.0533
  LF: 1.1603
  HF: 0.3633
  LF/HF: 3.1934

Time-frequency analysis results:
Wavelet scales: (127,)
Coefficients shape: (127, 240)
