-
Notifications
You must be signed in to change notification settings - Fork 89
/
_cnn.py
313 lines (269 loc) · 10.9 KB
/
_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
"""Time Convolutional Neural Network (CNN) for classification."""
__author__ = ["James-Large", "TonyBagnall", "hadifawaz1999"]
__all__ = ["CNNClassifier"]
import gc
import os
import time
from copy import deepcopy
from sklearn.utils import check_random_state
from aeon.classification.deep_learning.base import BaseDeepClassifier
from aeon.networks import CNNNetwork
class CNNClassifier(BaseDeepClassifier):
"""
Time Convolutional Neural Network (CNN).
Adapted from the implementation used in [1]_.
Parameters
----------
n_layers : int, default = 2
The number of convolution layers in the network.
kernel_size : int or list of int, default = 7
Kernel size of convolution layers, if not a list, the same kernel size
is used for all layer, len(list) should be n_layers.
n_filters : int or list of int, default = [6, 12]
Number of filters for each convolution layer, if not a list, the same n_filters
is used in all layers.
avg_pool_size : int or list of int, default = 3
The size of the average pooling layer, if not a list, the same
max pooling size is used for all convolution layer.
activation : str or list of str, default = "sigmoid"
Keras activation function used in the model for each layer, if not a list,
the same activation is used for all layers.
padding : str or list of str, default = 'valid'
The method of padding in convolution layers, if not a list, the same padding
used for all convolution layers.
strides : int or list of int, default = 1
The strides of kernels in the convolution and max pooling layers, if not a
list, the same strides are used for all layers.
dilation_rate : int or list of int, default = 1
The dilation rate of the convolution layers, if not a list, the same dilation
rate is used all over the network.
use_bias : bool or list of bool, default = True
Condition on whether to use bias values for convolution layers,
if not a list, the same condition is used for all layers.
random_state : int, default = 0
Seed to any needed random actions.
n_epochs : int, default = 2000
The number of epochs to train the model.
batch_size : int, default = 16
The number of samples per gradient update.
verbose : boolean, default = False
Whether to output extra information.
loss : string, default = "mean_squared_error"
Fit parameter for the keras model.
optimizer : keras.optimizer, default = keras.optimizers.Adam()
metrics : list of strings, default = ["accuracy"]
callbacks : keras.callbacks, default = model_checkpoint
To save best model on training loss.
file_path : file_path for the best model
Only used if checkpoint is used as callback.
save_best_model : bool, default = False
Whether to save the best model, if the modelcheckpoint callback is used by
default, this condition, if True, will prevent the automatic deletion of the
best saved model from file and the user can choose the file name.
save_last_model : bool, default = False
Whether to save the last model, last epoch trained, using the base class method
save_last_model_to_file.
best_file_name : str, default = "best_model"
The name of the file of the best model, if save_best_model is set to False,
this parameter is discarded.
last_file_name : str, default = "last_model"
The name of the file of the last model, if save_last_model is set to False,
this parameter is discarded.
Notes
-----
Adapted from the implementation from Fawaz et. al
https://github.com/hfawaz/dl-4-tsc/blob/master/classifiers/cnn.py
References
----------
.. [1] Zhao et. al, Convolutional neural networks for time series classification,
Journal of Systems Engineering and Electronics, 28(1):2017.
Examples
--------
>>> from aeon.classification.deep_learning import CNNClassifier
>>> from aeon.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> cnn = CNNClassifier(n_epochs=20, batch_size=4) # doctest: +SKIP
>>> cnn.fit(X_train, y_train) # doctest: +SKIP
CNNClassifier(...)
"""
def __init__(
self,
n_layers=2,
kernel_size=7,
n_filters=None,
avg_pool_size=3,
activation="sigmoid",
padding="valid",
strides=1,
dilation_rate=1,
n_epochs=2000,
batch_size=16,
callbacks=None,
file_path="./",
save_best_model=False,
save_last_model=False,
best_file_name="best_model",
last_file_name="last_model",
verbose=False,
loss="mean_squared_error",
metrics=None,
random_state=None,
use_bias=True,
optimizer=None,
):
self.n_layers = n_layers
self.kernel_size = kernel_size
self.n_filters = n_filters
self.padding = padding
self.strides = strides
self.dilation_rate = dilation_rate
self.avg_pool_size = avg_pool_size
self.activation = activation
self.use_bias = use_bias
self.n_epochs = n_epochs
self.callbacks = callbacks
self.file_path = file_path
self.save_best_model = save_best_model
self.save_last_model = save_last_model
self.best_file_name = best_file_name
self.verbose = verbose
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.history = None
super().__init__(
batch_size=batch_size,
random_state=random_state,
last_file_name=last_file_name,
)
self._network = CNNNetwork(
n_layers=self.n_layers,
kernel_size=self.kernel_size,
n_filters=self.n_filters,
avg_pool_size=self.avg_pool_size,
activation=self.activation,
padding=self.padding,
strides=self.strides,
dilation_rate=self.dilation_rate,
use_bias=self.use_bias,
random_state=self.random_state,
)
def build_model(self, input_shape, n_classes, **kwargs):
"""Construct a compiled, un-trained, keras model that is ready for training.
In aeon, time series are stored in numpy arrays of shape (d, m), where d
is the number of dimensions, m is the series length. Keras/tensorflow assume
data is in shape (m, d). This method also assumes (m, d). Transpose should
happen in fit.
Parameters
----------
input_shape : tuple
The shape of the data fed into the input layer, should be (m, d)
n_classes : int
The number of classes, which becomes the size of the output layer
Returns
-------
output : a compiled Keras Model
"""
import tensorflow as tf
tf.random.set_seed(self.random_state)
if self.metrics is None:
metrics = ["accuracy"]
else:
metrics = self.metrics
input_layer, output_layer = self._network.build_network(input_shape, **kwargs)
output_layer = tf.keras.layers.Dense(
units=n_classes, activation=self.activation, use_bias=self.use_bias
)(output_layer)
self.optimizer_ = (
tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer
)
model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
model.compile(
loss=self.loss,
optimizer=self.optimizer_,
metrics=metrics,
)
return model
def _fit(self, X, y):
"""Fit the classifier on the training set (X, y).
Parameters
----------
X : np.ndarray of shape = (n_instances (n), n_channels (d), series_length (m))
The training input samples.
y : np.ndarray of shape n
The training data class labels.
Returns
-------
self : object
"""
import tensorflow as tf
y_onehot = self.convert_y_to_keras(y)
# Transpose to conform to Keras input style.
X = X.transpose(0, 2, 1)
check_random_state(self.random_state)
self.input_shape = X.shape[1:]
self.training_model_ = self.build_model(self.input_shape, self.n_classes_)
if self.verbose:
self.training_model_.summary()
self.file_name_ = (
self.best_file_name if self.save_best_model else str(time.time_ns())
)
self.callbacks_ = (
[
tf.keras.callbacks.ModelCheckpoint(
filepath=self.file_path + self.file_name_ + ".hdf5",
monitor="loss",
save_best_only=True,
),
]
if self.callbacks is None
else self.callbacks
)
self.history = self.training_model_.fit(
X,
y_onehot,
batch_size=self.batch_size,
epochs=self.n_epochs,
verbose=self.verbose,
callbacks=self.callbacks_,
)
try:
self.model_ = tf.keras.models.load_model(
self.file_path + self.file_name_ + ".hdf5", compile=False
)
if not self.save_best_model:
os.remove(self.file_path + self.file_name_ + ".hdf5")
except FileNotFoundError:
self.model_ = deepcopy(self.training_model_)
if self.save_last_model:
self.save_last_model_to_file(file_path=self.file_path)
gc.collect()
return self
@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default = "default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return "default" set.
For classifiers, a "default" set of parameters should be provided for
general testing, and a "results_comparison" set for comparing against
previously recorded results if the general set does not produce suitable
probabilities to compare against.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class.
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`.
"""
param1 = {
"n_epochs": 10,
"batch_size": 4,
"avg_pool_size": 4,
}
test_params = [param1]
return test_params