In [2]:
import os
import wfdb
import pandas as pd
from collections import Counter
import shutil
import numpy as np
import scipy.io
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, confusion_matrix
from imblearn.over_sampling import SMOTE
from imblearn.combine import SMOTETomek
from sklearn.preprocessing import LabelEncoder, StandardScaler
from scipy.signal import welch, find_peaks, butter, filtfilt
from scipy.stats import entropy

dataset_path = r"D:\College\Year Three\Second Term\Medical Equipments 2\ECG Task\Dataset\a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0\WFDBRecords"
filtered_dataset_path = r"D:\College\Year Three\Second Term\Medical Equipments 2\ECG Task\Filtered Data"
diagnosis_codes_csv = "ConditionNames_SNOMED-CT.csv"

In [124]:
def extract_ecg_diagnoses(dataset_path, snomed_csv, output_csv):
    df = pd.read_csv(snomed_csv, encoding="utf-8")
    df.columns = df.columns.str.strip()
    snomed_dict = dict(zip(df['Snomed_CT'].astype(str), df['Full Name']))
    records_data = []
    print(f"Dataset Path: {dataset_path}")
    for root, _, files in os.walk(dataset_path):
        for file in files:
            if file.endswith(".hea"):
                # print("Found .hea file")
                record_base = os.path.splitext(file)[0]
                record_path = os.path.join(root, record_base)
                try:
                    record = wfdb.rdrecord(record_path)
                    diagnosis_codes = [comment.split(": ")[1] for comment in record.comments if "Dx" in comment]
                    diagnosis_codes = diagnosis_codes[0].split(", ") if diagnosis_codes else []
                    diagnosis_names = [snomed_dict.get(code, "Unknown Condition") for code in diagnosis_codes]
                    diagnosis_names_str = ", ".join(diagnosis_names)
                    records_data.append([record_base, diagnosis_names_str])
                except Exception as e:
                    print(f"Error processing {record_base}: {e}")
    df_out = pd.DataFrame(records_data, columns=["Record Name", "Diagnosis"])
    df_out.to_csv(output_csv, index=False)

extract_ecg_diagnoses(dataset_path, diagnosis_codes_csv,
                      "Decoded_Diagnosis.csv")

Dataset Path: D:\College\Year Three\Second Term\Medical Equipments 2\ECG Task\Dataset\a-large-scale-12-lead-electrocardiogram-database-for-arrhythmia-study-1.0.0\WFDBRecords
Error processing JS01052: time data '/' does not match format '%d/%m/%Y'
Error processing JS23074: list index out of range


In [3]:
def count_condition_occurrences(csv_file):
    df = pd.read_csv(csv_file, encoding="utf-8")
    df.columns = df.columns.str.strip()
    if "Record Name" not in df.columns or "Diagnosis" not in df.columns:
        raise ValueError("CSV file must contain 'Record Name' and 'Diagnosis' columns.")
    condition_counts = Counter(df["Diagnosis"])
    df_counts = pd.DataFrame(condition_counts.items(), columns=["Full Name", "Count"])
    df_counts = df_counts.sort_values(by="Count", ascending=False)
    return df_counts

print(count_condition_occurrences("Decoded_Diagnosis.csv"))

                                 Full Name  Count
0                        Unknown Condition  23540
1                        Sinus Bradycardia   8909
2                             Sinus Rhythm   5908
3                        Sinus Tachycardia   3223
5                           Atrial Flutter   1483
9                       Sinus Irregularity   1234
4                      Atrial Fibrillation    422
6             Supraventricular Tachycardia    390
7                       Atrial Tachycardia     31
8   Atrioventricular Reentrant Tachycardia      5
10                 ventricular escape beat      3
11         1 degree atrioventricular block      1
12                  atrial premature beats      1


In [None]:
def filter_and_copy_records(csv_file, source_dir, target_dir, selected_conditions):
    df = pd.read_csv(csv_file)
    df.columns = df.columns.str.strip()
    if "Record Name" not in df.columns or "Diagnosis" not in df.columns:
        raise ValueError("CSV file must contain 'Record Name' and 'Diagnosis' columns.")
    filtered_df = df[df["Diagnosis"].isin(selected_conditions)]
    for _, row in filtered_df.iterrows():
        record_name = row["Record Name"]
        for root, _, files in os.walk(source_dir):
            for file in files:
                if file.startswith(record_name) and (file.endswith(".hea") or file.endswith(".mat")):
                    source_file_path = os.path.join(root, file)
                    relative_path = os.path.relpath(root, source_dir)
                    target_folder = os.path.join(target_dir, relative_path)
                    os.makedirs(target_folder, exist_ok=True)
                    target_file_path = os.path.join(target_folder, file)
                    shutil.copy2(source_file_path, target_file_path)
                    # print(f"Copied: {source_file_path} -> {target_file_path}")

csv_path = "Decoded_Diagnosis.csv"
target_directory = "Filtered Data"
conditions_to_keep = ["Sinus Bradycardia", "Sinus Tachycardia", "Atrial Fibrillation", "Sinus Rhythm"]
# conditions_to_keep = ["Sinus Tachycardia", "Atrial Fibrillation", "Sinus Rhythm"]
filter_and_copy_records(csv_path, dataset_path, target_directory, conditions_to_keep)

In [None]:
# Testing Filtration Process (not necessary to rerun anymore)
extract_ecg_diagnoses(filtered_dataset_path, diagnosis_codes_csv, "test.csv")

In [6]:
def bandpass_filter(signal, fs=250, lowcut=0.5, highcut=40):
    nyquist = fs / 2
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(4, [low, high], btype='band')
    return filtfilt(b, a, signal)

def extract_ecg_features(mat_file):
    data = scipy.io.loadmat(mat_file)
    if 'val' not in data:
        return None
    ecg_signal = bandpass_filter(data['val'][0])
    
    # Statistical features
    mean_val = np.mean(ecg_signal)
    std_dev = np.std(ecg_signal)
    skewness = np.mean((ecg_signal - mean_val) ** 3) / std_dev ** 3
    kurtosis = np.mean((ecg_signal - mean_val) ** 4) / std_dev ** 4
    peak_to_peak = np.ptp(ecg_signal)

    # Time-domain features
    peaks, _ = find_peaks(ecg_signal, distance=200)
    heart_rate = len(peaks)
    rr_intervals = np.diff(peaks) / 250
    mean_rr = np.mean(rr_intervals) if len(rr_intervals) > 0 else 0
    std_rr = np.std(rr_intervals) if len(rr_intervals) > 0 else 0

    # Frequency-domain features
    freqs, psd = welch(ecg_signal, fs=250)
    dominant_freq = freqs[np.argmax(psd)]
    spectral_entropy = entropy(psd)

    return [mean_val, std_dev, skewness, kurtosis, peak_to_peak, heart_rate, mean_rr, std_rr, dominant_freq, spectral_entropy]

In [None]:
# Load Data
df_labels = pd.read_csv("Decoded_Diagnosis.csv")
label_dict = dict(zip(df_labels["Record Name"], df_labels["Diagnosis"]))

X, y = [], []
for root, _, files in os.walk(filtered_dataset_path):
    for file in files:
        if file.endswith(".mat"):
            record_name = file.replace(".mat", "")
            if record_name in label_dict:
                mat_path = os.path.join(root, file)
                features = extract_ecg_features(mat_path)
                if features:
                    X.append(features)
                    y.append(label_dict[record_name])

In [8]:
# Encode labels and split
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded)

# Balance dataset
smote_tomek = SMOTETomek(random_state=42)
X_train_balanced, y_train_balanced = smote_tomek.fit_resample(X_train, y_train)

# Scale features
scaler = StandardScaler()
X_train_balanced = scaler.fit_transform(X_train_balanced)
X_test = scaler.transform(X_test)

# Train and evaluate model with hyperparameter tuning
param_grid = {'n_estimators': [100, 200], 'max_depth': [10, None]}
grid_search = GridSearchCV(RandomForestClassifier(random_state=42), param_grid, cv=5)
grid_search.fit(X_train_balanced, y_train_balanced)
model = grid_search.best_estimator_

y_pred = model.predict(X_test)
print(f"Model Accuracy: {accuracy_score(y_test, y_pred):.2f}")
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))

Model Accuracy: 0.86
Confusion Matrix:
 [[ 19  10   3   0]
 [ 17 140   7   0]
 [  2   4  91   2]
 [  0   2   0  44]]


In [9]:
import joblib
joblib.dump(model, "model.pkl")

['model.pkl']

In [12]:
joblib.dump(scaler, "scaler.pkl")

['scaler.pkl']

In [11]:
label_mapping = dict(zip(y, y_encoded))
print("Label Mapping:", label_mapping)

Label Mapping: {'Sinus Bradycardia': np.int64(1), 'Sinus Rhythm': np.int64(2), 'Sinus Tachycardia': np.int64(3), 'Atrial Fibrillation': np.int64(0)}


In [None]:
print(f"Sampling Frequency: {record.fs} Hz")
print(f"Number of Leads: {record.n_sig}")
print(f"Lead Names: {record.sig_name}")
print(f"Number of Samples: {record.sig_len}")
print(f"Signal Shape: {record.p_signal.shape}")
print(f"ADC Gain (per lead): {record.adc_gain}")
print(f"Baseline Values: {record.baseline}")
print(f"File Name: {record.file_name}")
print(f"Comments (Patient Info): {record.comments}")

# Sampling Frequency: 500 Hz
# Number of Leads: 12
# Lead Names: ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
# Number of Samples: 5000
# Signal Shape: (5000, 12)
# ADC Gain (per lead): [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0]
# Baseline Values: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# File Name: ['JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat', 'JS00001.mat']
# Comments (Patient Info): ['Age: 85', 'Sex: Male', 'Dx: 164889003,59118001,164934002', 'Rx: Unknown', 'Hx: Unknown', 'Sx: Unknown']