In [2]:
import numpy as np
import wfdb
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Bidirectional, LSTM, Dense, Dropout, Layer
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
def load_and_prep_data(db_dir: Path, signal_length: int = 1024, records_to_process: list = None):
    """
    Memuat data, normalisasi, mapping label, dan otomatis unduh.
    (VERSI PERBAIKAN FINAL DENGAN LOGIKA AFIB YANG LEBIH KUAT)
    """
    if not db_dir.is_dir():
        print(f"Direktori '{db_dir}' tidak ditemukan. Memulai unduhan...")
        try:
            wfdb.dl_database('mitdb', dl_dir=db_dir)
            print("Unduhan selesai.")
        except Exception as e:
            print(f"Gagal mengunduh database: {e}")
            return np.array([]), np.array([])

    print(f"Memulai proses data dari '{db_dir}'...")
    if records_to_process is None: records_to_process = wfdb.get_record_list('mitdb')
    
    X, y = [], []
    label_map = {'N': 'Normal','L': 'Normal','R': 'Normal','V': 'PVC','E': 'PVC','A': 'Other','a': 'Other','J': 'Other','S': 'Other','F': 'Other','/': 'Other','f': 'Other','Q': 'Other'}
    
    for rec_name in records_to_process:
        print(f"  -> Memproses record: {rec_name}")
        try:
            record = wfdb.rdrecord(str(db_dir / rec_name))
            annotation = wfdb.rdann(str(db_dir / rec_name), 'atr')
            signal = record.p_signal[:, 1]
            
            # --- LOGIKA BARU YANG LEBIH ROBUST ---
            # 1. Buat dictionary event perubahan irama: {sample_index: rhythm_note}
            rhythm_change_indices = np.where(annotation.symbol == '+')[0]
            rhythm_events = {annotation.sample[i]: annotation.aux_note[i].strip('\x00') for i in rhythm_change_indices}
            rhythm_event_samples = sorted(rhythm_events.keys())

            # 2. Iterasi setiap detak jantung
            for i, loc in enumerate(annotation.sample):
                symbol = annotation.symbol[i]
                if symbol not in label_map: continue

                # Cari irama yang berlaku untuk detak jantung saat ini
                current_rhythm = '(N' # Default rhythm
                for event_sample in rhythm_event_samples:
                    if loc >= event_sample:
                        current_rhythm = rhythm_events[event_sample]
                    else:
                        break
                
                # Tentukan label final
                final_label = ''
                if '(AFIB' in current_rhythm:
                    final_label = 'AF'
                else:
                    final_label = label_map.get(symbol)
                
                # Proses dan simpan segmen jika label valid
                if final_label:
                    half_len = signal_length // 2
                    if loc > half_len and loc < len(signal) - half_len:
                        segment = signal[loc - half_len : loc + half_len]
                        if len(segment) == signal_length:
                           segment = (segment - np.mean(segment)) / np.std(segment)
                           X.append(segment)
                           y.append(final_label)
        except Exception as e:
            print(f"    Gagal memproses {rec_name}: {e}")

    print("Pemuatan dan persiapan data selesai.")
    if len(np.unique(y)) > 0: print(f"Label yang berhasil diekstrak: {np.unique(y)}")
    return np.array(X), np.array(y)

In [3]:
class Attention(Layer):
    """Custom Attention Layer."""
    def __init__(self, **kwargs): super(Attention, self).__init__(**kwargs)
    def build(self, input_shape):
        self.W = self.add_weight(name="att_weight", shape=(input_shape[-1], 1), initializer="normal")
        self.b = self.add_weight(name="att_bias", shape=(input_shape[1], 1), initializer="zeros")
        super(Attention, self).build(input_shape)
    def call(self, x):
        et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=-1)
        at = K.softmax(et)
        at = K.expand_dims(at, axis=-1)
        output = x * at
        return K.sum(output, axis=1)
    def compute_output_shape(self, input_shape): return (input_shape[0], input_shape[-1])

def build_bilstm_attention_model(input_shape, num_classes):
    """Membangun model Keras dengan arsitektur Bi-LSTM + Attention + Dense."""
    print("\nMembangun model Klasifikasi Bi-LSTM Attention...")
    inputs = Input(shape=input_shape)
    x = Bidirectional(LSTM(64, return_sequences=True))(inputs)
    x = Dropout(0.3)(x)
    x = Attention()(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    print("Model Klasifikasi berhasil dibangun.")
    return model

In [4]:
SIGNAL_LENGTH = 1024
DB_DIR = Path('mitdb')

In [5]:
X_data, y_labels = load_and_prep_data(DB_DIR, SIGNAL_LENGTH, None)

Memulai proses data dari 'mitdb'...
  -> Memproses record: 100
  -> Memproses record: 101
  -> Memproses record: 102
  -> Memproses record: 103
  -> Memproses record: 104
  -> Memproses record: 105
  -> Memproses record: 106
  -> Memproses record: 107
  -> Memproses record: 108
  -> Memproses record: 109
  -> Memproses record: 111
  -> Memproses record: 112
  -> Memproses record: 113
  -> Memproses record: 114
  -> Memproses record: 115
  -> Memproses record: 116
  -> Memproses record: 117
  -> Memproses record: 118
  -> Memproses record: 119
  -> Memproses record: 121
  -> Memproses record: 122
  -> Memproses record: 123
  -> Memproses record: 124
  -> Memproses record: 200
  -> Memproses record: 201
  -> Memproses record: 202
  -> Memproses record: 203
  -> Memproses record: 205
  -> Memproses record: 207
  -> Memproses record: 208
  -> Memproses record: 209
  -> Memproses record: 210
  -> Memproses record: 212
  -> Memproses record: 213
  -> Memproses record: 214
  -> Memproses reco

In [6]:
print(X_data.shape[0] < 10)

False


In [7]:
X_data = X_data[..., np.newaxis]
print("\nMenyiapkan data untuk klasifikasi...")
le = LabelEncoder()
y_encoded = le.fit_transform(y_labels)
num_classes = len(le.classes_)
y_one_hot = to_categorical(y_encoded, num_classes=num_classes)
print(f"Kelas ditemukan ({num_classes}): {le.classes_}")


Menyiapkan data untuk klasifikasi...
Kelas ditemukan (3): ['Normal' 'Other' 'PVC']


In [8]:
class_weights = compute_class_weight('balanced', classes=np.unique(y_encoded), y=y_encoded)
class_weights_dict = dict(enumerate(class_weights))
print(f"Class Weights yang akan digunakan: {class_weights_dict}")

Class Weights yang akan digunakan: {0: 0.4029226092192216, 1: 3.131725523973586, 2: 5.029649098538295}


In [9]:
X_train, X_val, y_train, y_val = train_test_split(
    X_data, y_one_hot, test_size=0.2, random_state=42, stratify=y_encoded
)
print(f"Ukuran Data Latih: {X_train.shape}")
print(f"Ukuran Data Validasi: {X_val.shape}")

Ukuran Data Latih: (87262, 1024, 1)
Ukuran Data Validasi: (21816, 1024, 1)


In [10]:
classifier_model = build_bilstm_attention_model(
    input_shape=(SIGNAL_LENGTH, 1), 
    num_classes=num_classes
)
classifier_model.summary()


Membangun model Klasifikasi Bi-LSTM Attention...
Model Klasifikasi berhasil dibangun.
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 1024, 1)]         0         
                                                                 
 bidirectional (Bidirectiona  (None, 1024, 128)        33792     
 l)                                                              
                                                                 
 dropout (Dropout)           (None, 1024, 128)         0         
                                                                 
 attention (Attention)       (None, 128)               1152      
                                                                 
 dense (Dense)               (None, 64)                8256      
                                                                 
 dropout_1 (Dropout)         (None, 64) 

In [11]:
model_checkpoint = ModelCheckpoint(
    filepath='model_terbaik.keras',
    save_best_only=True, 
    monitor='val_loss', 
    mode='min', 
    verbose=1
)
early_stopping = EarlyStopping(
    monitor='val_loss', 
    patience=5, 
    restore_best_weights=True, 
    verbose=1
)

In [12]:
print("\nMelatih model Klasifikasi dengan Class Weighting...")
history = classifier_model.fit(
    X_train, y_train,
    epochs=25,
    batch_size=64,
    validation_data=(X_val, y_val),
    verbose=1,
    callbacks=[early_stopping, model_checkpoint],
    class_weight=class_weights_dict
)
print("\nPipeline klasifikasi selesai!")


Melatih model Klasifikasi dengan Class Weighting...
Epoch 1/25
  26/1364 [..............................] - ETA: 28:20 - loss: 1.0530 - accuracy: 0.3456

KeyboardInterrupt: 

In [None]:
print("\n--- Mengevaluasi Model pada Data Validasi ---")
# Buat prediksi
y_pred_probs = classifier_model.predict(X_val)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = np.argmax(y_val, axis=1)

In [None]:
class_names = le.classes_

In [None]:
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))


In [None]:
print("\nConfusion Matrix:")
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('Label Sebenarnya (True Label)')
plt.xlabel('Label Prediksi (Predicted Label)')
plt.show()

In [None]:
print("\nGrafik Performa Training:")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

In [None]:
# Plot akurasi
ax1.plot(history.history['accuracy'], label='Training Accuracy')
ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax1.set_title('Akurasi Model')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()

In [None]:
# Plot loss
ax2.plot(history.history['loss'], label='Training Loss')
ax2.plot(history.history['val_loss'], label='Validation Loss')
ax2.set_title('Loss Model')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()

In [None]:
plt.show()