# A Tutorial on Model Training in Python with Keras

In [1]:
# cell 1: 永久性路径配置
import sys
from pathlib import Path

def add_project_root():
    """add python project root to sys.path"""
    notebook_path = Path().absolute()  # Get the absolute path of the notebook
    project_root = notebook_path.parent  # Go up two levels to the project root

    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))  # search project root first
        print(f"✅ project root added: {project_root}")
    else:
        print("⏩ project root already exists")
    
    return project_root

project_root = add_project_root()


✅ project root added: e:\Data Files\Projects\Embedded Systems\f103_demo_vscode


In [2]:
from data_processing import create_windows
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from model import MyCNNModel
import keras


## 1. Pre-process data

In [3]:
import yaml
with open((project_root / 'config.yaml'), 'r') as file:
    config = yaml.safe_load(file)
# Load data
ds_config = config['datasets'][config['active_dataset']]
print(f"Loading dataset from {ds_config['path']}")
data = pd.read_csv(project_root / ds_config['path'])
data_columns = ds_config['features']
norm_windows, labels, _ = create_windows(
    data,
    window_length=config['data_processing']['window']['length'],
    step_size=config['data_processing']['window']['step_size'],
    overlap=None,
    feature_columns=data_columns,
    label_column=ds_config['label_column'],
)
print(norm_windows[0])
print(norm_windows.shape)

Loading dataset from ./data/mafaulda19.csv
[[0.1268389  0.50210248 0.42535302 0.49732183 0.45425745 0.51582176
  0.34823428]
 [0.12390014 0.79841099 0.56313183 0.57897221 0.47493213 0.53154488
  0.36121152]
 [0.13653318 0.38906528 0.36739579 0.56287525 0.45272538 0.52438231
  0.32922548]
 [0.11404636 0.76992723 0.61350561 0.61937753 0.48076297 0.53406649
  0.34672639]
 [0.14071314 0.49854315 0.33360714 0.56348694 0.46513774 0.55279058
  0.32424605]
 [0.13252692 0.65480122 0.58415628 0.52554504 0.47726446 0.59921238
  0.3411011 ]
 [0.11675141 0.64901484 0.39020151 0.42395862 0.46699559 0.58795604
  0.34841398]
 [0.12332849 0.53862245 0.50149669 0.36988451 0.46345459 0.59455294
  0.36118287]
 [0.12170199 0.77000319 0.48721797 0.34031024 0.46977881 0.59797623
  0.38057138]
 [0.13794176 0.44185517 0.39682818 0.30800992 0.45284342 0.59034973
  0.37364967]
 [0.13480719 0.81156668 0.55976619 0.38185251 0.47349449 0.60342638
  0.40327999]
 [0.13193159 0.47252951 0.31846178 0.40229121 0.4571634

## 2. split train, validation and test data

In [4]:
X_train, X_test, y_train, y_test = train_test_split(
    norm_windows, labels, test_size=0.2, random_state=42
)
y_train = keras.utils.to_categorical(y_train, num_classes=len(np.unique(labels)))
y_test = keras.utils.to_categorical(y_test, num_classes=len(np.unique(labels)))

print(X_train.shape, y_train.shape)
model_input_shape = (X_train.shape[1], X_train.shape[2]) # (window_length, num_features)
num_classes = y_train.shape[1]  # Number of classes for one-hot encoding
print(y_train)
print(num_classes)

(71898, 128, 7) (71898, 6)
[[0. 0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 ...
 [0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]]
6


In [5]:
# Model initialization
tf_model = MyCNNModel(input_shape=model_input_shape,
                      num_classes=num_classes)
# Model training
training_config = config['model']['training']
evaluation_config = config['model']['evaluation']
# tf_model.compile(
#     optimizer=keras.optimizers.Adam(learning_rate=training_config['learning_rate']),
#     loss=training_config['loss'],
#     metrics=training_config['metrics'],
#     # learning_rate=training_config['learning_rate'],
#     # batch_size=training_config['batch_size'],
# )

# tf_model.fit(
#     X_train,
#     y_train,
#     validation_split=training_config['validation_split'],
#     epochs=training_config['epochs'],
#     batch_size=training_config['batch_size'],
#     callbacks=[
#         keras.callbacks.EarlyStopping(
#             monitor='val_loss',
#             patience=training_config['early_stopping_patience'],
#             restore_best_weights=True
#         ),
#         keras.callbacks.ModelCheckpoint(
#             filepath=str(project_root / 'model.h5'),
#             save_best_only=True,
#             monitor='val_loss'
#         )
#     ]
# )

tf_model.compile_model(learning_rate=training_config['learning_rate'],
                       loss=training_config['loss'],
                       metrics=training_config['metrics'])
history = tf_model.train(
                        X_train,
                        y_train,
                        validation_split=training_config['validation_split'],
                        epochs=training_config['epochs'],
                        batch_size=training_config['batch_size'],
                        callbacks=[
                            keras.callbacks.EarlyStopping(
                                monitor='val_loss',
                                patience=training_config['early_stopping_patience'],
                                restore_best_weights=True
                            ),
                            keras.callbacks.ModelCheckpoint(
                                filepath=str(project_root / 'model.h5'),
                                save_best_only=True,
                                monitor='val_loss'
                            )
                        ]
                        )
print(history)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 81ms/step - accuracy: 0.2389 - auc: 0.5823 - f1_score: 0.2004 - loss: 1.7542 - precision: 0.3039 - recall: 0.0029  



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 110ms/step - accuracy: 0.2464 - auc: 0.5892 - f1_score: 0.2089 - loss: 1.7468 - precision: 0.3443 - recall: 0.0044 - val_accuracy: 0.3770 - val_auc: 0.7581 - val_f1_score: 0.3159 - val_loss: 1.4907 - val_precision: 0.8769 - val_recall: 0.0728
Epoch 2/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 80ms/step - accuracy: 0.4487 - auc: 0.8019 - f1_score: 0.4189 - loss: 1.3926 - precision: 0.8487 - recall: 0.1276



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 87ms/step - accuracy: 0.4531 - auc: 0.8052 - f1_score: 0.4244 - loss: 1.3841 - precision: 0.8524 - recall: 0.1309 - val_accuracy: 0.5974 - val_auc: 0.8977 - val_f1_score: 0.5988 - val_loss: 1.0919 - val_precision: 0.9240 - val_recall: 0.2469
Epoch 3/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 83ms/step - accuracy: 0.6183 - auc: 0.9092 - f1_score: 0.6209 - loss: 1.0252 - precision: 0.9140 - recall: 0.3171



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 90ms/step - accuracy: 0.6207 - auc: 0.9102 - f1_score: 0.6233 - loss: 1.0193 - precision: 0.9136 - recall: 0.3217 - val_accuracy: 0.7067 - val_auc: 0.9406 - val_f1_score: 0.7146 - val_loss: 0.8291 - val_precision: 0.9170 - val_recall: 0.4427
Epoch 4/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 82ms/step - accuracy: 0.7086 - auc: 0.9438 - f1_score: 0.7131 - loss: 0.8001 - precision: 0.9037 - recall: 0.4772



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 89ms/step - accuracy: 0.7101 - auc: 0.9444 - f1_score: 0.7145 - loss: 0.7966 - precision: 0.9034 - recall: 0.4799 - val_accuracy: 0.7720 - val_auc: 0.9628 - val_f1_score: 0.7757 - val_loss: 0.6694 - val_precision: 0.9096 - val_recall: 0.5812
Epoch 5/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.7761 - auc: 0.9644 - f1_score: 0.7791 - loss: 0.6516 - precision: 0.9071 - recall: 0.6007



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.7771 - auc: 0.9647 - f1_score: 0.7800 - loss: 0.6490 - precision: 0.9072 - recall: 0.6028 - val_accuracy: 0.8187 - val_auc: 0.9753 - val_f1_score: 0.8179 - val_loss: 0.5571 - val_precision: 0.9183 - val_recall: 0.6793
Epoch 6/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.8212 - auc: 0.9766 - f1_score: 0.8227 - loss: 0.5455 - precision: 0.9118 - recall: 0.6914



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.8217 - auc: 0.9767 - f1_score: 0.8233 - loss: 0.5435 - precision: 0.9117 - recall: 0.6931 - val_accuracy: 0.8096 - val_auc: 0.9769 - val_f1_score: 0.8062 - val_loss: 0.5145 - val_precision: 0.8783 - val_recall: 0.7327
Epoch 7/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 83ms/step - accuracy: 0.8241 - auc: 0.9796 - f1_score: 0.8249 - loss: 0.4888 - precision: 0.8910 - recall: 0.7476



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.8251 - auc: 0.9798 - f1_score: 0.8260 - loss: 0.4870 - precision: 0.8916 - recall: 0.7485 - val_accuracy: 0.8421 - val_auc: 0.9825 - val_f1_score: 0.8436 - val_loss: 0.4418 - val_precision: 0.8872 - val_recall: 0.7796
Epoch 8/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.8511 - auc: 0.9840 - f1_score: 0.8522 - loss: 0.4305 - precision: 0.9022 - recall: 0.7816



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.8519 - auc: 0.9841 - f1_score: 0.8530 - loss: 0.4297 - precision: 0.9028 - recall: 0.7825 - val_accuracy: 0.8872 - val_auc: 0.9897 - val_f1_score: 0.8865 - val_loss: 0.3770 - val_precision: 0.9402 - val_recall: 0.8159
Epoch 9/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.8815 - auc: 0.9888 - f1_score: 0.8821 - loss: 0.3789 - precision: 0.9266 - recall: 0.8191



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.8818 - auc: 0.9888 - f1_score: 0.8824 - loss: 0.3783 - precision: 0.9267 - recall: 0.8196 - val_accuracy: 0.9035 - val_auc: 0.9915 - val_f1_score: 0.9028 - val_loss: 0.3401 - val_precision: 0.9439 - val_recall: 0.8411
Epoch 10/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 90ms/step - accuracy: 0.9003 - auc: 0.9910 - f1_score: 0.9008 - loss: 0.3400 - precision: 0.9345 - recall: 0.8459 - val_accuracy: 0.8836 - val_auc: 0.9890 - val_f1_score: 0.8846 - val_loss: 0.3489 - val_precision: 0.9165 - val_recall: 0.8503
Epoch 11/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.8933 - auc: 0.9903 - f1_score: 0.8938 - loss: 0.3364 - precision: 0.9252 - recall: 0.8515



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.8941 - auc: 0.9904 - f1_score: 0.8947 - loss: 0.3350 - precision: 0.9258 - recall: 0.8523 - val_accuracy: 0.9203 - val_auc: 0.9935 - val_f1_score: 0.9196 - val_loss: 0.2927 - val_precision: 0.9479 - val_recall: 0.8752
Epoch 12/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9182 - auc: 0.9929 - f1_score: 0.9184 - loss: 0.2952 - precision: 0.9436 - recall: 0.8791 - val_accuracy: 0.8892 - val_auc: 0.9914 - val_f1_score: 0.8896 - val_loss: 0.3021 - val_precision: 0.9140 - val_recall: 0.8586
Epoch 13/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 85ms/step - accuracy: 0.9138 - auc: 0.9930 - f1_score: 0.9143 - loss: 0.2862 - precision: 0.9373 - recall: 0.8829



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - accuracy: 0.9146 - auc: 0.9931 - f1_score: 0.9152 - loss: 0.2849 - precision: 0.9380 - recall: 0.8839 - val_accuracy: 0.9026 - val_auc: 0.9925 - val_f1_score: 0.9008 - val_loss: 0.2831 - val_precision: 0.9237 - val_recall: 0.8825
Epoch 14/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.9163 - auc: 0.9932 - f1_score: 0.9166 - loss: 0.2729 - precision: 0.9347 - recall: 0.8887



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9168 - auc: 0.9933 - f1_score: 0.9172 - loss: 0.2721 - precision: 0.9353 - recall: 0.8893 - val_accuracy: 0.9166 - val_auc: 0.9940 - val_f1_score: 0.9150 - val_loss: 0.2586 - val_precision: 0.9390 - val_recall: 0.8965
Epoch 15/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.9353 - auc: 0.9952 - f1_score: 0.9353 - loss: 0.2383 - precision: 0.9530 - recall: 0.9112



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9358 - auc: 0.9952 - f1_score: 0.9358 - loss: 0.2378 - precision: 0.9534 - recall: 0.9118 - val_accuracy: 0.9278 - val_auc: 0.9947 - val_f1_score: 0.9279 - val_loss: 0.2429 - val_precision: 0.9428 - val_recall: 0.9083
Epoch 16/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9411 - auc: 0.9955 - f1_score: 0.9413 - loss: 0.2274 - precision: 0.9556 - recall: 0.9204 - val_accuracy: 0.9083 - val_auc: 0.9933 - val_f1_score: 0.9083 - val_loss: 0.2587 - val_precision: 0.9282 - val_recall: 0.8877
Epoch 17/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.9304 - auc: 0.9951 - f1_score: 0.9307 - loss: 0.2295 - precision: 0.9468 - recall: 0.9135



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9310 - auc: 0.9951 - f1_score: 0.9313 - loss: 0.2285 - precision: 0.9472 - recall: 0.9141 - val_accuracy: 0.9356 - val_auc: 0.9956 - val_f1_score: 0.9352 - val_loss: 0.2173 - val_precision: 0.9497 - val_recall: 0.9130
Epoch 18/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 86ms/step - accuracy: 0.9407 - auc: 0.9959 - f1_score: 0.9408 - loss: 0.2102 - precision: 0.9532 - recall: 0.9221



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 93ms/step - accuracy: 0.9412 - auc: 0.9959 - f1_score: 0.9413 - loss: 0.2091 - precision: 0.9537 - recall: 0.9229 - val_accuracy: 0.9526 - val_auc: 0.9968 - val_f1_score: 0.9523 - val_loss: 0.1858 - val_precision: 0.9622 - val_recall: 0.9396
Epoch 19/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 86ms/step - accuracy: 0.9553 - auc: 0.9972 - f1_score: 0.9554 - loss: 0.1810 - precision: 0.9658 - recall: 0.9419



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - accuracy: 0.9554 - auc: 0.9972 - f1_score: 0.9555 - loss: 0.1809 - precision: 0.9659 - recall: 0.9421 - val_accuracy: 0.9533 - val_auc: 0.9973 - val_f1_score: 0.9526 - val_loss: 0.1761 - val_precision: 0.9655 - val_recall: 0.9394
Epoch 20/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.9560 - auc: 0.9974 - f1_score: 0.9559 - loss: 0.1716 - precision: 0.9663 - recall: 0.9437



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9558 - auc: 0.9974 - f1_score: 0.9557 - loss: 0.1718 - precision: 0.9661 - recall: 0.9435 - val_accuracy: 0.9653 - val_auc: 0.9979 - val_f1_score: 0.9650 - val_loss: 0.1563 - val_precision: 0.9739 - val_recall: 0.9546
Epoch 21/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 85ms/step - accuracy: 0.9566 - auc: 0.9975 - f1_score: 0.9565 - loss: 0.1661 - precision: 0.9660 - recall: 0.9442



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9568 - auc: 0.9975 - f1_score: 0.9567 - loss: 0.1657 - precision: 0.9661 - recall: 0.9445 - val_accuracy: 0.9620 - val_auc: 0.9981 - val_f1_score: 0.9616 - val_loss: 0.1523 - val_precision: 0.9717 - val_recall: 0.9506
Epoch 22/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 86ms/step - accuracy: 0.9646 - auc: 0.9979 - f1_score: 0.9645 - loss: 0.1509 - precision: 0.9724 - recall: 0.9541



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 93ms/step - accuracy: 0.9648 - auc: 0.9979 - f1_score: 0.9648 - loss: 0.1503 - precision: 0.9726 - recall: 0.9544 - val_accuracy: 0.9706 - val_auc: 0.9985 - val_f1_score: 0.9703 - val_loss: 0.1335 - val_precision: 0.9774 - val_recall: 0.9614
Epoch 23/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9700 - auc: 0.9986 - f1_score: 0.9700 - loss: 0.1321 - precision: 0.9767 - recall: 0.9615 - val_accuracy: 0.9532 - val_auc: 0.9979 - val_f1_score: 0.9532 - val_loss: 0.1483 - val_precision: 0.9610 - val_recall: 0.9453
Epoch 24/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.9474 - auc: 0.9975 - f1_score: 0.9476 - loss: 0.1597 - precision: 0.9557 - recall: 0.9382



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9470 - auc: 0.9975 - f1_score: 0.9471 - loss: 0.1604 - precision: 0.9553 - recall: 0.9377 - val_accuracy: 0.9680 - val_auc: 0.9987 - val_f1_score: 0.9675 - val_loss: 0.1263 - val_precision: 0.9752 - val_recall: 0.9592
Epoch 25/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 87ms/step - accuracy: 0.9662 - auc: 0.9985 - f1_score: 0.9662 - loss: 0.1316 - precision: 0.9724 - recall: 0.9579



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 95ms/step - accuracy: 0.9666 - auc: 0.9985 - f1_score: 0.9666 - loss: 0.1310 - precision: 0.9728 - recall: 0.9584 - val_accuracy: 0.9703 - val_auc: 0.9987 - val_f1_score: 0.9699 - val_loss: 0.1211 - val_precision: 0.9760 - val_recall: 0.9624
Epoch 26/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 85ms/step - accuracy: 0.9721 - auc: 0.9989 - f1_score: 0.9722 - loss: 0.1162 - precision: 0.9775 - recall: 0.9652



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 93ms/step - accuracy: 0.9723 - auc: 0.9989 - f1_score: 0.9723 - loss: 0.1159 - precision: 0.9776 - recall: 0.9654 - val_accuracy: 0.9732 - val_auc: 0.9988 - val_f1_score: 0.9730 - val_loss: 0.1130 - val_precision: 0.9773 - val_recall: 0.9672
Epoch 27/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 84ms/step - accuracy: 0.9774 - auc: 0.9988 - f1_score: 0.9774 - loss: 0.1108 - precision: 0.9810 - recall: 0.9715



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9774 - auc: 0.9988 - f1_score: 0.9775 - loss: 0.1103 - precision: 0.9811 - recall: 0.9716 - val_accuracy: 0.9760 - val_auc: 0.9990 - val_f1_score: 0.9758 - val_loss: 0.1026 - val_precision: 0.9806 - val_recall: 0.9709
Epoch 28/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9792 - auc: 0.9991 - f1_score: 0.9792 - loss: 0.0993 - precision: 0.9831 - recall: 0.9738 - val_accuracy: 0.9627 - val_auc: 0.9987 - val_f1_score: 0.9623 - val_loss: 0.1150 - val_precision: 0.9684 - val_recall: 0.9566
Epoch 29/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 86ms/step - accuracy: 0.9769 - auc: 0.9991 - f1_score: 0.9768 - loss: 0.0996 - precision: 0.9818 - recall: 0.9720



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - accuracy: 0.9769 - auc: 0.9991 - f1_score: 0.9769 - loss: 0.0995 - precision: 0.9818 - recall: 0.9721 - val_accuracy: 0.9823 - val_auc: 0.9992 - val_f1_score: 0.9820 - val_loss: 0.0902 - val_precision: 0.9856 - val_recall: 0.9770
Epoch 30/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9786 - auc: 0.9992 - f1_score: 0.9786 - loss: 0.0929 - precision: 0.9824 - recall: 0.9741 - val_accuracy: 0.9745 - val_auc: 0.9990 - val_f1_score: 0.9742 - val_loss: 0.0962 - val_precision: 0.9798 - val_recall: 0.9702
Epoch 31/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9809 - auc: 0.9992 - f1_score: 0.9811 - loss: 0.0894 - precision: 0.9839 - recall: 0.9771 - val_accuracy: 0.9691 - val_auc: 0.9989 - val_f1_score: 0.9688 - val_loss: 0.0998 - val_precision: 0.9726 - val_recall: 0.9655
Epoch 32/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 93ms/step - accuracy: 0.9781 - auc: 0.9992 - f1_score: 0.9781 - loss: 0.0883 - precision: 0.9814 - recall: 0.9744 - val_accuracy: 0.9819 - val_auc: 0.9994 - val_f1_score: 0.9817 - val_loss: 0.0785 - val_precision: 0.9854 - val_recall: 0.9793
Epoch 33/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 87ms/step - accuracy: 0.9854 - auc: 0.9994 - f1_score: 0.9855 - loss: 0.0753 - precision: 0.9873 - recall: 0.9821



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - accuracy: 0.9854 - auc: 0.9994 - f1_score: 0.9855 - loss: 0.0753 - precision: 0.9874 - recall: 0.9821 - val_accuracy: 0.9820 - val_auc: 0.9994 - val_f1_score: 0.9818 - val_loss: 0.0760 - val_precision: 0.9856 - val_recall: 0.9785
Epoch 34/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 90ms/step - accuracy: 0.9823 - auc: 0.9993 - f1_score: 0.9823 - loss: 0.0790 - precision: 0.9856 - recall: 0.9792 - val_accuracy: 0.9809 - val_auc: 0.9993 - val_f1_score: 0.9805 - val_loss: 0.0773 - val_precision: 0.9838 - val_recall: 0.9776
Epoch 35/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 85ms/step - accuracy: 0.9843 - auc: 0.9995 - f1_score: 0.9843 - loss: 0.0735 - precision: 0.9869 - recall: 0.9809



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9844 - auc: 0.9995 - f1_score: 0.9844 - loss: 0.0733 - precision: 0.9870 - recall: 0.9811 - val_accuracy: 0.9798 - val_auc: 0.9993 - val_f1_score: 0.9796 - val_loss: 0.0748 - val_precision: 0.9822 - val_recall: 0.9766
Epoch 36/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 93ms/step - accuracy: 0.9863 - auc: 0.9995 - f1_score: 0.9863 - loss: 0.0663 - precision: 0.9885 - recall: 0.9846 - val_accuracy: 0.9794 - val_auc: 0.9993 - val_f1_score: 0.9792 - val_loss: 0.0763 - val_precision: 0.9817 - val_recall: 0.9766
Epoch 37/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9831 - auc: 0.9994 - f1_score: 0.9832 - loss: 0.0718 - precision: 0.9854 - recall: 0.9803 - val_accuracy: 0.9771 - val_auc: 0.9994 - val_f1_score: 0.9765 - val_loss: 0.0791 - val_precision: 0.9805 - val_recall: 0.9715
Epoch 38/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 95ms/step - accuracy: 0.9853 - auc: 0.9995 - f1_score: 0.9853 - loss: 0.0647 - precision: 0.9872 - recall: 0.9830 - val_accuracy: 0.9896 - val_auc: 0.9996 - val_f1_score: 0.9894 - val_loss: 0.0569 - val_precision: 0.9908 - val_recall: 0.9866
Epoch 39/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 93ms/step - accuracy: 0.9887 - auc: 0.9995 - f1_score: 0.9887 - loss: 0.0594 - precision: 0.9901 - recall: 0.9869 - val_accuracy: 0.9787 - val_auc: 0.9995 - val_f1_score: 0.9783 - val_loss: 0.0724 - val_precision: 0.9806 - val_recall: 0.9746
Epoch 40/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9719 - auc: 0.9991 - f1_score: 0.9720 - loss: 0.0867 - precision: 0.9745 - recall: 0.9688 - val_accuracy: 0.9202 - val_auc: 0.9957 - val_f1_score: 0.9208 - val_loss: 0.1929 - val_precision: 0.9246 - val_recall: 0.9168
Epoch 41/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[3



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9859 - auc: 0.9996 - f1_score: 0.9859 - loss: 0.0605 - precision: 0.9873 - recall: 0.9838 - val_accuracy: 0.9896 - val_auc: 0.9997 - val_f1_score: 0.9895 - val_loss: 0.0519 - val_precision: 0.9909 - val_recall: 0.9877
Epoch 43/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 86ms/step - accuracy: 0.9905 - auc: 0.9997 - f1_score: 0.9905 - loss: 0.0517 - precision: 0.9917 - recall: 0.9889



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 93ms/step - accuracy: 0.9906 - auc: 0.9997 - f1_score: 0.9906 - loss: 0.0515 - precision: 0.9918 - recall: 0.9890 - val_accuracy: 0.9902 - val_auc: 0.9998 - val_f1_score: 0.9900 - val_loss: 0.0483 - val_precision: 0.9918 - val_recall: 0.9887
Epoch 44/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 85ms/step - accuracy: 0.9920 - auc: 0.9997 - f1_score: 0.9920 - loss: 0.0466 - precision: 0.9928 - recall: 0.9909



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9919 - auc: 0.9997 - f1_score: 0.9919 - loss: 0.0467 - precision: 0.9927 - recall: 0.9908 - val_accuracy: 0.9926 - val_auc: 0.9998 - val_f1_score: 0.9925 - val_loss: 0.0424 - val_precision: 0.9934 - val_recall: 0.9914
Epoch 45/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - accuracy: 0.9939 - auc: 0.9998 - f1_score: 0.9939 - loss: 0.0408 - precision: 0.9946 - recall: 0.9928 - val_accuracy: 0.9909 - val_auc: 0.9998 - val_f1_score: 0.9907 - val_loss: 0.0444 - val_precision: 0.9917 - val_recall: 0.9892
Epoch 46/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 88ms/step - accuracy: 0.9931 - auc: 0.9998 - f1_score: 0.9932 - loss: 0.0428 - precision: 0.9940 - recall: 0.9921



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 96ms/step - accuracy: 0.9932 - auc: 0.9998 - f1_score: 0.9932 - loss: 0.0427 - precision: 0.9941 - recall: 0.9922 - val_accuracy: 0.9933 - val_auc: 0.9998 - val_f1_score: 0.9932 - val_loss: 0.0414 - val_precision: 0.9942 - val_recall: 0.9928
Epoch 47/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 87ms/step - accuracy: 0.9942 - auc: 0.9998 - f1_score: 0.9942 - loss: 0.0396 - precision: 0.9949 - recall: 0.9935



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - accuracy: 0.9942 - auc: 0.9998 - f1_score: 0.9942 - loss: 0.0397 - precision: 0.9948 - recall: 0.9935 - val_accuracy: 0.9952 - val_auc: 0.9999 - val_f1_score: 0.9951 - val_loss: 0.0364 - val_precision: 0.9955 - val_recall: 0.9942
Epoch 48/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 87ms/step - accuracy: 0.9933 - auc: 0.9997 - f1_score: 0.9933 - loss: 0.0402 - precision: 0.9941 - recall: 0.9927



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - accuracy: 0.9934 - auc: 0.9997 - f1_score: 0.9934 - loss: 0.0401 - precision: 0.9941 - recall: 0.9927 - val_accuracy: 0.9944 - val_auc: 0.9999 - val_f1_score: 0.9943 - val_loss: 0.0348 - val_precision: 0.9948 - val_recall: 0.9932
Epoch 49/50
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - accuracy: 0.9948 - auc: 0.9998 - f1_score: 0.9948 - loss: 0.0357 - precision: 0.9956 - recall: 0.9942 - val_accuracy: 0.9941 - val_auc: 0.9999 - val_f1_score: 0.9940 - val_loss: 0.0363 - val_precision: 0.9945 - val_recall: 0.9933
Epoch 50/50
[1m19/20[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 85ms/step - accuracy: 0.9953 - auc: 0.9998 - f1_score: 0.9953 - loss: 0.0350 - precision: 0.9958 - recall: 0.9949



[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - accuracy: 0.9954 - auc: 0.9998 - f1_score: 0.9954 - loss: 0.0348 - precision: 0.9958 - recall: 0.9950 - val_accuracy: 0.9944 - val_auc: 0.9999 - val_f1_score: 0.9943 - val_loss: 0.0337 - val_precision: 0.9950 - val_recall: 0.9935
<keras.src.callbacks.history.History object at 0x000002306A7E3D30>


## convert to tensorflow lite

In [12]:
import tensorflow as tf
# evaluate the model
history = tf_model.evaluate(X_test, y_test)
print(history)

# Save the model as SavedModel format
tf_model.export('saved_model_dir', save_format='tf')

# Convert to TFLite model
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_dir')
converter.optimizations = [tf.lite.Optimize.DEFAULT]

def representative_data_gen():
    for input_value in tf.data.Dataset.from_tensor_slices(X_test).batch(1).take(100):
        yield [tf.cast(input_value, tf.float32)]

# Optional: Provide a representative dataset to enable full integer quantization (better for STM32)
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

tflite_model = converter.convert()

# Save as .tflite file
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

print("TFLite model is saved as: model.tflite")



[1m562/562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - accuracy: 0.9959 - auc: 0.9998 - f1_score: 0.9959 - loss: 0.0350 - precision: 0.9967 - recall: 0.9956
[0.03541061282157898, 0.995326817035675, 0.9960452318191528, 0.9948261380195618, <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([0.99449164, 0.99864036, 0.99284315, 0.9994994 , 0.9903515 ,
       0.99639106], dtype=float32)>, 0.999830961227417]
INFO:tensorflow:Assets written to: saved_model_dir\assets


INFO:tensorflow:Assets written to: saved_model_dir\assets


Saved artifact at 'saved_model_dir'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 128, 7), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(None, 6), dtype=tf.float32, name=None)
Captures:
  2406968851296: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2406968850768: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2406968851472: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2406981071120: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2406968845840: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2406981067072: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2406981068480: TensorSpec(shape=(), dtype=tf.resource, name=None)
  2406981071824: TensorSpec(shape=(), dtype=tf.resource, name=None)
TFLite model is saved as: model.tflite


### record quantization parameters

In [14]:

# load the TFLite model and print input/output quantization parameters
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Print input quantization parameters
print("Input quantization params:")
print(" scale:", input_details[0]['quantization'][0])
print(" zero_point:", input_details[0]['quantization'][1])

# Print output quantization parameters
print("Output quantization params:")
print(" scale:", output_details[0]['quantization'][0])
print(" zero_point:", output_details[0]['quantization'][1])

# write the quantization to yaml file
quantization_params = {
    'input_scale': input_details[0]['quantization'][0],
    'input_zero_point': input_details[0]['quantization'][1],
    'output_scale': output_details[0]['quantization'][0],
    'output_zero_point': output_details[0]['quantization'][1]
}
with open('quantization_params.yaml', 'w') as f:
    yaml.dump(quantization_params, f)
print("Quantization parameters are saved as: quantization_params.yaml")


Input quantization params:
 scale: 0.003921568859368563
 zero_point: -128
Output quantization params:
 scale: 0.00390625
 zero_point: -128
Quantization parameters are saved as: quantization_params.yaml


## test model

In [28]:
from keras.utils import custom_object_scope

with custom_object_scope({'MyCNNModel': MyCNNModel}):
    model_test = keras.models.load_model(project_root / 'model.h5')
# Model evaluation
results = model_test.evaluate(X_test, y_test, batch_size=32)
print(f"Test loss: {results[0]}")
for name, value in zip(model_test.metrics_names, results):
    print(f"{name}: {value}")

TypeError: Unable to revive model from config. When overriding the `get_config()` method, make sure that the returned config contains all items used as arguments in the  constructor to <class 'model.MyCNNModel'>, which is the default behavior. You can override this default behavior by defining a `from_config(cls, config)` class method to specify how to create an instance of MyCNNModel from its config.

Received config={'trainable': True, 'dtype': {'module': 'keras', 'class_name': 'DTypePolicy', 'config': {'name': 'float32'}, 'registered_name': None}}

Error encountered during deserialization: MyCNNModel.__init__() got an unexpected keyword argument 'trainable'