# ECG Feature Extraction

This notebook demonstrates various feature extraction techniques for ECG signal analysis.

## Features Extracted:
1. R-peak detection
2. QRS complex extraction
3. Heart Rate Variability (HRV) metrics
4. Wavelet features
5. Statistical features

In [None]:
# Import required libraries
import wfdb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
from scipy.stats import skew, kurtosis
import pywt
import warnings
warnings.filterwarnings('ignore')

print("Libraries imported successfully!")

## Load Preprocessed ECG Signal

In [None]:
# Load ECG record
record_name = '../data/mitdb/100'
record = wfdb.rdrecord(record_name)
annotation = wfdb.rdann(record_name, 'atr')

ecg_signal = record.p_signal[:, 0]
sampling_rate = record.fs

print(f"Loaded ECG signal: {len(ecg_signal)} samples at {sampling_rate} Hz")

## 1. R-Peak Detection

In [None]:
def detect_r_peaks(ecg_signal, sampling_rate):
    """Detect R-peaks using Pan-Tompkins algorithm"""
    # Bandpass filter (5-15 Hz)
    nyquist = sampling_rate / 2
    low = 5 / nyquist
    high = 15 / nyquist
    b, a = signal.butter(2, [low, high], btype='band')
    filtered = signal.filtfilt(b, a, ecg_signal)
    
    # Derivative
    diff_signal = np.diff(filtered)
    
    # Squaring
    squared = diff_signal ** 2
    
    # Moving average
    window_size = int(0.12 * sampling_rate)
    integrated = np.convolve(squared, np.ones(window_size)/window_size, mode='same')
    
    # Find peaks
    threshold = np.mean(integrated) * 0.6
    peaks, _ = signal.find_peaks(integrated, height=threshold, distance=int(0.6*sampling_rate))
    
    return peaks

# Detect R-peaks
r_peaks = detect_r_peaks(ecg_signal, sampling_rate)
print(f"Detected {len(r_peaks)} R-peaks")
print(f"Average heart rate: {len(r_peaks) / (len(ecg_signal)/sampling_rate) * 60:.1f} BPM")

In [None]:
# Visualize R-peaks
plt.figure(figsize=(15, 5))
time = np.arange(len(ecg_signal)) / sampling_rate
segment_end = 3600  # First 10 seconds

plt.plot(time[:segment_end], ecg_signal[:segment_end], linewidth=0.8, color='#2E86AB', label='ECG Signal')
peaks_in_segment = r_peaks[r_peaks < segment_end]
plt.scatter(time[peaks_in_segment], ecg_signal[peaks_in_segment], 
           color='#E63946', s=100, zorder=5, label='R-peaks', marker='^')
plt.title('ECG Signal with Detected R-peaks', fontsize=14, fontweight='bold')
plt.xlabel('Time (seconds)', fontsize=12)
plt.ylabel('Amplitude (mV)', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 2. QRS Complex Extraction

In [None]:
def extract_qrs_complexes(ecg_signal, r_peaks, sampling_rate, window_ms=200):
    """Extract QRS complexes around R-peaks"""
    window_samples = int(window_ms * sampling_rate / 1000)
    half_window = window_samples // 2
    
    qrs_complexes = []
    for peak in r_peaks:
        start = max(0, peak - half_window)
        end = min(len(ecg_signal), peak + half_window)
        qrs = ecg_signal[start:end]
        
        # Pad if necessary
        if len(qrs) < window_samples:
            qrs = np.pad(qrs, (0, window_samples - len(qrs)), mode='edge')
        
        qrs_complexes.append(qrs)
    
    return np.array(qrs_complexes)

# Extract QRS complexes
qrs_complexes = extract_qrs_complexes(ecg_signal, r_peaks, sampling_rate)
print(f"Extracted {len(qrs_complexes)} QRS complexes")
print(f"QRS complex shape: {qrs_complexes[0].shape}")

In [None]:
# Visualize QRS complexes
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for i in range(6):
    axes[i].plot(qrs_complexes[i], linewidth=1.5, color='#4361EE')
    axes[i].set_title(f'QRS Complex {i+1}', fontsize=12, fontweight='bold')
    axes[i].grid(True, alpha=0.3)
    axes[i].set_xlabel('Sample', fontsize=10)
    axes[i].set_ylabel('Amplitude', fontsize=10)

plt.tight_layout()
plt.show()

# Average QRS complex
plt.figure(figsize=(10, 5))
avg_qrs = np.mean(qrs_complexes, axis=0)
std_qrs = np.std(qrs_complexes, axis=0)
x = np.arange(len(avg_qrs))

plt.plot(x, avg_qrs, linewidth=2, color='#06A77D', label='Average QRS')
plt.fill_between(x, avg_qrs - std_qrs, avg_qrs + std_qrs, alpha=0.3, color='#06A77D')
plt.title('Average QRS Complex with Standard Deviation', fontsize=14, fontweight='bold')
plt.xlabel('Sample', fontsize=12)
plt.ylabel('Amplitude', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Heart Rate Variability (HRV) Features

In [None]:
def calculate_hrv_features(r_peaks, sampling_rate):
    """Calculate time-domain HRV features"""
    # RR intervals (in milliseconds)
    rr_intervals = np.diff(r_peaks) / sampling_rate * 1000
    
    # Time-domain features
    hrv_features = {
        'mean_rr': np.mean(rr_intervals),
        'std_rr': np.std(rr_intervals),
        'rmssd': np.sqrt(np.mean(np.diff(rr_intervals) ** 2)),
        'sdsd': np.std(np.diff(rr_intervals)),
        'nn50': np.sum(np.abs(np.diff(rr_intervals)) > 50),
        'pnn50': np.sum(np.abs(np.diff(rr_intervals)) > 50) / len(rr_intervals) * 100,
        'mean_hr': 60000 / np.mean(rr_intervals),
        'min_hr': 60000 / np.max(rr_intervals),
        'max_hr': 60000 / np.min(rr_intervals)
    }
    
    return hrv_features, rr_intervals

# Calculate HRV features
hrv_features, rr_intervals = calculate_hrv_features(r_peaks, sampling_rate)

print("Heart Rate Variability Features:")
for key, value in hrv_features.items():
    print(f"  {key}: {value:.2f}")

In [None]:
# Visualize RR intervals
fig, axes = plt.subplots(2, 1, figsize=(15, 8))

# RR interval tachogram
axes[0].plot(rr_intervals, linewidth=1, color='#F77F00', marker='o', markersize=3)
axes[0].axhline(y=np.mean(rr_intervals), color='#E63946', linestyle='--', 
               linewidth=2, label=f'Mean: {np.mean(rr_intervals):.1f} ms')
axes[0].set_title('RR Interval Tachogram', fontsize=14, fontweight='bold')
axes[0].set_ylabel('RR Interval (ms)', fontsize=12)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# RR interval histogram
axes[1].hist(rr_intervals, bins=50, color='#4361EE', alpha=0.7, edgecolor='black')
axes[1].axvline(x=np.mean(rr_intervals), color='#E63946', linestyle='--', 
               linewidth=2, label=f'Mean: {np.mean(rr_intervals):.1f} ms')
axes[1].set_title('RR Interval Distribution', fontsize=14, fontweight='bold')
axes[1].set_xlabel('RR Interval (ms)', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Wavelet Features

In [None]:
def extract_wavelet_features(signal_segment, wavelet='db4', level=4):
    """Extract wavelet decomposition features"""
    coeffs = pywt.wavedec(signal_segment, wavelet, level=level)
    
    features = []
    for i, coeff in enumerate(coeffs):
        features.extend([
            np.mean(np.abs(coeff)),
            np.std(coeff),
            np.max(np.abs(coeff)),
            np.sum(coeff ** 2)  # Energy
        ])
    
    return np.array(features)

# Extract wavelet features from first QRS complex
wavelet_features = extract_wavelet_features(qrs_complexes[0])
print(f"Extracted {len(wavelet_features)} wavelet features")
print(f"Features: {wavelet_features[:10]}...")  # Show first 10

In [None]:
# Visualize wavelet decomposition
coeffs = pywt.wavedec(qrs_complexes[0], 'db4', level=4)

fig, axes = plt.subplots(len(coeffs), 1, figsize=(12, 10))

axes[0].plot(coeffs[0], linewidth=1, color='#E63946')
axes[0].set_title('Approximation Coefficients (cA4)', fontsize=12, fontweight='bold')
axes[0].grid(True, alpha=0.3)

for i in range(1, len(coeffs)):
    axes[i].plot(coeffs[i], linewidth=1, color='#4361EE')
    axes[i].set_title(f'Detail Coefficients (cD{len(coeffs)-i})', fontsize=12, fontweight='bold')
    axes[i].grid(True, alpha=0.3)

axes[-1].set_xlabel('Sample', fontsize=12)
plt.tight_layout()
plt.show()

## 5. Statistical Features

In [None]:
def extract_statistical_features(signal_segment):
    """Extract statistical features from signal segment"""
    features = {
        'mean': np.mean(signal_segment),
        'std': np.std(signal_segment),
        'var': np.var(signal_segment),
        'min': np.min(signal_segment),
        'max': np.max(signal_segment),
        'range': np.ptp(signal_segment),
        'median': np.median(signal_segment),
        'skewness': skew(signal_segment),
        'kurtosis': kurtosis(signal_segment),
        'rms': np.sqrt(np.mean(signal_segment ** 2)),
        'energy': np.sum(signal_segment ** 2),
        'zero_crossings': np.sum(np.diff(np.sign(signal_segment)) != 0)
    }
    return features

# Extract statistical features from all QRS complexes
all_stats = []
for qrs in qrs_complexes[:10]:  # First 10 for demonstration
    stats = extract_statistical_features(qrs)
    all_stats.append(stats)

stats_df = pd.DataFrame(all_stats)
print("Statistical Features (first 10 QRS complexes):")
print(stats_df.head())

In [None]:
# Visualize feature distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

features_to_plot = ['mean', 'std', 'skewness', 'kurtosis', 'energy', 'zero_crossings']
colors = ['#E63946', '#F77F00', '#4361EE', '#06A77D', '#9D4EDD', '#F72585']

for i, (feature, color) in enumerate(zip(features_to_plot, colors)):
    axes[i].hist(stats_df[feature], bins=20, color=color, alpha=0.7, edgecolor='black')
    axes[i].set_title(f'{feature.capitalize()} Distribution', fontsize=12, fontweight='bold')
    axes[i].set_xlabel(feature.capitalize(), fontsize=10)
    axes[i].set_ylabel('Frequency', fontsize=10)
    axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Feature Summary

In [None]:
# Combine all features for one QRS complex
def extract_all_features(qrs_complex, r_peaks, sampling_rate):
    """Extract all features from a QRS complex"""
    # Statistical features
    stat_features = extract_statistical_features(qrs_complex)
    
    # Wavelet features
    wav_features = extract_wavelet_features(qrs_complex)
    
    # Combine
    all_features = list(stat_features.values()) + list(wav_features)
    
    return np.array(all_features)

# Extract features for all QRS complexes
feature_matrix = []
for qrs in qrs_complexes:
    features = extract_all_features(qrs, r_peaks, sampling_rate)
    feature_matrix.append(features)

feature_matrix = np.array(feature_matrix)
print(f"\nFeature Matrix Shape: {feature_matrix.shape}")
print(f"Total features per QRS complex: {feature_matrix.shape[1]}")
print(f"Total QRS complexes: {feature_matrix.shape[0]}")

## Summary

This notebook demonstrated:
- R-peak detection using Pan-Tompkins algorithm
- QRS complex extraction and analysis
- Heart Rate Variability (HRV) feature calculation
- Wavelet decomposition features
- Statistical feature extraction

These features are now ready for model training and classification!