# Overview of the base class structure

`aeon` uses a core inheritance hierarchy of classes across the toolkit, with specialised sub classes in each module. The basic class hierarchy is shown in the following diagram.

<img src="img/aeon_uml_simple.drawio.png" alt="Basic class hierarchy">


## Scikit-learn `BaseEstimator` and aeon `BaseAeonEstimator`

To make sense of this, we break it down from the top.
Everything inherits from sklearn `BaseEstimator`, which mainly handles the mechanisms for getting and setting parameters using the `set_params` and `get_params` methods. These methods are used when the estimators interact with other classes such as [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV), and is also used in aeon's `ComposableEstimatorMixin`, which we'll talk about later.

Then we have aeon's `BaseAeonEstimator` class. This class handles the following for all aeon's estimator:
- management of tags, setting, getting, interaction with sklearn's tags, etc.
- cloning and resetting of the estimator
- creation of test instances using test parameters specified by each estimators. For example, this is used to define fast-running estimator (e.g. a forest classifier with only 2 trees) for the CI/CD pipelines.

#### A word on aeon's estimator tag system
Tags in aeon are used for various purposes, to display estimators capabilities in the documentations, to use specific tests based on each estimator's capabilities. You can check [all existing tags in aeon](https://github.com/aeon-toolkit/aeon/blob/main/aeon/utils/tags/_tags.py) and the [developer documentation on the testing framework](https://www.aeon-toolkit.org/en/stable/developer_guide/testing.html#) to know more about how we use tags.

## `BaseCollectionEstimator` and `BaseSeriesEstimator`

We distinguish between two types of inputs for aeon estimators, series and collections:
- Series represent single time series as a 2D format `(n_channels, n_timepoints)`, some estimators can also use 1D format as `(n_timepoints)` when they don't support multivariate series. Series estimators also have an `axis` parameter, which allow the input shape to be transposed such as the 2D format becomes `(n_timepoints, n_channels)` instead.
- Collections represent an ensemble of time series as a 3D format `(n_samples, n_channels, n_timepoints)`. Again, this can sometime be represented as a 2D format such as `(n_samples, n_timepoints)` for univariate estimators. Preferably, this should be avoided to clear any confusion on the meaning of axes and the possible confusion with with 2D single series. More information on this problem can be found in [this notebook](series_estimator.ipynb).

For example, if we go back to the base class schema `BaseClassifier` inherit from `BaseCollectionEstimator`. This means that during `fit` and `predict`, all estimators inheriting from `BaseClassifier` will take time series collection as inputs. 


## Collection base estimators

The `BaseCollectionEstimator` defines methods to check the shape of the input, extract metadata (e.g. whether the collection is multivariate) and check compatibility of the input against tags of the estimators. For example, when you do the following : 

In [22]:
from aeon.classification.dictionary_based import TemporalDictionaryEnsemble
from aeon.testing.data_generation import make_example_3d_numpy_list

# TDE does not support unequal length collections
X_unequal, y_unequal = make_example_3d_numpy_list()
try:
    TemporalDictionaryEnsemble().fit(X_unequal, y_unequal)
except ValueError as e:
    print(e)

Data seen by instance of TemporalDictionaryEnsemble has unequal length series, but TemporalDictionaryEnsemble cannot handle these characteristics. 


What happens here is that `TemporalDictionaryEnsemble` inherit from `BaseClassifier`, which itself inherit from `BaseCollectionEstimator`. During `fit` and `predict`, `BaseClassifier` calls `_preprocess_collection`, a function defined in `BaseCollectionEstimator`. This function extracts the input metadata (whether it is multivariate, of unequal lengths etc.) and compare it against `TemporalDictionaryEnsemble` tags. These states that the estimator does not support unequal lengths collections, and hence an exception is raised. 

### `BaseClassifier` (aeon.classification)

This is the base class for all classifiers. It uses the standard `fit`, `predict` and `predict_proba` structure from `sklearn`. `fit` and `predict` call the abstract methods `_fit` and `_predict` which are implemented in the subclass to define the classification algorithm. All of the common format checking and conversion are done in final functions such as `fit`, `predict` and are made before calling the abstract methods `_fit` and `_predict`. 

When implementing a new classifier inheriting from `BaseClassifier`, you thus only have to implement the `__init__`, `_fit` and `_predict` methods that handle the classification logic of the classifier. Also, you will need to set the correct tags to allow the check and conversion to be done for you. Note that each base class also define some attributes that are commonly used in the estimators, for example `BaseClassifier` exposes `classes_`, `n_classes_`, `_class_dictionary` that we can use in our new classifier:

In [11]:
from numpy.random import default_rng

from aeon.classification import BaseClassifier
from aeon.testing.data_generation import (
    make_example_3d_numpy,
    make_example_dataframe_list,
)


class RandomClassifier(BaseClassifier):
    """A dummy classifier returning random predictions."""

    _tags = {
        "capability:multivariate": True,  # allow multivariate collections
        "capability:unequal_length": True,  # allow multivariate collections
        "X_inner_type": ["np-list", "numpy3D"],  # Specify data format used internally
    }

    def __init__(self, random_state: int = 42):
        self.random_state = random_state
        super().__init__()

    def _fit(self, X, y):
        self.rng = default_rng(self.random_state)

    def _predict(self, X):
        # generate a random int between 0 and n_classes-1 and use _class_dictionary
        # to convert it to class label
        return [
            self._class_dictionary[i]
            for i in self.rng.integers(low=0, high=self.n_classes_, size=len(X))
        ]


X, y = make_example_3d_numpy(n_channels=2)
print(RandomClassifier().fit_predict(X, y))
X, y = make_example_dataframe_list()
print(RandomClassifier().fit(X, y).predict(X))

[1 0 1 1 0 0 1 0 1 0]
[0, 1, 1, 0, 0, 1, 0, 1, 0, 0]


### `BaseRegressor`, `BaseClusterer` and `BaseCollectionAnomalyDetector` 
These base classes are mostly similar to `BaseClassifier` in how they use the checks and conversion operations from `BaseCollectionEstimator`.

- `BaseRegressor` also defines a `fit`and `predict` method and requires `_fit`and `_predict` methods to be implemented by child classes. The difference is that it has no `predict_proba` method. The tests on `y` are also different, as we can have floats has values for `y`.

- `BaseClusterer` also has `fit` and `predict`, but does not take input `y` as child classes can be unsupervised estimators. It does include `predict_proba`.

- `BaseCollectionAnomalyDetector` also has `fit` and `predict`, but does not take input `y` as child classes can be unsupervised estimators.


### `BaseCollectionTransformer` 

Rather than `fit` and`predict`, the `BaseCollectionTransformer` implements `fit`, `transform` and `fit_transform` methods. It will require child classes to define `_fit`and `_transform` methods. The output of the transform method is not fixed and should be specified with the `output_data_type`.

For example, if the output is another collection of time series (e.g. after using `SAX`), then `output_data_type` must take the `Collection` value (note that this is the default value for all `BaseCollectionTransformer` child classes). If the output is not time series anymore, but rather a 2D array of  features extracted from each input time series, such as in `Rocket` or `RandomShapeletTransform`, then the `output_data_type` must take the `Tabular`.

## Series base estimators
### `BaseForecaster`
### `BaseSegmenter`
### `BaseSeriesTransformer`
### `BaseSeriesAnomalyDetector`
