Deep Learning for ECG Arrhythmia Classification: A PLRNN Framework on Large-Scale Clinical Data

# ECG 5-Class Classification Deep Learning Project

## Project Overview

This project implements a 5-class ECG classification system using the MIMIC-IV-ECG dataset, automatically identifying atrial fibrillation, bradycardia, bundle branch block, normal rhythm, and tachycardia.

### Key Features
- 🏥 **Medical Feature Engineering**: Extracts 8 core medical features including heart rate and HRV
- 🧠 **LSTM Temporal Modeling**: Captures sequential dependencies in ECG signals
- ⚖️ **Class Balancing**: Addresses severe data imbalance issues
- 🔧 **Data Augmentation**: Lightweight noise and amplitude augmentation
- 💻 **Hardware Compatibility**: Optimized for Apple Silicon (M4) chips

## Directory Structure

```
ECG-Classification/
├── too_feature.py              # Main training script
├── ecg_5_class_data.csv       # Label data file
├── ecg_stable_lstm_model.keras # Trained model
├── README.md                   # Project documentation
└── mimic-iv-ecg/              # ECG waveform data directory
```

## Requirements

### Hardware Requirements
- **Recommended**: Apple Silicon (M1/M2/M3/M4) Mac
- **Memory**: At least 8GB RAM
- **Storage**: At least 20GB available space

### Software Environment
- Python 3.10/3.11 (Note: Python 3.13 not supported)
- TensorFlow 2.12+ (Apple Silicon optimized)
- See requirements.txt for other dependencies

### Environment Setup

```bash
# Create conda environment
conda create -n tf_final python=3.11
conda activate tf_final

# Install TensorFlow (Apple Silicon)
pip install tensorflow-macos tensorflow-metal

# Install other dependencies
pip install pandas numpy scipy scikit-learn wfdb tqdm
```

## Complete Workflow

### 1. Data Preprocessing Stage

#### 1.1 Data Loading and Validation
```python
# Load label file
full_df = pd.read_csv('ecg_5_class_data.csv', header=None,
                     names=['subject_id', 'waveform_path', 'ecg_category'])
```

**Key Points**:
- Dataset contains 366,301 records
- 5 classes: AF(240,717), Tachycardia(60,809), Bradycardia(32,508), Normal(21,950), BBB(10,317)
- Severe class imbalance issue

#### 1.2 Dataset Splitting Strategy
```python
# Split by patient ID to avoid data leakage
all_subjects = full_df['subject_id'].unique()
train_val_subjects, test_subjects = train_test_split(all_subjects, test_size=0.15)
train_subjects, val_subjects = train_test_split(train_val_subjects, test_size=0.15)
```

**Key Points**:
- **Patient-level splitting**: Ensures data from same patient doesn't appear in both train and test
- **Prevent data leakage**: Critical for medical AI projects

#### 1.3 Balanced Sampling Mechanism
```python
def balanced_sampling(df, target_samples, random_state=42):
    categories = df['ecg_category'].unique()
    samples_per_class = target_samples // len(categories)
    # Sample same number for each class, use replacement for minority classes
```

**Key Points**:
- Solves class imbalance: Equal samples per class
- Oversampling: With replacement for minority classes
- Final data: 1500 train, 300 validation, 400 test samples

### 2. Signal Preprocessing Stage

#### 2.1 ECG Signal Preprocessing
```python
def stable_preprocess_ecg(raw_signal, target_length=500):
    # 1. Data validation and type conversion
    # 2. Resample to fixed length
    # 3. Channel-wise normalization
    # 4. Outlier clipping
```

**Key Points**:
- **Fixed length**: 500 time points for batch processing
- **12-lead**: Preserves complete ECG information
- **Robust normalization**: Independent normalization per lead
- **Outlier handling**: Clip to [-3,3] range

#### 2.2 Data Augmentation Strategy
```python
def lightweight_augmentation(signal):
    # 30% probability Gaussian noise
    # 20% probability amplitude scaling
```

**Key Points**:
- **Lightweight design**: Avoids over-augmentation affecting medical features
- **Realism preservation**: Simulates natural variations in clinical environment
- **Training only**: No augmentation during validation/testing

### 3. Medical Feature Engineering

#### 3.1 Core Medical Feature Extraction
```python
def extract_core_medical_features(signal, fs=100):
    # Statistical features: mean, std, skewness, kurtosis
    # Heart rate features: HR, HRV, RMSSD, CV_RR
```

**Key Feature Descriptions**:

| Feature Name | Medical Significance | Normal Range |
|-------------|---------------------|--------------|
| **Heart Rate (HR)** | Beats per minute | 60-100 bpm |
| **SDNN** | RR interval standard deviation, overall HRV | 20-50 ms |
| **RMSSD** | Root mean square of successive RR differences | 15-40 ms |
| **CV_RR** | Coefficient of variation of RR intervals | 0.03-0.07 |

#### 3.2 QRS Detection Algorithm
```python
peaks, _ = find_peaks(lead_ii, height=np.std(lead_ii)*0.5, distance=fs//4)
rr_intervals = np.diff(peaks) / fs
```

**Key Points**:
- **Lead selection**: Uses Lead II for R-wave detection
- **Adaptive threshold**: Dynamic threshold based on signal std
- **Distance constraint**: Minimum distance prevents duplicate detection

### 4. Model Architecture Design

#### 4.1 Lightweight LSTM Architecture
```python
def create_lightweight_lstm_model():
    # Waveform branch: LSTM(32) + Dense(24)
    # Feature branch: Dense(16)
    # Fusion layer: Concatenate + Dense(32) + Output(5)
```

**Architecture Features**:
- **Total parameters**: 8,461 parameters (33.05 KB)
- **Memory efficient**: Optimized for Apple Silicon
- **Dual-branch design**: Deep fusion of waveform and medical features

#### 4.2 Model Component Details

```
Input Layers:
├── Waveform Input: (None, 500, 12) - 500 timepoints × 12 leads
└── Feature Input: (None, 8) - 8 medical features

Waveform Branch:
└── LSTM(32, dropout=0.2) → Dense(24) → BatchNorm → Dropout(0.3)

Feature Branch:
└── Dense(16) → BatchNorm → Dropout(0.2)

Fusion & Output:
└── Concatenate → Dense(32) → BatchNorm → Dropout(0.3) → Dense(5)
```

### 5. Training Strategy

#### 5.1 Optimizer Configuration
```python
optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.0005)
class_weight=class_weight_dict  # Balanced class weights
```

#### 5.2 Callbacks
```python
callbacks = [
    EarlyStopping(patience=5, monitor='val_accuracy'),
    ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-6)
]
```

**Key Points**:
- **Early stopping**: Prevents overfitting
- **Learning rate scheduling**: Dynamic learning rate adjustment
- **Class weights**: Automatic class importance balancing

### 6. Hardware Optimization (Apple Silicon Specific)

#### 6.1 GPU Disable Configuration
```python
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_METAL_DEVICE_ENABLE'] = '0'
tf.config.set_visible_devices([], 'GPU')
```

**Critical Optimizations**:
- **Force CPU mode**: Avoids M4 chip Metal GPU compatibility issues
- **Memory management**: Frequent garbage collection prevents memory leaks
- **Batch size**: Optimized to 4, balancing performance and stability

### 7. Performance Evaluation

#### 7.1 Final Results
- **Accuracy**: 32.5%
- **Confidence**: 32.7% ± 8.7%
- **Best class**: Tachycardia (F1=0.42)

#### 7.2 Medical Feature Statistics
```
Heart Rate: 127±16 bpm (covers normal to abnormal range)
HRV: 45±55 ms (shows heart rhythm variability)
```

## Usage

### Quick Start
```bash
# 1. Prepare data
# Ensure ecg_5_class_data.csv and MIMIC-IV-ECG data are in correct paths

# 2. Run training
python too_feature.py

# 3. View results
# Model saved as ecg_stable_lstm_model.keras
# Training log shows detailed classification report
```

### Custom Configuration
```python
# Adjust training parameters
BATCH_SIZE = 4          # Batch size
EPOCHS = 15             # Training epochs
SEQUENCE_LENGTH = 500   # Sequence length
TRAIN_SAMPLES = 1500    # Training samples
```

## Core Advantages

### 1. Medical Domain Expertise
- ✅ Extracts clinically relevant heart rate variability features
- ✅ Uses standard RR interval analysis methods
- ✅ Follows medical standards for ECG analysis

### 2. Technical Innovation
- ✅ LSTM temporal modeling captures rhythm changes
- ✅ Dual-branch architecture fuses waveform and feature information
- ✅ Lightweight design ensures practical deployability

### 3. Engineering Practicality
- ✅ Addresses real-world class imbalance issues
- ✅ Deep optimization for hardware constraints
- ✅ Complete error handling and robustness design

## Issues and Solutions

### Common Issues

**Q1: SIGBUS Error**
```bash
A: Apple Silicon compatibility issue
Solution:
- Set environment variables to disable Metal GPU
- Reduce batch size and model complexity
- Use legacy optimizer
```

**Q2: Low Accuracy**
```bash
A: 5-class medical task is inherently complex
Improvement directions:
- Increase training sample size
- Extract more medical features
- Use ensemble learning methods
```

**Q3: Memory Issues**
```bash
A: Reduce memory consumption
Solutions:
- Reduce sequence length (500→250)
- Reduce batch size (4→2)
- Reduce training samples
```

## Future Improvements

### 1. Model Architecture
- [ ] Attention mechanism for enhanced temporal modeling
- [ ] Multi-scale CNN for local feature extraction
- [ ] Graph neural networks for lead relationship modeling

### 2. Feature Engineering
- [ ] P-wave and T-wave morphological features
- [ ] Frequency domain power spectral analysis
- [ ] ST-segment deviation detection

### 3. Data Strategy
- [ ] More complex data augmentation
- [ ] Active learning for difficult sample selection
- [ ] Transfer learning with pre-trained models