-
Notifications
You must be signed in to change notification settings - Fork 89
/
_signature_method.py
124 lines (106 loc) · 4.26 KB
/
_signature_method.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Signature transformer."""
from sklearn.pipeline import Pipeline
from aeon.transformations.collection import BaseCollectionTransformer
from aeon.transformations.collection.signature_based._augmentations import (
_make_augmentation_pipeline,
)
from aeon.transformations.collection.signature_based._compute import (
_WindowSignatureTransform,
)
class SignatureTransformer(BaseCollectionTransformer):
"""Transformation class from the signature method.
Follows the methodology laid out in the paper:
"A Generalised Signature Method for Multivariate Time Series"
Parameters
----------
augmentation_list: tuple of strings, contains the augmentations to be
applied before application of the signature transform.
window_name: str, The name of the window transform to apply.
window_depth: int, The depth of the dyadic window. (Active only if
`window_name == 'dyadic'`).
window_length: int, The length of the sliding/expanding window. (Active
only if `window_name in ['sliding, 'expanding']`.
window_step: int, The step of the sliding/expanding window. (Active
only if `window_name in ['sliding, 'expanding']`.
rescaling: str or None, The method of signature rescaling.
sig_tfm: str, String to specify the type of signature transform. One of:
['signature', 'logsignature']).
depth: int, Signature truncation depth.
Attributes
----------
signature_method: sklearn.Pipeline, A sklearn pipeline object that contains
all the steps to extract the signature features.
"""
_tags = {
"output_data_type": "Tabular",
"capability:multivariate": True,
"python_dependencies": "esig",
"python_version": "<3.11",
}
def __init__(
self,
augmentation_list=("basepoint", "addtime"),
window_name="dyadic",
window_depth=3,
window_length=None,
window_step=None,
rescaling=None,
sig_tfm="signature",
depth=4,
):
self.augmentation_list = augmentation_list
self.window_name = window_name
self.window_depth = window_depth
self.window_length = window_length
self.window_step = window_step
self.rescaling = rescaling
self.sig_tfm = sig_tfm
self.depth = depth
super().__init__()
self.setup_feature_pipeline()
def setup_feature_pipeline(self):
"""Set up the signature method as an sklearn pipeline."""
augmentation_step = _make_augmentation_pipeline(self.augmentation_list)
transform_step = _WindowSignatureTransform(
window_name=self.window_name,
window_depth=self.window_depth,
window_length=self.window_length,
window_step=self.window_step,
sig_tfm=self.sig_tfm,
sig_depth=self.depth,
rescaling=self.rescaling,
)
# The so-called 'signature method' as defined in the reference paper
self.signature_method = Pipeline(
[
("augmentations", augmentation_step),
("window_and_transform", transform_step),
]
)
def _fit(self, X, y=None):
self.signature_method.fit(X)
return self
def _transform(self, X, y=None):
return self.signature_method.transform(X)
@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
params = {
"augmentation_list": ("basepoint", "addtime"),
"depth": 3,
"window_name": "global",
}
return params