In [1]:
print('simpal')

simpal


In [9]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report


In [10]:
BASE_DIR = "ecg_images"     # already exists
CLASS_IDS = list(range(1, 10))  # class_1 to class_9
NUM_CLASSES = 9

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 30


In [11]:
class_files = {}

for cls in CLASS_IDS:
    class_dir = os.path.join(BASE_DIR, f"class_{cls}")
    files = sorted(os.listdir(class_dir))   # FIRST images
    class_files[cls] = files

min_count = min(len(files) for files in class_files.values())
print("Using FIRST", min_count, "images per class")


Using FIRST 1229 images per class


In [12]:
image_paths = []
labels = []

for cls in CLASS_IDS:
    class_dir = os.path.join(BASE_DIR, f"class_{cls}")
    selected_files = class_files[cls][:min_count]

    for fname in selected_files:
        image_paths.append(os.path.join(class_dir, fname))
        labels.append(cls - 1)   # map 1–9 → 0–8

image_paths = np.array(image_paths)
labels = np.array(labels)


In [13]:
X_train, X_temp, y_train, y_temp = train_test_split(
    image_paths, labels,
    test_size=0.3,
    shuffle=False
)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp,
    test_size=0.5,
    shuffle=False
)


In [14]:
def load_image(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.keras.applications.efficientnet.preprocess_input(img)
    return img, label

train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_ds = train_ds.map(load_image).batch(BATCH_SIZE)

val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_ds = val_ds.map(load_image).batch(BATCH_SIZE)

test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_ds = test_ds.map(load_image).batch(BATCH_SIZE)


In [15]:
base_model = EfficientNetB0(
    include_top=False,
    weights="imagenet",
    input_shape=(224, 224, 3)
)

base_model.trainable = False   # observation run

inputs = layers.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)

model = models.Model(inputs, outputs)


In [16]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor="val_loss",
            patience=5,
            restore_best_weights=True
        )
    ]
)


Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30


<keras.callbacks.History at 0x20f31dca940>

In [17]:
y_true = []
y_pred = []

for images, labels_batch in test_ds:
    preds = model.predict(images, verbose=0)
    y_pred.extend(np.argmax(preds, axis=1))
    y_true.extend(labels_batch.numpy())

print(classification_report(
    y_true,
    y_pred,
    labels=list(range(9)),
    digits=2
))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       0.00      0.00      0.00         0
           2       0.00      0.00      0.00         0
           3       0.00      0.00      0.00         0
           4       0.00      0.00      0.00         0
           5       0.00      0.00      0.00         0
           6       0.00      0.00      0.00         0
           7       0.10      0.00      0.00       431
           8       0.60      0.02      0.05      1229

    accuracy                           0.02      1660
   macro avg       0.08      0.00      0.01      1660
weighted avg       0.47      0.02      0.03      1660



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
