In [None]:
# ECG Disease Detection System using PTB-XL Dataset
# Complete implementation for medical-grade ECG analysis

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.ensemble import RandomForestClassifier
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense, Dropout, BatchNormalization, Input, GlobalAveragePooling1D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
import wfdb
import os
import ast
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

class ECGDiseaseDetector:
    def __init__(self, data_path='./ptb-xl/'):
        self.data_path = data_path
        self.sampling_rate = 100  # Using 100Hz version for faster processing
        self.signal_length = 1000  # 10 seconds * 100Hz
        self.n_leads = 12
        self.model = None
        self.scaler = StandardScaler()
        self.label_encoder = LabelEncoder()

        # Disease mapping with symptoms
        self.disease_symptoms = {
            'NORM': {
                'disease': 'Normal ECG',
                'symptoms': 'No cardiac abnormalities detected. Regular heart rhythm and electrical activity.'
            },
            'MI': {
                'disease': 'Myocardial Infarction (Heart Attack)',
                'symptoms': 'Chest pain, shortness of breath, nausea, sweating, arm or jaw pain, fatigue.'
            },
            'STTC': {
                'disease': 'ST/T Change',
                'symptoms': 'Chest discomfort, palpitations, shortness of breath, fatigue, dizziness.'
            },
            'CD': {
                'disease': 'Conduction Disturbance',
                'symptoms': 'Palpitations, dizziness, fainting, chest pain, fatigue, shortness of breath.'
            },
            'HYP': {
                'disease': 'Hypertrophy',
                'symptoms': 'Chest pain, shortness of breath, fatigue, palpitations, swelling in legs/ankles.'
            },
            'AFIB': {
                'disease': 'Atrial Fibrillation',
                'symptoms': 'Irregular heartbeat, palpitations, shortness of breath, weakness, fatigue, dizziness.'
            },
            'AFLT': {
                'disease': 'Atrial Flutter',
                'symptoms': 'Rapid heartbeat, palpitations, shortness of breath, chest discomfort, fatigue.'
            },
            'BRADY': {
                'disease': 'Bradycardia',
                'symptoms': 'Slow heart rate, fatigue, dizziness, fainting, confusion, shortness of breath.'
            },
            'TACHY': {
                'disease': 'Tachycardia',
                'symptoms': 'Rapid heart rate, palpitations, shortness of breath, chest pain, dizziness, fainting.'
            },
            'PAC': {
                'disease': 'Premature Atrial Contraction',
                'symptoms': 'Palpitations, feeling of skipped beats, chest discomfort, anxiety.'
            },
            'PVC': {
                'disease': 'Premature Ventricular Contraction',
                'symptoms': 'Palpitations, feeling of skipped beats, chest discomfort, fatigue.'
            },
            'LBBB': {
                'disease': 'Left Bundle Branch Block',
                'symptoms': 'May be asymptomatic, or chest pain, shortness of breath, fatigue, dizziness.'
            },
            'RBBB': {
                'disease': 'Right Bundle Branch Block',
                'symptoms': 'Often asymptomatic, occasionally palpitations or chest discomfort.'
            },
            'LVH': {
                'disease': 'Left Ventricular Hypertrophy',
                'symptoms': 'Chest pain, shortness of breath, fatigue, palpitations, dizziness.'
            },
            'LAD': {
                'disease': 'Left Axis Deviation',
                'symptoms': 'Usually asymptomatic, may indicate underlying heart condition.'
            },
            'RAD': {
                'disease': 'Right Axis Deviation',
                'symptoms': 'Usually asymptomatic, may indicate underlying heart condition.'
            },
            'QWAVE': {
                'disease': 'Q Wave Abnormality',
                'symptoms': 'May indicate previous heart attack, chest pain, shortness of breath.'
            }
        }

    def load_and_prepare_data(self):
        """Load and prepare PTB-XL dataset"""
        print("Loading PTB-XL dataset...")

        # Load metadata
        try:
            df = pd.read_csv(os.path.join(self.data_path, 'ptbxl_database.csv'), index_col='ecg_id')
            print(f"Loaded metadata for {len(df)} ECG records")
        except FileNotFoundError:
            print("Error: ptbxl_database.csv not found. Please ensure the PTB-XL dataset is properly extracted.")
            return None, None, None, None

        # Load scp_statements for disease descriptions
        scp_statements = pd.read_csv(os.path.join(self.data_path, 'scp_statements.csv'), index_col=0)

        # Process diagnostic labels
        df['scp_codes'] = df['scp_codes'].apply(lambda x: ast.literal_eval(x) if pd.notna(x) else {})

        # Extract main diagnostic class (superclass)
        def get_main_diagnosis(scp_codes):
            if not scp_codes:
                return 'NORM'
            # Get the diagnosis with highest probability
            main_diag = max(scp_codes.items(), key=lambda x: x[1])
            return main_diag[0]

        df['main_diagnosis'] = df['scp_codes'].apply(get_main_diagnosis)

        # Filter out records with uncertain diagnoses
        df = df[df['main_diagnosis'].isin(self.disease_symptoms.keys())]

        print(f"Disease distribution:")
        print(df['main_diagnosis'].value_counts())

        # Load ECG signals
        print("Loading ECG signals...")
        X, y = self.load_ecg_signals(df)

        if X is None:
            return None, None, None, None

        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )

        print(f"Training set: {X_train.shape}, Test set: {X_test.shape}")
        return X_train, X_test, y_train, y_test

    def load_ecg_signals(self, df):
        """Load ECG signal data"""
        X = []
        y = []
        failed_loads = 0

        for idx, row in df.iterrows():
            try:
                # Try to load from 100Hz folder first
                file_path = os.path.join(self.data_path, 'records100', str(idx).zfill(5))
                if not os.path.exists(file_path + '.dat'):
                    # Try 500Hz folder
                    file_path = os.path.join(self.data_path, 'records500', str(idx).zfill(5))
                    if not os.path.exists(file_path + '.dat'):
                        failed_loads += 1
                        continue

                # Load the signal
                record = wfdb.rdrecord(file_path)
                signal = record.p_signal

                # Handle different sampling rates
                if record.fs == 500:
                    # Downsample to 100Hz
                    signal = signal[::5]

                # Ensure consistent length (1000 samples = 10 seconds at 100Hz)
                if len(signal) > self.signal_length:
                    signal = signal[:self.signal_length]
                elif len(signal) < self.signal_length:
                    # Pad with zeros
                    padding = np.zeros((self.signal_length - len(signal), signal.shape[1]))
                    signal = np.vstack([signal, padding])

                # Handle missing leads (ensure 12 leads)
                if signal.shape[1] < 12:
                    padding = np.zeros((signal.shape[0], 12 - signal.shape[1]))
                    signal = np.hstack([signal, padding])
                elif signal.shape[1] > 12:
                    signal = signal[:, :12]

                X.append(signal)
                y.append(row['main_diagnosis'])

            except Exception as e:
                failed_loads += 1
                continue

        if not X:
            print("Error: No ECG signals could be loaded. Please check dataset structure.")
            return None, None

        print(f"Successfully loaded {len(X)} ECG signals ({failed_loads} failed)")

        X = np.array(X)
        y = np.array(y)

        return X, y

    def preprocess_signals(self, X):
        """Preprocess ECG signals"""
        # Reshape for scaling (samples * leads, time_points)
        original_shape = X.shape
        X_reshaped = X.reshape(-1, X.shape[-1])

        # Apply scaling
        X_scaled = self.scaler.fit_transform(X_reshaped.T).T

        # Reshape back
        X_scaled = X_scaled.reshape(original_shape)

        return X_scaled

    def build_model(self, input_shape, num_classes):
        """Build CNN-LSTM model for ECG classification"""
        model = Sequential([
            # Input layer
            Input(shape=input_shape),

            # CNN layers for feature extraction
            Conv1D(filters=64, kernel_size=7, activation='relu', padding='same'),
            BatchNormalization(),
            MaxPooling1D(pool_size=2),
            Dropout(0.2),

            Conv1D(filters=128, kernel_size=5, activation='relu', padding='same'),
            BatchNormalization(),
            MaxPooling1D(pool_size=2),
            Dropout(0.2),

            Conv1D(filters=256, kernel_size=3, activation='relu', padding='same'),
            BatchNormalization(),
            MaxPooling1D(pool_size=2),
            Dropout(0.3),

            # LSTM layers for temporal patterns
            LSTM(128, return_sequences=True, dropout=0.3, recurrent_dropout=0.3),
            LSTM(64, dropout=0.3, recurrent_dropout=0.3),

            # Dense layers
            Dense(128, activation='relu'),
            BatchNormalization(),
            Dropout(0.5),

            Dense(64, activation='relu'),
            Dropout(0.3),

            # Output layer
            Dense(num_classes, activation='softmax')
        ])

        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

        return model

    def train_model(self, X_train, y_train, X_val, y_val):
        """Train the ECG classification model"""
        print("Training ECG classification model...")

        # Preprocess signals
        X_train_scaled = self.preprocess_signals(X_train)
        X_val_scaled = self.scaler.transform(X_val.reshape(-1, X_val.shape[-1]).T).T.reshape(X_val.shape)

        # Encode labels
        y_train_encoded = self.label_encoder.fit_transform(y_train)
        y_val_encoded = self.label_encoder.transform(y_val)

        # Build model
        num_classes = len(np.unique(y_train_encoded))
        self.model = self.build_model((X_train_scaled.shape[1], X_train_scaled.shape[2]), num_classes)

        print(f"Model architecture:")
        self.model.summary()

        # Callbacks
        callbacks = [
            ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7, verbose=1),
            EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
            ModelCheckpoint('best_ecg_model.h5', monitor='val_accuracy', save_best_only=True, verbose=1)
        ]

        # Train model
        history = self.model.fit(
            X_train_scaled, y_train_encoded,
            validation_data=(X_val_scaled, y_val_encoded),
            epochs=100,
            batch_size=32,
            callbacks=callbacks,
            verbose=1
        )

        return history

    def predict_disease(self, ecg_signal):
        """Predict disease from ECG signal with detailed output"""
        if self.model is None:
            return "Error: Model not trained yet"

        # Preprocess the signal
        if len(ecg_signal.shape) == 2:
            ecg_signal = ecg_signal.reshape(1, ecg_signal.shape[0], ecg_signal.shape[1])

        # Scale the signal
        signal_reshaped = ecg_signal.reshape(-1, ecg_signal.shape[-1])
        signal_scaled = self.scaler.transform(signal_reshaped.T).T.reshape(ecg_signal.shape)

        # Make prediction
        prediction_probs = self.model.predict(signal_scaled)
        predicted_class_idx = np.argmax(prediction_probs[0])
        confidence = np.max(prediction_probs[0]) * 100

        # Get disease label
        predicted_disease_code = self.label_encoder.inverse_transform([predicted_class_idx])[0]

        # Get disease information
        disease_info = self.disease_symptoms.get(predicted_disease_code, {
            'disease': 'Unknown',
            'symptoms': 'Unknown symptoms'
        })

        # Prepare detailed result
        result = {
            'disease_code': predicted_disease_code,
            'disease_name': disease_info['disease'],
            'symptoms': disease_info['symptoms'],
            'confidence': confidence,
            'all_probabilities': {}
        }

        # Add all probabilities
        for i, prob in enumerate(prediction_probs[0]):
            disease_code = self.label_encoder.inverse_transform([i])[0]
            disease_name = self.disease_symptoms.get(disease_code, {}).get('disease', 'Unknown')
            result['all_probabilities'][disease_name] = prob * 100

        return result

    def generate_detailed_report(self, prediction_result):
        """Generate detailed medical report"""
        report = f"""
═══════════════════════════════════════
         ECG ANALYSIS REPORT
═══════════════════════════════════════

PRIMARY DIAGNOSIS:
• Disease: {prediction_result['disease_name']}
• Code: {prediction_result['disease_code']}
• Confidence: {prediction_result['confidence']:.2f}%

CLINICAL SYMPTOMS:
{prediction_result['symptoms']}

PROBABILITY DISTRIBUTION:
"""

        # Sort probabilities by value
        sorted_probs = sorted(prediction_result['all_probabilities'].items(),
                            key=lambda x: x[1], reverse=True)

        for disease, prob in sorted_probs[:5]:  # Top 5 most likely
            report += f"• {disease}: {prob:.2f}%\n"

        report += f"""
RECOMMENDATION:
{'• This ECG appears normal. Continue regular check-ups.' if prediction_result['disease_code'] == 'NORM'
  else '• Consult with a cardiologist for further evaluation and treatment planning.'}
• This AI analysis should be used as a diagnostic aid only.
• Final diagnosis should always be confirmed by a qualified physician.

Analysis Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
═══════════════════════════════════════
"""
        return report

# Usage example and testing
def main():
    # Initialize the ECG detector
    detector = ECGDiseaseDetector('./ptb-xl/')

    # Load and prepare data
    X_train, X_test, y_train, y_test = detector.load_and_prepare_data()

    if X_train is None:
        print("Failed to load data. Please check the dataset path and structure.")
        return

    # Train the model
    history = detector.train_model(X_train, y_train, X_test, y_test)

    # Test with a sample
    print("\n" + "="*50)
    print("TESTING WITH SAMPLE ECG")
    print("="*50)

    # Use first test sample
    test_sample = X_test[0]
    actual_disease = y_test[0]

    # Make prediction
    result = detector.predict_disease(test_sample)

    # Generate detailed report
    report = detector.generate_detailed_report(result)

    print(f"Actual Disease: {actual_disease}")
    print(report)

    # Calculate overall accuracy
    print("\n" + "="*50)
    print("MODEL EVALUATION")
    print("="*50)

    # Predict on all test samples
    correct_predictions = 0
    total_predictions = len(X_test)

    for i in range(min(100, total_predictions)):  # Test first 100 samples for speed
        result = detector.predict_disease(X_test[i])
        if result['disease_code'] == y_test[i]:
            correct_predictions += 1

    accuracy = (correct_predictions / min(100, total_predictions)) * 100
    print(f"Test Accuracy on {min(100, total_predictions)} samples: {accuracy:.2f}%")

    # Save the model
    detector.model.save('ecg_disease_detector.h5')
    print("\nModel saved as 'ecg_disease_detector.h5'")

    return detector

if __name__ == "__main__":
    detector = main()
