Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Loading ability in deep learning modules for classification/regression/clustering #1374

Merged
merged 13 commits into from
May 25, 2024
28 changes: 28 additions & 0 deletions aeon/classification/deep_learning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,31 @@ def save_last_model_to_file(self, file_path="./"):
None
"""
self.model_.save(file_path + self.last_file_name + ".keras")

def load_model(self, model_path, classes):
"""Load a pre-trained keras model instead of fitting.

When calling this function, all functionalities can be used
such as predict, predict_proba etc. with the loaded model.

Parameters
----------
model_path : str (path including model name and extension)
The directory where the model will be saved including the model
name with a ".keras" extension.
Example: model_path="path/to/file/best_model.keras"
classes : np.ndarray
The set of unique classes the pre-trained loaded model is trained
to predict during the classification task.

Returns
-------
None
"""
import tensorflow as tf

self.model_ = tf.keras.models.load_model(model_path)
self._is_fitted = True

self.classes_ = classes
self.n_classes_ = len(self.classes_)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Unit tests for classifiers deep learning base class functionality."""

import gc
import os
import tempfile
import time

import pytest
Expand All @@ -16,7 +16,7 @@
class _DummyDeepClassifier(BaseDeepClassifier):
"""Dummy Deep Classifier for testing empty base deep class save utilities."""

def __init__(self, last_file_name):
def __init__(self, last_file_name="last_model"):
self.last_file_name = last_file_name
super().__init__(last_file_name=last_file_name)

Expand Down Expand Up @@ -59,21 +59,35 @@ def _fit(self, X, y):
)
def test_dummy_deep_classifier():
"""Test dummy deep classifier."""
last_file_name = str(time.time_ns())
with tempfile.TemporaryDirectory() as tmp:
import numpy as np

# create a dummy deep classifier
dummy_deep_clf = _DummyDeepClassifier(last_file_name=last_file_name)
last_file_name = str(time.time_ns())

# generate random data
X, y = make_example_2d_numpy()
# create a dummy deep classifier
dummy_deep_clf = _DummyDeepClassifier(last_file_name=last_file_name)

# test fit function on random data
dummy_deep_clf.fit(X=X, y=y)
# generate random data
X, y = make_example_2d_numpy()

# test save last model to file than delete it
dummy_deep_clf.save_last_model_to_file()
# test fit function on random data
dummy_deep_clf.fit(X=X, y=y)

os.remove("./" + last_file_name + ".keras")
# test save last model to file than delete it
dummy_deep_clf.save_last_model_to_file(file_path=tmp)

# test summary of model
assert dummy_deep_clf.summary() is not None
# create a new dummy deep classifier
dummy_deep_clf2 = _DummyDeepClassifier()

# load without fitting
dummy_deep_clf2.load_model(
model_path=tmp + last_file_name + ".keras", classes=np.unique(y)
)

# predict
ypred = dummy_deep_clf2.predict(X=X)

assert len(ypred) == len(y)

# test summary of model
assert dummy_deep_clf.summary() is not None
22 changes: 22 additions & 0 deletions aeon/regression/deep_learning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,25 @@ def save_last_model_to_file(self, file_path="./"):
None
"""
self.model_.save(file_path + self.last_file_name + ".keras")

def load_model(self, model_path):
"""Load a pre-trained keras model instead of fitting.

When calling this function, all functionalities can be used
such as predict etc. with the loaded model.

Parameters
----------
model_path : str (path including model name and extension)
The directory where the model will be saved including the model
name with a ".keras" extension.
Example: model_path="path/to/file/best_model.keras"

Returns
-------
None
"""
import tensorflow as tf

self.model_ = tf.keras.models.load_model(model_path)
self._is_fitted = True
38 changes: 24 additions & 14 deletions aeon/regression/deep_learning/tests/test_deep_regressor_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Unit tests for regressors deep learning base class functionality."""

import gc
import os
import tempfile
import time

import pytest
Expand All @@ -16,7 +16,7 @@
class _DummyDeepRegressor(BaseDeepRegressor):
"""Dummy Deep Regressor for testing empty base deep class save utilities."""

def __init__(self, last_file_name):
def __init__(self, last_file_name="last_model"):
self.last_file_name = last_file_name
super().__init__(last_file_name=last_file_name)

Expand Down Expand Up @@ -56,24 +56,34 @@ def _fit(self, X, y):
)
def test_dummy_deep_regressor():
"""Test for DummyRegressor."""
last_file_name = str(time.time_ns())
with tempfile.TemporaryDirectory() as tmp:
last_file_name = str(time.time_ns())

# create a dummy regressor
dummy_deep_rg = _DummyDeepRegressor(last_file_name=last_file_name)
# create a dummy regressor
dummy_deep_rg = _DummyDeepRegressor(last_file_name=last_file_name)

# generate random data
# generate random data

X, y = make_example_2d_numpy()
X, y = make_example_2d_numpy()

# test fit function on random data
dummy_deep_rg.fit(X=X, y=y)
# test fit function on random data
dummy_deep_rg.fit(X=X, y=y)

# test save last model to file than delete it
# test save last model to file than delete it

dummy_deep_rg.save_last_model_to_file()
dummy_deep_rg.save_last_model_to_file(file_path=tmp)

os.remove("./" + last_file_name + ".keras")
# create a new dummy deep classifier
dummy_deep_clf2 = _DummyDeepRegressor()

# test summary of model
# load without fitting
dummy_deep_clf2.load_model(model_path=tmp + last_file_name + ".keras")

assert dummy_deep_rg.summary() is not None
# predict
ypred = dummy_deep_clf2.predict(X=X)

assert len(ypred) == len(y)

# test summary of model

assert dummy_deep_rg.summary() is not None
67 changes: 59 additions & 8 deletions examples/networks/deep_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
"\n",
"In this demo, we cover the usage of the deep learning models for both TSC, TSCL and TSER.\n",
"\n",
"Soon we plan to include more tasks into the deep learning domain suchas Time Series Forecasting."
"Soon we plan to include more tasks into the deep learning domain suchas Time Series Forecasting.\n",
"\n",
"\n",
"For all figures used in this demo, we use the one provided by the [Deep Learning for Time Series Classification webpage](https://msd-irimas.github.io/pages/dl4tsc/) with this reference figure for all legends needed:\n",
"![legends](./img/legend.png)"
]
},
{
Expand Down Expand Up @@ -98,7 +102,7 @@
"#### InceptionTime\n",
"\n",
"The InceptionTime model is an ensemble of multiple (by default five) Inception models. Each Inception model is a Convolutional Neural Network made of six Inception modules as seen in the Figure below. Each Inception module is composed of multiple (by default three) convolution layer in parallel and a max-pooling operation as well.\n",
"![InceptionTime](./img/Inception.png)"
"![InceptionTime](./img/Inception-archi.png)"
]
},
{
Expand Down Expand Up @@ -147,7 +151,7 @@
"#### H-InceptionTime\n",
"\n",
"Just as InceptionTime, H-InceptionTime is an ensemble of multiple H-Inception models. The model can be seen in the figure below, where the additional custom filters are added in parallel to the first Inception module.\n",
"![H-InceptionTime](./img/H-Inception.png)"
"![H-InceptionTime](./img/H-Inception-archi.png)"
]
},
{
Expand Down Expand Up @@ -296,20 +300,56 @@
"source": [
"### 5. Saving and Loading Model\n",
"\n",
"The functionalities presented here can be applied for deep classification, regression and clustering in the same way.\n",
"\n",
"#### Saving model\n",
"\n",
"While training a deep learning model with `aeon`, a default callback is used to save the best model during training based on the training loss. This saved model is then loaded to use for evaluation and deleted once no longer used. This intermediate step is to avoid saving the model to file when the user is not in need of it later. This however can be changed by setting the `save_best_model` flag to True. The default name for the model is \"best_model\" with a `.keras` extension and will be saved in the default working directory. To change this, simply change the `file_path` and `best_file_name` parameters to your preference.\n",
"\n",
"The same settings can be used if the user needs to save the last model during training as well. This is done by simply setting the flag `save_last_model` to True. The name of this saved model can be set by changing the `last_file_name` parameter and the `file_path` as well.\n",
"\n",
"\n",
"### Loading Model\n",
"\n",
"After training and saving the model in a `.keras` extension, we can then use this saved model to load it in another instance of the classifier instead of training again from scratch, this is done using the `load_model` functionality of the base deep classifier.\n",
"\n",
"Here is an example on saving the best and last models of an FCN model:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/11 [==============================] - 0s 29ms/step\n",
"11/11 [==============================] - 0s 29ms/step\n",
"['1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1']\n",
"['1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1'\n",
" '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1' '1']\n"
]
}
],
"source": [
"xtrain, ytrain = load_classification(name=\"ArrowHead\", split=\"train\")\n",
"xtest, ytest = load_classification(name=\"ArrowHead\", split=\"test\")\n",
Expand All @@ -323,10 +363,21 @@
" n_epochs=2,\n",
")\n",
"\n",
"# The following is commented to avoid CI doing the saving for no reason\n",
"fcn.fit(X=xtrain, y=ytrain)\n",
"ypred = fcn.predict(X=xtest)\n",
"\n",
"# Loading the model ento another FCN model\n",
"\n",
"fcn2 = FCNClassifier()\n",
"fcn2.load_model(model_path=\"./best_fcn.keras\", classes=np.unique(ytrain))\n",
"# The classes parameter is only needed in the case of classification\n",
"ypred2 = fcn2.predict(X=xtest)\n",
"\n",
"print(ypred)\n",
"print(ypred2)\n",
"\n",
"# fcn.fit(X=xtrain, y=ytrain)\n",
"# ypred = fcn.predict(X=xtest)"
"os.remove(\"./best_fcn.keras\")\n",
"os.remove(\"./last_fcn.keras\")"
]
},
{
Expand Down
Binary file added examples/networks/img/H-Inception-archi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/networks/img/Inception-archi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/networks/img/legend.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.