In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
import pickle
import os

class ECGTester:
    def __init__(self, model_path='ecg_disease_detector.h5'):
        """Initialize ECG tester with trained model"""
        self.model = tf.keras.models.load_model(model_path)
        self.scaler = None
        self.label_encoder = None
        self.load_preprocessing_tools()

        # Disease information database
        self.disease_info = {
            'NORM': {
                'name': 'Normal ECG',
                'severity': 'None',
                'urgency': 'Routine',
                'symptoms': 'No cardiac abnormalities detected. Regular heart rhythm and electrical activity.',
                'recommendations': 'Continue regular check-ups. Maintain healthy lifestyle.',
                'risk_level': 'Low'
            },
            'MI': {
                'name': 'Myocardial Infarction (Heart Attack)',
                'severity': 'Critical',
                'urgency': 'Emergency',
                'symptoms': 'Severe chest pain, shortness of breath, nausea, sweating, arm/jaw pain, fatigue.',
                'recommendations': 'IMMEDIATE EMERGENCY CARE REQUIRED. Call 911. Administer aspirin if not contraindicated.',
                'risk_level': 'Critical'
            },
            'STTC': {
                'name': 'ST/T Changes',
                'severity': 'Moderate',
                'urgency': 'Urgent',
                'symptoms': 'Chest discomfort, palpitations, shortness of breath, fatigue, possible dizziness.',
                'recommendations': 'Cardiology consultation within 24-48 hours. Monitor symptoms closely.',
                'risk_level': 'Moderate to High'
            },
            'CD': {
                'name': 'Conduction Disturbance',
                'severity': 'Moderate',
                'urgency': 'Urgent',
                'symptoms': 'Palpitations, dizziness, syncope, chest pain, fatigue, shortness of breath.',
                'recommendations': 'Cardiology evaluation. Consider pacemaker assessment if symptomatic.',
                'risk_level': 'Moderate'
            },
            'HYP': {
                'name': 'Cardiac Hypertrophy',
                'severity': 'Moderate',
                'urgency': 'Non-urgent',
                'symptoms': 'Chest pain, shortness of breath, fatigue, palpitations, peripheral edema.',
                'recommendations': 'Echocardiogram, blood pressure management, lifestyle modifications.',
                'risk_level': 'Moderate'
            },
            'AFIB': {
                'name': 'Atrial Fibrillation',
                'severity': 'Moderate to High',
                'urgency': 'Urgent',
                'symptoms': 'Irregular rapid heartbeat, palpitations, shortness of breath, weakness, fatigue.',
                'recommendations': 'Anticoagulation assessment, rate/rhythm control, stroke risk evaluation.',
                'risk_level': 'High'
            },
            'AFLT': {
                'name': 'Atrial Flutter',
                'severity': 'Moderate',
                'urgency': 'Urgent',
                'symptoms': 'Rapid regular heartbeat, palpitations, shortness of breath, chest discomfort.',
                'recommendations': 'Cardioversion consideration, anticoagulation, electrophysiology consult.',
                'risk_level': 'Moderate to High'
            },
            'BRADY': {
                'name': 'Bradycardia',
                'severity': 'Mild to Moderate',
                'urgency': 'Non-urgent to Urgent',
                'symptoms': 'Slow heart rate, fatigue, dizziness, syncope, confusion, exercise intolerance.',
                'recommendations': 'Evaluate for underlying causes, consider pacemaker if symptomatic.',
                'risk_level': 'Low to Moderate'
            },
            'TACHY': {
                'name': 'Tachycardia',
                'severity': 'Mild to High',
                'urgency': 'Variable',
                'symptoms': 'Rapid heart rate, palpitations, shortness of breath, chest pain, dizziness.',
                'recommendations': 'Identify underlying cause, manage triggers, consider antiarrhythmic therapy.',
                'risk_level': 'Variable'
            }
        }

    def load_preprocessing_tools(self):
        """Load saved preprocessing tools"""
        try:
            with open('scaler.pkl', 'rb') as f:
                self.scaler = pickle.load(f)
            with open('label_encoder.pkl', 'rb') as f:
                self.label_encoder = pickle.load(f)
        except FileNotFoundError:
            print("Warning: Preprocessing tools not found. Please train the model first.")
            self.scaler = StandardScaler()

    def process_ecg_image(self, image_path):
        """Convert ECG image to signal data (simplified approach)"""
        try:
            # Load image
            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                raise ValueError("Could not load image")

            # Resize to standard size
            img = cv2.resize(img, (1000, 12))  # 1000 time points, 12 leads

            # Normalize pixel values to signal range
            img = img.astype(np.float32)
            img = (img - img.mean()) / img.std()

            # Transpose to get (time_points, leads)
            signal = img.T

            return signal

        except Exception as e:
            print(f"Error processing image: {e}")
            return None

    def load_ecg_from_wfdb(self, file_path):
        """Load ECG from WFDB format"""
        try:
            import wfdb
            record = wfdb.rdrecord(file_path)
            signal = record.p_signal

            # Ensure 1000 samples and 12 leads
            if len(signal) > 1000:
                signal = signal[:1000]
            elif len(signal) < 1000:
                padding = np.zeros((1000 - len(signal), signal.shape[1]))
                signal = np.vstack([signal, padding])

            if signal.shape[1] < 12:
                padding = np.zeros((signal.shape[0], 12 - signal.shape[1]))
                signal = np.hstack([signal, padding])
            elif signal.shape[1] > 12:
                signal = signal[:, :12]

            return signal
        except Exception as e:
            print(f"Error loading WFDB file: {e}")
            return None

    def analyze_ecg(self, ecg_signal, patient_info=None):
        """Comprehensive ECG analysis"""
        if ecg_signal is None:
            return None

        # Ensure correct shape
        if len(ecg_signal.shape) == 2:
            ecg_signal = ecg_signal.reshape(1, ecg_signal.shape[0], ecg_signal.shape[1])

        # Preprocess
        signal_reshaped = ecg_signal.reshape(-1, ecg_signal.shape[-1])
        if self.scaler:
            signal_scaled = self.scaler.transform(signal_reshaped.T).T.reshape(ecg_signal.shape)
        else:
            signal_scaled = ecg_signal

        # Predict
        prediction_probs = self.model.predict(signal_scaled, verbose=0)
        predicted_idx = np.argmax(prediction_probs[0])
        confidence = np.max(prediction_probs[0]) * 100

        # Get disease code
        if self.label_encoder:
            disease_code = self.label_encoder.inverse_transform([predicted_idx])[0]
        else:
            disease_code = f"CLASS_{predicted_idx}"

        # Get disease information
        disease_info = self.disease_info.get(disease_code, {
            'name': 'Unknown Condition',
            'severity': 'Unknown',
            'urgency': 'Consult physician',
            'symptoms': 'Unknown symptoms',
            'recommendations': 'Consult with cardiologist',
            'risk_level': 'Unknown'
        })

        # Calculate heart rate (simplified)
        heart_rate = self.estimate_heart_rate(ecg_signal[0])

        # Prepare comprehensive result
        result = {
            'patient_info': patient_info or {},
            'primary_diagnosis': {
                'code': disease_code,
                'name': disease_info['name'],
                'confidence': confidence,
                'severity': disease_info['severity'],
                'urgency': disease_info['urgency'],
                'risk_level': disease_info['risk_level']
            },
            'clinical_findings': {
                'symptoms': disease_info['symptoms'],
                'heart_rate': heart_rate,
                'rhythm_analysis': self.analyze_rhythm(ecg_signal[0])
            },
            'recommendations': disease_info['recommendations'],
            'differential_diagnosis': {},
            'technical_details': {
                'model_confidence': confidence,
                'signal_quality': self.assess_signal_quality(ecg_signal[0])
            }
        }

        # Add differential diagnosis (top 3 possibilities)
        if self.label_encoder:
            sorted_indices = np.argsort(prediction_probs[0])[::-1][:3]
            for i, idx in enumerate(sorted_indices):
                code = self.label_encoder.inverse_transform([idx])[0]
                prob = prediction_probs[0][idx] * 100
                name = self.disease_info.get(code, {}).get('name', f'Class {idx}')
                result['differential_diagnosis'][f'option_{i+1}'] = {
                    'name': name,
                    'probability': prob,
                    'code': code
                }

        return result

    def estimate_heart_rate(self, signal):
        """Estimate heart rate from ECG signal (Lead II typically)"""
        try:
            # Use Lead II (index 1) for heart rate calculation
            lead_ii = signal[:, 1] if signal.shape[1] > 1 else signal[:, 0]

            # Find R peaks (simplified peak detection)
            from scipy.signal import find_peaks
            peaks, _ = find_peaks(lead_ii, height=np.std(lead_ii), distance=50)

            if len(peaks) > 1:
                # Calculate average RR interval
                rr_intervals = np.diff(peaks) / 100  # Convert to seconds (100 Hz sampling)
                avg_rr = np.mean(rr_intervals)
                heart_rate = 60 / avg_rr  # Convert to BPM
                return round(heart_rate)
            else:
                return "Cannot determine"
        except:
            return "Cannot determine"

    def analyze_rhythm(self, signal):
        """Basic rhythm analysis"""
        try:
            # Simplified rhythm analysis
            lead_ii = signal[:, 1] if signal.shape[1] > 1 else signal[:, 0]

            # Calculate rhythm regularity
            from scipy.signal import find_peaks
            peaks, _ = find_peaks(lead_ii, height=np.std(lead_ii), distance=30)

            if len(peaks) > 2:
                rr_intervals = np.diff(peaks)
                regularity = np.std(rr_intervals) / np.mean(rr_intervals)

                if regularity < 0.1:
                    return "Regular rhythm"
                elif regularity < 0.3:
                    return "Slightly irregular rhythm"
                else:
                    return "Irregular rhythm"
            else:
                return "Insufficient data for rhythm analysis"
        except:
            return "Cannot determine rhythm"

    def assess_signal_quality(self, signal):
        """Assess ECG signal quality"""
        try:
            # Simple signal quality assessment
            signal_std = np.std(signal)
            noise_level = np.std(np.diff(signal, axis=0))
            snr = signal_std / (noise_level + 1e-8)

            if snr > 10:
                return "Excellent"
            elif snr > 5:
                return "Good"
            elif snr > 2:
                return "Fair"
            else:
                return "Poor"
        except:
            return "Cannot assess"

    def generate_medical_report(self, analysis_result):
        """Generate comprehensive medical report"""
        if analysis_result is None:
            return "Error: Analysis failed"

        patient_info = analysis_result.get('patient_info', {})
        primary = analysis_result['primary_diagnosis']
        clinical = analysis_result['clinical_findings']
        recommendations = analysis_result['recommendations']
        technical = analysis_result['technical_details']
        differential = analysis_result['differential_diagnosis']

        report = f"""
╔═══════════════════════════════════════════════════════════════╗
║                     ECG ANALYSIS REPORT                      ║
║                   AI-ASSISTED DIAGNOSIS                      ║
╚═══════════════════════════════════════════════════════════════╝

PATIENT INFORMATION:
• Name: {patient_info.get('name', 'Not provided')}
• Age: {patient_info.get('age', 'Not provided')}
• Gender: {patient_info.get('gender', 'Not provided')}
• ID: {patient_info.get('id', 'Not provided')}
• Date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}

╔═══════════════════════════════════════════════════════════════╗
║                    PRIMARY DIAGNOSIS                         ║
╚═══════════════════════════════════════════════════════════════╝

CONDITION: {primary['name']}
DIAGNOSTIC CODE: {primary['code']}
CONFIDENCE LEVEL: {primary['confidence']:.1f}%
SEVERITY: {primary['severity']}
URGENCY: {primary['urgency']}
RISK LEVEL: {primary['risk_level']}

╔═══════════════════════════════════════════════════════════════╗
║                   CLINICAL FINDINGS                          ║
╚═══════════════════════════════════════════════════════════════╝

HEART RATE: {clinical['heart_rate']} BPM
RHYTHM: {clinical['rhythm_analysis']}
SIGNAL QUALITY: {technical['signal_quality']}

ASSOCIATED SYMPTOMS:
{clinical['symptoms']}

╔═══════════════════════════════════════════════════════════════╗
║                DIFFERENTIAL DIAGNOSIS                        ║
╚═══════════════════════════════════════════════════════════════╝
"""

        for key, diff in differential.items():
            report += f"• {diff['name']}: {diff['probability']:.1f}%\n"

        report += f"""
╔═══════════════════════════════════════════════════════════════╗
║                 CLINICAL RECOMMENDATIONS                     ║
╚═══════════════════════════════════════════════════════════════╝

{recommendations}

╔═══════════════════════════════════════════════════════════════╗
║                   IMPORTANT DISCLAIMERS                      ║
╚═══════════════════════════════════════════════════════════════╝

• This AI analysis is a diagnostic aid and should NOT replace
  clinical judgment by qualified medical professionals.
• Final diagnosis and treatment decisions must be made by a
  licensed physician after complete clinical evaluation.
• In case of emergency symptoms, seek immediate medical attention.
• This analysis is based on ECG signal patterns and may not
  account for patient history, medications, or other factors.

╔═══════════════════════════════════════════════════════════════╗
║                    TECHNICAL DETAILS                         ║
╚═══════════════════════════════════════════════════════════════╝

Model Confidence: {technical['model_confidence']:.2f}%
Signal Quality: {technical['signal_quality']}
Analysis Method: Deep Learning CNN-LSTM
Training Dataset: PTB-XL (100,000+ ECGs)
Model Version: ECG-AI v1.0

Report Generated: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}
╚═══════════════════════════════════════════════════════════════╝
"""
        return report

    def test_with_file(self, file_path, patient_info=None, file_type='auto'):
        """Test ECG model with file input"""
        print(f"Loading ECG from: {file_path}")

        # Determine file type
        if file_type == 'auto':
            ext = os.path.splitext(file_path)[1].lower()
            if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
                file_type = 'image'
            elif ext in ['.dat', '.hea']:
                file_type = 'wfdb'
            else:
                file_type = 'image'  # Default

        # Load ECG signal
        if file_type == 'image':
            ecg_signal = self.process_ecg_image(file_path)
        elif file_type == 'wfdb':
            ecg_signal = self.load_ecg_from_wfdb(file_path.replace('.dat', '').replace('.hea', ''))
        else:
            print("Unsupported file type")
            return None

        if ecg_signal is None:
            print("Failed to load ECG signal")
            return None

        # Analyze ECG
        analysis = self.analyze_ecg(ecg_signal, patient_info)

        # Generate report
        report = self.generate_medical_report(analysis)

        return analysis, report

    def plot_ecg(self, ecg_signal, title="ECG Signal"):
        """Plot ECG signal for visualization"""
        if ecg_signal is None:
            print("No signal to plot")
            return

        # Remove batch dimension if present
        if len(ecg_signal.shape) == 3:
            ecg_signal = ecg_signal[0]

        # Lead names
        lead_names = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

        plt.figure(figsize=(15, 12))

        # Plot each lead
        for i in range(min(12, ecg_signal.shape[1])):
            plt.subplot(4, 3, i+1)
            time_axis = np.arange(len(ecg_signal)) / 100  # 100 Hz sampling rate
            plt.plot(time_axis, ecg_signal[:, i], 'b-', linewidth=1)
            plt.title(f'Lead {lead_names[i]}')
            plt.xlabel('Time (s)')
            plt.ylabel('Amplitude (mV)')
            plt.grid(True, alpha=0.3)

        plt.suptitle(title, fontsize=16)
        plt.tight_layout()
        plt.show()

# Demo and testing functions
def demo_test():
    """Demo function to test the ECG analyzer"""
    print("ECG Disease Detection System - Demo Test")
    print("=" * 50)

    # Initialize tester
    try:
        tester = ECGTester()
        print("✓ Model loaded successfully")
    except Exception as e:
        print(f"✗ Error loading model: {e}")
        print("Please ensure the model file 'ecg_disease_detector.h5' exists")
        return

    # Create synthetic test data for demonstration
    print("\nGenerating synthetic ECG for demonstration...")

    # Generate a synthetic normal ECG pattern
    time = np.linspace(0, 10, 1000)  # 10 seconds
    ecg_synthetic = np.zeros((1000, 12))

    for lead in range(12):
        # Create basic ECG pattern with some variation per lead
        base_signal = np.sin(2 * np.pi * 1.2 * time)  # Base heart rate ~72 BPM

        # Add P wave
        p_wave = 0.1 * np.sin(2 * np.pi * 4 * time)

        # Add QRS complex
        qrs = np.zeros_like(time)
        for beat in range(12):  # ~12 beats in 10 seconds
            beat_time = beat * 0.83  # ~72 BPM
            beat_idx = int(beat_time * 100)  # Convert to sample index
            if beat_idx < len(qrs) - 10:
                qrs[beat_idx:beat_idx+10] = 0.8 * np.hanning(10)

        # Add T wave
        t_wave = 0.2 * np.sin(2 * np.pi * 2 * time + np.pi/4)

        # Combine and add noise
        ecg_synthetic[:, lead] = base_signal + p_wave + qrs + t_wave + 0.05 * np.random.randn(1000)

        # Add lead-specific variations
        if lead in [0, 1, 2]:  # Limb leads
            ecg_synthetic[:, lead] *= 1.2
        elif lead >= 6:  # Precordial leads
            ecg_synthetic[:, lead] *= 0.8 + 0.2 * (lead - 6) / 6

    # Test patient information
    patient_info = {
        'name': 'John Doe (Demo Patient)',
        'age': 45,
        'gender': 'Male',
        'id': 'DEMO001'
    }

    print("Analyzing synthetic ECG...")

    # Analyze the synthetic ECG
    analysis, report = tester.test_with_synthetic_ecg(ecg_synthetic, patient_info)

    if analysis:
        print("\n" + "="*80)
        print("ANALYSIS COMPLETE")
        print("="*80)
        print(report)

        # Plot the ECG
        tester.plot_ecg(ecg_synthetic, "Demo ECG - Synthetic Normal Pattern")

    else:
        print("Analysis failed")

def test_with_real_ecg(file_path, patient_name="Patient", patient_age="Unknown", patient_gender="Unknown"):
    """Test with real ECG file"""
    print(f"Testing ECG Disease Detector with real ECG: {file_path}")
    print("=" * 60)

    # Patient information
    patient_info = {
        'name': patient_name,
        'age': patient_age,
        'gender': patient_gender,
        'id': f"PAT_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}"
    }

    # Initialize tester
    try:
        tester = ECGTester()
        print("✓ Model loaded successfully")
    except Exception as e:
        print(f"✗ Error loading model: {e}")
        return

    # Test with file
    result = tester.test_with_file(file_path, patient_info)

    if result:
        analysis, report = result
        print("\n" + "="*80)
        print("MEDICAL ANALYSIS COMPLETE")
        print("="*80)
        print(report)

        # Save report to file
        report_filename = f"ECG_Report_{patient_info['id']}.txt"
        with open(report_filename, 'w') as f:
            f.write(report)
        print(f"\n✓ Report saved to: {report_filename}")

        return analysis
    else:
        print("✗ Analysis failed")
        return None

# Additional utility functions
class ECGTester(ECGTester):  # Extend the class
    def test_with_synthetic_ecg(self, ecg_signal, patient_info=None):
        """Test with synthetic ECG data"""
        analysis = self.analyze_ecg(ecg_signal, patient_info)
        report = self.generate_medical_report(analysis)
        return analysis, report

    def batch_test(self, file_list, output_dir="./ecg_reports/"):
        """Test multiple ECG files in batch"""
        os.makedirs(output_dir, exist_ok=True)
        results = []

        for i, file_path in enumerate(file_list):
            print(f"Processing file {i+1}/{len(file_list)}: {file_path}")

            patient_info = {
                'name': f'Patient_{i+1}',
                'id': f'BATCH_{i+1:03d}'
            }

            result = self.test_with_file(file_path, patient_info)
            if result:
                analysis, report = result

                # Save individual report
                report_path = os.path.join(output_dir, f"Report_{i+1:03d}.txt")
                with open(report_path, 'w') as f:
                    f.write(report)

                results.append({
                    'file': file_path,
                    'diagnosis': analysis['primary_diagnosis']['name'],
                    'confidence': analysis['primary_diagnosis']['confidence'],
                    'severity': analysis['primary_diagnosis']['severity']
                })
            else:
                results.append({
                    'file': file_path,
                    'diagnosis': 'FAILED',
                    'confidence': 0,
                    'severity': 'N/A'
                })

        # Create summary report
        summary_df = pd.DataFrame(results)
        summary_path = os.path.join(output_dir, "batch_summary.csv")
        summary_df.to_csv(summary_path, index=False)

        print(f"\nBatch processing complete. Summary saved to: {summary_path}")
        return results

# Main execution
if __name__ == "__main__":
    print("ECG Disease Detection System - Ready for Testing")
    print("=" * 60)
    print("Available functions:")
    print("1. demo_test() - Run demo with synthetic ECG")
    print("2. test_with_real_ecg('file_path', 'name', age, 'gender') - Test with real ECG")
    print("3. ECGTester().batch_test(['file1', 'file2', ...]) - Batch testing")
    print("\nExample usage:")
    print("demo_test()")
    print("test_with_real_ecg('patient_ecg.png', 'John Smith', 65, 'Male')")

    # Uncomment the line below to run demo automatically
    # demo_test()