In [15]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

print("TensorFlow:", tf.__version__)
print("GPU:", tf.config.list_physical_devices("GPU"))


TensorFlow: 2.10.0
GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [16]:
TRAIN_DIR = Path(r"D:\ECG_model\train")
VAL_DIR   = Path(r"D:\ECG_model\validation")
TEST_DIR  = Path(r"D:\ECG_model\test")

IMG_SIZE = (224, 224)
BATCH_SIZE = 16


In [17]:
train_set = tf.keras.utils.image_dataset_from_directory(
    TRAIN_DIR,
    label_mode="categorical",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_set = tf.keras.utils.image_dataset_from_directory(
    VAL_DIR,
    label_mode="categorical",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True
)

test_set = tf.keras.utils.image_dataset_from_directory(
    TEST_DIR,
    label_mode="categorical",
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False
)

NUM_CLASSES = train_set.element_spec[1].shape[-1]
CLASS_NAMES = train_set.class_names

print("Classes:", CLASS_NAMES)


Found 3528 files belonging to 7 classes.
Found 336 files belonging to 7 classes.
Found 144 files belonging to 7 classes.
Classes: ['Atrial_fibrillation(AF)', 'LBBB', 'Normal', 'PAC', 'PVC', 'RBBB', 'STD']


In [18]:
norm = tf.keras.layers.Rescaling(1./255)

train_ds = train_set.map(lambda x,y: (norm(x), y))
val_ds   = val_set.map(lambda x,y: (norm(x), y))
test_ds  = test_set.map(lambda x,y: (norm(x), y))

train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds   = val_ds.prefetch(tf.data.AUTOTUNE)
test_ds  = test_ds.prefetch(tf.data.AUTOTUNE)


In [19]:
def focal_loss(gamma=2.0, alpha=0.25):
    def loss_fn(y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        cross_entropy = -y_true * tf.math.log(y_pred)
        weight = alpha * tf.pow(1 - y_pred, gamma)
        return tf.reduce_sum(weight * cross_entropy, axis=1)
    return loss_fn


In [None]:
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

base_model = EfficientNetB0(
    weights="imagenet",
    include_top=False,
    input_shape=(224,224,3)
)

for layer in base_model.layers:
    layer.trainable = True

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation="relu")(x)
x = Dropout(0.4)(x)
outputs = Dense(NUM_CLASSES, activation="softmax")(x)

model = Model(base_model.input, outputs)

model.compile(
    optimizer=Adam(learning_rate=1e-6),
    loss=focal_loss(gamma=2.0, alpha=0.25),
    metrics=["accuracy"]
)

model.summary()


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling_4 (Rescaling)        (None, 224, 224, 3)  0           ['input_2[0][0]']                
                                                                                                  
 normalization_1 (Normalization  (None, 224, 224, 3)  7          ['rescaling_4[0][0]']            
 )                                                                                                
                                                                                            

In [21]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        "efficientnet_final.h5",
        monitor="val_loss",
        save_best_only=True,
        verbose=1
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=15,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    )
]


In [None]:
_ = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=80,     # upper limit, early stopping will stop earlier
    callbacks=callbacks
)
print("Done")

Epoch 1/80
  1/221 [..............................] - ETA: 2:04 - loss: 0.3508 - accuracy: 0.1875

ResourceExhaustedError: Graph execution error:

Detected at node 'model_1/top_bn/FusedBatchNormV3' defined at (most recent call last):
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\runpy.py", line 194, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
      app.start()
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
      self.io_loop.start()
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\tornado\platform\asyncio.py", line 205, in start
      self.asyncio_loop.run_forever()
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\asyncio\base_events.py", line 570, in run_forever
      self._run_once()
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\asyncio\base_events.py", line 1859, in _run_once
      handle._run()
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\asyncio\events.py", line 81, in _run
      self._context.run(self._callback, *self._args)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
      await result
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
      await super().execute_request(stream, ident, parent)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
      res = shell.run_cell(
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\IPython\core\interactiveshell.py", line 3009, in run_cell
      result = self._run_cell(
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\IPython\core\interactiveshell.py", line 3064, in _run_cell
      result = runner(coro)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\IPython\core\interactiveshell.py", line 3269, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\IPython\core\interactiveshell.py", line 3448, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\DELL\AppData\Local\Temp\ipykernel_21596\2891836108.py", line 1, in <module>
      _ = model.fit(
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\layers\normalization\batch_normalization.py", line 850, in call
      outputs = self._fused_batch_norm(inputs, training=training)
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\layers\normalization\batch_normalization.py", line 660, in _fused_batch_norm
      output, mean, variance = control_flow_util.smart_cond(
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\utils\control_flow_util.py", line 108, in smart_cond
      return tf.__internal__.smart_cond.smart_cond(
    File "c:\Users\DELL\miniconda3\envs\tensorflow_env\lib\site-packages\keras\layers\normalization\batch_normalization.py", line 634, in _fused_batch_norm_training
      return tf.compat.v1.nn.fused_batch_norm(
Node: 'model_1/top_bn/FusedBatchNormV3'
OOM when allocating tensor with shape[16,1280,7,7] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model_1/top_bn/FusedBatchNormV3}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_36961]