In [7]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.preprocessing import image_dataset_from_directory
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np


In [8]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 20
NUM_CLASSES = 9

In [9]:
train_ds = image_dataset_from_directory(
    "ecg_images_split/train",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode="int"
)

val_ds = image_dataset_from_directory(
    "ecg_images_split/val",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode="int",
    shuffle=False
)

test_ds = image_dataset_from_directory(
    "ecg_images_split/test",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    label_mode="int",
    shuffle=False
)

class_names = train_ds.class_names


Found 29850 files belonging to 9 classes.
Found 6394 files belonging to 9 classes.
Found 6406 files belonging to 9 classes.


In [10]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(AUTOTUNE)
val_ds = val_ds.cache().prefetch(AUTOTUNE)
test_ds = test_ds.cache().prefetch(AUTOTUNE)

In [11]:
labels = np.concatenate([y for x, y in train_ds], axis=0)

class_counts = np.bincount(labels)
class_weights = {i: 1.0 / class_counts[i] for i in range(NUM_CLASSES)}


In [12]:
base_model = EfficientNetB0(
    include_top=False,
    weights="imagenet",
    input_shape=(224, 224, 3)
)

base_model.trainable = False  # freeze backbone (first stage)

inputs = layers.Input(shape=(224, 224, 3))

# Preprocess input for EfficientNet
x = tf.keras.applications.efficientnet.preprocess_input(inputs)

# Feature extractor
x = base_model(x, training=False)

# Classification head
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)

model = models.Model(inputs, outputs)


In [9]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)


In [10]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weights
)


Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [11]:
y_true = []
y_pred = []

for images, labels in test_ds:
    preds = model.predict(images)
    y_pred.extend(np.argmax(preds, axis=1))
    y_true.extend(labels.numpy())

print(classification_report(y_true, y_pred, digits=4))
print(confusion_matrix(y_true, y_pred))


              precision    recall  f1-score   support

           0     0.4301    0.4504    0.4400       826
           1     0.5138    0.5148    0.5143       977
           2     0.4216    0.4978    0.4566       681
           3     0.3247    0.8109    0.4637       201
           4     0.7361    0.5080    0.6011      1559
           5     0.1859    0.1436    0.1621       550
           6     0.3457    0.2536    0.2926       623
           7     0.3957    0.4154    0.4053       804
           8     0.1828    0.4486    0.2598       185

    accuracy                         0.4407      6406
   macro avg     0.3929    0.4493    0.3995      6406
weighted avg     0.4725    0.4407    0.4453      6406

[[372  34  63  21  45  52  24  93 122]
 [ 42 503  78  57  58  70  62  77  30]
 [ 46  80 339  34  32  25  31  45  49]
 [  1   9  12 163   4   2   6   1   3]
 [133 117  96 114 792  75  55 111  66]
 [ 70  69  69  25  40  79  60  92  46]
 [ 74  70  65  45  49  51 158  81  30]
 [103  90  66  33  38 

In [None]:
# model.save("efficientnet_stockwell_ecg_tf")


In [14]:
from sklearn.metrics import classification_report, confusion_matrix

# Force labels 0â€“8 so report looks exactly like your example
print(
    classification_report(
        y_true,
        y_pred,
        labels=list(range(9)),   # classes 0 to 8
        digits=2
    )
)



              precision    recall  f1-score   support

           0       0.43      0.45      0.44       826
           1       0.51      0.51      0.51       977
           2       0.42      0.50      0.46       681
           3       0.32      0.81      0.46       201
           4       0.74      0.51      0.60      1559
           5       0.19      0.14      0.16       550
           6       0.35      0.25      0.29       623
           7       0.40      0.42      0.41       804
           8       0.18      0.45      0.26       185

    accuracy                           0.44      6406
   macro avg       0.39      0.45      0.40      6406
weighted avg       0.47      0.44      0.45      6406



### Finetune model

In [13]:
# Unfreeze the base model
base_model.trainable = True

# Freeze early layers, train deeper layers only
for layer in base_model.layers[:200]:
    layer.trainable = False


In [14]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)


In [15]:
FINE_TUNE_EPOCHS = 15

history_ft = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=FINE_TUNE_EPOCHS,
    class_weight=class_weights
)


Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


In [16]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

y_true = []
y_pred = []

for images, labels in test_ds:
    preds = model.predict(images, verbose=0)
    y_pred.extend(np.argmax(preds, axis=1))
    y_true.extend(labels.numpy())

print(classification_report(
    y_true,
    y_pred,
    labels=list(range(9)),
    digits=2
))

print(confusion_matrix(y_true, y_pred))


              precision    recall  f1-score   support

           0       0.43      0.54      0.48       826
           1       0.51      0.62      0.56       977
           2       0.50      0.50      0.50       681
           3       0.40      0.83      0.54       201
           4       0.75      0.53      0.62      1559
           5       0.24      0.13      0.17       550
           6       0.41      0.27      0.33       623
           7       0.41      0.48      0.44       804
           8       0.23      0.45      0.31       185

    accuracy                           0.48      6406
   macro avg       0.43      0.48      0.44      6406
weighted avg       0.50      0.48      0.48      6406

[[442  36  45  24  45  24  16 118  76]
 [ 49 609  54  45  47  37  37  73  26]
 [ 62  92 338  28  24  17  42  48  30]
 [  1  11   5 166   3   3   4   4   4]
 [156 132  85  66 822  53  52 124  69]
 [ 91  97  49  26  40  72  51  92  32]
 [ 81  88  50  33  49  43 168  86  25]
 [113 110  42  22  43 

### Multi-window voting

In [17]:
import os
import numpy as np
from collections import defaultdict
from sklearn.metrics import classification_report, accuracy_score


In [18]:
test_dir = r"D:\ECG_model\ecg_images_split\test"
IMG_SIZE = (224, 224)

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    test_dir,
    image_size=IMG_SIZE,
    batch_size=32,
    shuffle=False
)


Found 6406 files belonging to 9 classes.


In [19]:
y_pred = []
y_true = []
file_paths = []

for batch_images, batch_labels in test_ds:
    preds = model.predict(batch_images, verbose=0)
    y_pred.extend(np.argmax(preds, axis=1))
    y_true.extend(batch_labels.numpy())

# get filenames
file_paths = test_ds.file_paths


In [20]:
ecg_preds = defaultdict(list)
ecg_true = {}

for path, pred, true in zip(file_paths, y_pred, y_true):
    # Extract ECG ID: A0001 from A0001_win3.png
    ecg_id = os.path.basename(path).split("_win")[0]

    ecg_preds[ecg_id].append(pred)
    ecg_true[ecg_id] = true


In [21]:
final_preds = []
final_true = []

for ecg_id in ecg_preds:
    preds = ecg_preds[ecg_id]
    voted_pred = max(set(preds), key=preds.count)

    final_preds.append(voted_pred)
    final_true.append(ecg_true[ecg_id])


In [None]:
print("ECG-Level Accuracy:",
      accuracy_score(final_true, final_preds))

print("\nECG-Level Classification Report:\n")
print(classification_report(
    final_true,
    final_preds,
    labels=list(range(9)),
    digits=2
))

ECG-Level Accuracy: 0.495886804870023

ECG-Level Classification Report:

              precision    recall  f1-score   support

           0       0.42      0.62      0.50       396
           1       0.50      0.70      0.59       472
           2       0.53      0.52      0.53       322
           3       0.39      0.83      0.53        94
           4       0.78      0.51      0.62       748
           5       0.28      0.12      0.16       266
           6       0.49      0.23      0.31       282
           7       0.46      0.42      0.44       372
           8       0.23      0.54      0.32        87

    accuracy                           0.50      3039
   macro avg       0.45      0.50      0.45      3039
weighted avg       0.53      0.50      0.49      3039

