diff --git a/.all-contributorsrc b/.all-contributorsrc index 8aca5c8e5f4..b641b5a2979 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -1747,7 +1747,17 @@ "contributions": [ "bug" ] - } + }, + { + "login": "haskarb", + "name": "Bhaskar Dhariyal", + "avatar_url": "https://avatars.githubusercontent.com/u/20501023?v=4", + "profile": "https://haskarb.github.io/", + "contributions": [ + "code", + "test" + ] + }, ], "projectName": "sktime", "projectOwner": "alan-turing-institute", diff --git a/CODEOWNERS b/CODEOWNERS index 63a8a64ea0f..de3020a15ed 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -36,6 +36,7 @@ sktime/transformations/series/kalman_filter.py @NoaBenAmi sktime/transformations/panel/augmenter.py @MrPr3ntice @iljamaurer sktime/transformations/multiplex.py @miraep8 sktime/transformations/tests/test_multiplexer.py @miraep8 +sktime/transformations/panel/channel_selection.py @haskarb @a-pasos-ruiz @TonyBagnall sktime/forecasting/base/ @fkiraly @mloning @aiwalter sktime/forecasting/base/adapters/_statsforecast.py @FedericoGarza diff --git a/examples/channel_selection.ipynb b/examples/channel_selection.ipynb new file mode 100644 index 00000000000..7c5a799e3fa --- /dev/null +++ b/examples/channel_selection.ipynb @@ -0,0 +1,534 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "89a363cd-8944-475a-88ea-4401785218c5", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "id": "b45cbeda-65bc-4b10-9f9c-9afd43a54d20", + "metadata": {}, + "source": [ + "# Channel Selection in Multivariate Time Series Classification \n" + ] + }, + { + "cell_type": "markdown", + "id": "08decb5b-8dfb-4666-a66b-3a2349960956", + "metadata": {}, + "source": [ + "## Overview" + ] + }, + { + "cell_type": "markdown", + "id": "da743484-17d3-4cec-8eaa-8d8293ce6f35", + "metadata": {}, + "source": [ + "Sometimes every channel is not required to perform classification; only a few are useful. The [1] proposed a fast channel selection technique for Multivariate Time Classification. " + ] + }, + { + "cell_type": "markdown", + "id": "dcbe2174-a691-4093-ab80-d796edb5121d", + "metadata": {}, + "source": [ + "[1] : Fast Channel Selection for Scalable Multivariate Time Series Classification [Link](https://www.researchgate.net/publication/354445008_Fast_Channel_Selection_for_Scalable_Multivariate_Time_Series_Classification)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d1779970-eefb-4577-9c4e-e0a19ceadcc1", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import RidgeClassifierCV\n", + "from sklearn.pipeline import make_pipeline\n", + "\n", + "from sktime.datasets import load_UCR_UEA_dataset\n", + "from sktime.transformations.panel import channel_selection\n", + "from sktime.transformations.panel.rocket import Rocket" + ] + }, + { + "cell_type": "markdown", + "id": "0437ca7a-5b5a-4e28-b565-0b2df4eac60d", + "metadata": {}, + "source": [ + "# 1 Initialise the Pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "830137a3-10c3-49b9-9a98-7062dc7ab1d8", + "metadata": {}, + "outputs": [], + "source": [ + "# cs = channel_selection.ElbowClassSum() # ECS\n", + "cs = channel_selection.ElbowClassPairwise() # ECP" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "89443793-7cf0-4a4c-a4b1-d928a46a1bb2", + "metadata": {}, + "outputs": [], + "source": [ + "rocket_pipeline = make_pipeline(cs, Rocket(), RidgeClassifierCV())" + ] + }, + { + "cell_type": "markdown", + "id": "5a268cc1-d5bf-4b02-916b-f3417c1cd3ff", + "metadata": {}, + "source": [ + "# 2 Load and Fit the Training Data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "68f508e3-ecc9-4b3a-b4de-7073cd1dfb90", + "metadata": {}, + "outputs": [], + "source": [ + "data = \"BasicMotions\"\n", + "X_train, y_train = load_UCR_UEA_dataset(data, split=\"train\", return_X_y=True)\n", + "X_test, y_test = load_UCR_UEA_dataset(data, split=\"test\", return_X_y=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "94f421bc-e384-4b98-89af-111a4d8c378b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Pipeline(steps=[('elbowclasspairwise', ElbowClassPairwise()),\n", + " ('rocket', Rocket()),\n", + " ('ridgeclassifiercv',\n", + " RidgeClassifierCV(alphas=array([ 0.1, 1. , 10. ])))])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rocket_pipeline.fit(X_train, y_train)" + ] + }, + { + "cell_type": "markdown", + "id": "0c867fdb-5126-44f6-b7f0-f999d3f60457", + "metadata": {}, + "source": [ + "# 3 Classify the Test Data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "04573f4d-0b61-4ab8-8355-79f0aa1ca04f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rocket_pipeline.score(X_test, y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "d18ac8bc-a83a-4dd7-b577-aefc25d7bed6", + "metadata": {}, + "source": [ + "# 4 Identify channels" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "35a44d68-7bce-44b0-baf3-e4f11606001c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0, 1]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rocket_pipeline.steps[0][1].channels_selected_" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "358ab28f-edbe-49f4-95f5-8a7d0fb5d166", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Centroid_badminton_runningCentroid_badminton_standingCentroid_badminton_walkingCentroid_running_standingCentroid_running_walkingCentroid_standing_walking
039.59467955.75278548.44077963.61022057.24738310.717044
157.68176724.39054327.77026960.45812562.33912016.370347
220.17591124.12696922.33162125.67197922.9915554.897452
312.54621212.43915212.7418546.3176546.6957433.585273
410.1011968.8658719.2219086.5201726.7157021.299989
523.46425114.56868513.95344518.87842919.7685497.228389
\n", + "
" + ], + "text/plain": [ + " Centroid_badminton_running Centroid_badminton_standing \\\n", + "0 39.594679 55.752785 \n", + "1 57.681767 24.390543 \n", + "2 20.175911 24.126969 \n", + "3 12.546212 12.439152 \n", + "4 10.101196 8.865871 \n", + "5 23.464251 14.568685 \n", + "\n", + " Centroid_badminton_walking Centroid_running_standing \\\n", + "0 48.440779 63.610220 \n", + "1 27.770269 60.458125 \n", + "2 22.331621 25.671979 \n", + "3 12.741854 6.317654 \n", + "4 9.221908 6.520172 \n", + "5 13.953445 18.878429 \n", + "\n", + " Centroid_running_walking Centroid_standing_walking \n", + "0 57.247383 10.717044 \n", + "1 62.339120 16.370347 \n", + "2 22.991555 4.897452 \n", + "3 6.695743 3.585273 \n", + "4 6.715702 1.299989 \n", + "5 19.768549 7.228389 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rocket_pipeline.steps[0][1].distance_frame_" + ] + }, + { + "cell_type": "markdown", + "id": "c75f99ea-3966-483b-9f08-89e3f0dbffeb", + "metadata": {}, + "source": [ + "# 5 Standalone" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "82607728-1095-4f15-a06e-d2463bb5c642", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ElbowClassPairwise()" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cs.fit(X_train, y_train)" + ] + }, + { + "cell_type": "markdown", + "id": "a1f3ec36-ce86-4388-b7b5-b3b2a087b43a", + "metadata": {}, + "source": [ + "# 6 Distance Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f4a19774-368e-43d4-a45a-d8109ae2d17f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Centroid_badminton_runningCentroid_badminton_standingCentroid_badminton_walkingCentroid_running_standingCentroid_running_walkingCentroid_standing_walking
039.59467955.75278548.44077963.61022057.24738310.717044
157.68176724.39054327.77026960.45812562.33912016.370347
220.17591124.12696922.33162125.67197922.9915554.897452
312.54621212.43915212.7418546.3176546.6957433.585273
410.1011968.8658719.2219086.5201726.7157021.299989
523.46425114.56868513.95344518.87842919.7685497.228389
\n", + "
" + ], + "text/plain": [ + " Centroid_badminton_running Centroid_badminton_standing \\\n", + "0 39.594679 55.752785 \n", + "1 57.681767 24.390543 \n", + "2 20.175911 24.126969 \n", + "3 12.546212 12.439152 \n", + "4 10.101196 8.865871 \n", + "5 23.464251 14.568685 \n", + "\n", + " Centroid_badminton_walking Centroid_running_standing \\\n", + "0 48.440779 63.610220 \n", + "1 27.770269 60.458125 \n", + "2 22.331621 25.671979 \n", + "3 12.741854 6.317654 \n", + "4 9.221908 6.520172 \n", + "5 13.953445 18.878429 \n", + "\n", + " Centroid_running_walking Centroid_standing_walking \n", + "0 57.247383 10.717044 \n", + "1 62.339120 16.370347 \n", + "2 22.991555 4.897452 \n", + "3 6.695743 3.585273 \n", + "4 6.715702 1.299989 \n", + "5 19.768549 7.228389 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cs.distance_frame_" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a29b0ece", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "13" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cs.train_time_" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "30ff7f6bb2505d289b6e6022e217e794dc64e9153f959b8a264cb3c597a35999" + }, + "kernelspec": { + "display_name": "Python 3.7.5 ('sktime-test')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sktime/transformations/panel/channel_selection.py b/sktime/transformations/panel/channel_selection.py new file mode 100644 index 00000000000..103cbe002ba --- /dev/null +++ b/sktime/transformations/panel/channel_selection.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- +"""Channel Selection techniques for Multivariate Time Series Classification. + +A transformer that selects a subset of channels/dimensions for time series +classification using a scoring system with an elbow point method. +""" + +__author__ = ["haskarb", "a-pasos-ruiz", "TonyBagnall"] +__all__ = ["ElbowClassSum", "ElbowClassPairwise"] + + +import itertools +import time + +import numpy as np +import pandas as pd +from scipy.spatial.distance import euclidean +from sklearn.neighbors import NearestCentroid + +from sktime.datatypes._panel._convert import from_3d_numpy_to_nested +from sktime.transformations.base import BaseTransformer + + +def _eu_dist(x, y): + """Calculate the euclidean distance.""" + return euclidean(x, y) + + +def _detect_knee_point(values, indices): + """Find elbow point.""" + n_points = len(values) + all_coords = np.vstack((range(n_points), values)).T + first_point = all_coords[0] + line_vec = all_coords[-1] - all_coords[0] + line_vec_norm = line_vec / np.sqrt(np.sum(line_vec**2)) + vec_from_first = all_coords - first_point + scalar_prod = np.sum(vec_from_first * np.tile(line_vec_norm, (n_points, 1)), axis=1) + vec_from_first_parallel = np.outer(scalar_prod, line_vec_norm) + vec_to_line = vec_from_first - vec_from_first_parallel + dist_to_line = np.sqrt(np.sum(vec_to_line**2, axis=1)) + knee_idx = np.argmax(dist_to_line) + knee = values[knee_idx] + best_dims = [idx for (elem, idx) in zip(values, indices) if elem > knee] + if len(best_dims) == 0: + return [knee_idx], knee_idx + + return (best_dims,) + + +class _distance_matrix: + """Create distance matrix.""" + + def distance(self, centroid_frame): + """Fuction to create DM.""" + distance_pair = list( + itertools.combinations(range(0, centroid_frame.shape[0]), 2) + ) + # exit() + + map_cls = centroid_frame.class_vals.to_dict() + distance_frame = pd.DataFrame() + for class_ in distance_pair: + + class_pair = [] + # calculate the distance of centroid here + for _, (q, t) in enumerate( + zip( + centroid_frame.drop(["class_vals"], axis=1).iloc[class_[0], :], + centroid_frame.iloc[class_[1], :], + ) + ): + # print(eu_dist(q.values, t.values)) + class_pair.append(_eu_dist(q.values, t.values)) + dict_ = { + f"Centroid_{map_cls[class_[0]]}_{map_cls[class_[1]]}": class_pair + } + # print(class_[0]) + + distance_frame = pd.concat([distance_frame, pd.DataFrame(dict_)], axis=1) + + return distance_frame + + +class _shrunk_centroid: + """Create centroid.""" + + def __init__(self, shrink=0): + self.shrink = shrink + + def create_centroid(self, X, y): + """Create the centroid for each class.""" + _, ncols, _ = X.shape + cols = ["dim_" + str(i) for i in range(ncols)] + ts = X + centroids = [] + + # le = LabelEncoder() + # y_ind = le.fit_transform(y) + + for dim in range(ts.shape[1]): + train = ts[:, dim, :] + clf = NearestCentroid(train) + clf = NearestCentroid(shrink_threshold=self.shrink) + clf.fit(train, y) + centroids.append(clf.centroids_) + + centroid_frame = from_3d_numpy_to_nested( + np.stack(centroids, axis=1), column_names=cols + ) + centroid_frame["class_vals"] = clf.classes_ + + return centroid_frame.reset_index(drop=True) + + +class ElbowClassSum(BaseTransformer): + """Elbow Class Sum (ECS) transformer to select a subset of channels. + + Overview: From the input of multivariate time series data, create a distance + matrix [1] by calculating the distance between each class centroid. The + ECS selects the subset of channels using the elbow method, which maximizes the + distance between the class centroids by aggregating the distance for every + class pair across each channel. + + Note: Channels, variables, dimensions, features are used interchangeably in + literature. + + Attributes + ---------- + channels_selected_ : list of integers; integer being the index of the channel + List of channels selected by the ECS. + distance_frame_ : DataFrame + distance matrix of the class centroids pair and channels. + ``shape = [n_channels, n_class_centroids_pairs]`` + Table 1 provides an illustration in [1]. + train_time_ : int + Time taken to train the ECS. + + Notes + ----- + Original repository: https://github.com/mlgig/Channel-Selection-MTSC + + References + ---------- + ..[1]: Bhaskar Dhariyal et al. “Fast Channel Selection for Scalable Multivariate + Time Series Classification.” AALTD, ECML-PKDD, Springer, 2021 + + Examples + -------- + >>> from sktime.transformations.panel.channel_selection import ElbowClassSum + >>> from sktime.datasets import load_UCR_UEA_dataset + >>> cs = ElbowClassSum() + >>> X_train, y_train = load_UCR_UEA_dataset( + ... "Cricket", split="train", return_X_y=True + ... ) + >>> cs.fit(X_train, y_train) + ElbowClassSum(...) + >>> Xt = cs.transform(X_train) + """ + + _tags = { + "scitype:transform-input": "Series", + # what is the scitype of X: Series, or Panel + # "scitype:transform-output": "Primitives", + # what scitype is returned: Primitives, Series, Panel + "scitype:instancewise": True, # is this an instance-wise transform? + "univariate-only": False, # can the transformer handle multivariate X? + "X_inner_mtype": "numpy3D", # which mtypes do _fit/_predict support for X? + "y_inner_mtype": "numpy1D", # which mtypes do _fit/_predict support for y? + "requires_y": True, # does y need to be passed in fit? + "fit_is_empty": False, # is fit empty and can be skipped? Yes = True + "skip-inverse-transform": True, # is inverse-transform skipped when called? + "capability:unequal_length": False, + # can the transformer handle unequal length time series (if passed Panel)? + } + + def __init__(self): + + super(ElbowClassSum, self).__init__() + + def _fit(self, X, y): + """Fit ECS to a specified X and y. + + Parameters + ---------- + X: pandas DataFrame or np.ndarray + The training input samples. + y: array-like or list + The class values for X. + + Returns + ------- + self : reference to self. + + """ + self.channels_selected_ = [] + start = int(round(time.time() * 1000)) + centroid_obj = _shrunk_centroid(0) + centroids = centroid_obj.create_centroid(X.copy(), y) + distances = _distance_matrix() + self.distance_frame_ = distances.distance(centroids) + distance = self.distance_frame_.sum(axis=1).sort_values(ascending=False).values + indices = self.distance_frame_.sum(axis=1).sort_values(ascending=False).index + self.channels_selected_.extend(_detect_knee_point(distance, indices)[0]) + self.train_time_ = int(round(time.time() * 1000)) - start + return self + + def _transform(self, X, y=None): + """ + Transform X and return a transformed version. + + Parameters + ---------- + X : pandas DataFrame or np.ndarray + The input data to transform. + + Returns + ------- + output : pandas DataFrame + X with a subset of channels + """ + return X[:, self.channels_selected_, :] + + +class ElbowClassPairwise(BaseTransformer): + """Elbow Class Pairwise (ECP) transformer to select a subset of channels. + + Overview: From the input of multivariate time series data, create a distance + matrix [1] by calculating the distance between each class centroid. The ECP + selects the subset of channels using the elbow method that maximizes the + distance between each class centroids pair across all channels. + + Note: Channels, variables, dimensions, features are used interchangeably in + literature. + + Attributes + ---------- + channels_selected_ : list of integers; integer being the index of the channel + List of channels selected by the ECS. + distance_frame_ : DataFrame + distance matrix of the class centroids pair and channels. + ``shape = [n_channels, n_class_centroids_pairs]`` + Table 1 provides an illustration in [1]. + train_time_ : int + Time taken to train the ECP. + + Notes + ----- + Original repository: https://github.com/mlgig/Channel-Selection-MTSC + + References + ---------- + ..[1]: Bhaskar Dhariyal et al. “Fast Channel Selection for Scalable Multivariate + Time Series Classification.” AALTD, ECML-PKDD, Springer, 2021 + + Examples + -------- + >>> from sktime.transformations.panel.channel_selection import ElbowClassPairwise + >>> from sktime.datasets import load_UCR_UEA_dataset + >>> cs = ElbowClassPairwise() + >>> X_train, y_train = load_UCR_UEA_dataset( + ... "Cricket", split="train", return_X_y=True + ... ) + >>> cs.fit(X_train, y_train) + ElbowClassPairwise(...) + >>> Xt = cs.transform(X_train) + """ + + _tags = { + "scitype:transform-input": "Series", + # what is the scitype of X: Series, or Panel + # "scitype:transform-output": "Primitives", + # what scitype is returned: Primitives, Series, Panel + "scitype:instancewise": True, # is this an instance-wise transform? + "univariate-only": False, # can the transformer handle multivariate X? + "X_inner_mtype": "numpy3D", # which mtypes do _fit/_predict support for X? + "y_inner_mtype": "numpy1D", # which mtypes do _fit/_predict support for y? + "requires_y": True, # does y need to be passed in fit? + "fit_is_empty": False, # is fit empty and can be skipped? Yes = True + "skip-inverse-transform": True, # is inverse-transform skipped when called? + "capability:unequal_length": False, + # can the transformer handle unequal length time series (if passed Panel)? + } + + def __init__(self): + super(ElbowClassPairwise, self).__init__() + + def _fit(self, X, y): + """Fit ECP to a specified X and y. + + Parameters + ---------- + X: pandas DataFrame or np.ndarray + The training input samples. + y: array-like or list + The class values for X. + + Returns + ------- + self : reference to self. + + """ + self.channels_selected_ = [] + start = int(round(time.time() * 1000)) + centroid_obj = _shrunk_centroid(0) + df = centroid_obj.create_centroid(X.copy(), y) + obj = _distance_matrix() + self.distance_frame_ = obj.distance(df) + + for pairdistance in self.distance_frame_.iteritems(): + distance = pairdistance[1].sort_values(ascending=False).values + indices = pairdistance[1].sort_values(ascending=False).index + self.channels_selected_.extend(_detect_knee_point(distance, indices)[0]) + self.channels_selected_ = list(set(self.channels_selected_)) + self.train_time_ = int(round(time.time() * 1000)) - start + # self._is_fitted = True + return self + + def _transform(self, X, y=None): + """ + Transform X and return a transformed version. + + Parameters + ---------- + X : pandas DataFrame or np.ndarray + The input data to transform. + + Returns + ------- + output : pandas DataFrame + X with a subset of channels + """ + return X[:, self.channels_selected_, :] diff --git a/sktime/transformations/panel/tests/test_channel_selection.py b/sktime/transformations/panel/tests/test_channel_selection.py new file mode 100644 index 00000000000..9af05f45945 --- /dev/null +++ b/sktime/transformations/panel/tests/test_channel_selection.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +"""Channel selection test code.""" +from sktime.datasets import load_basic_motions +from sktime.transformations.panel.channel_selection import ElbowClassPairwise + + +def test_cs_basic_motions(): + """Test channel selection on basic motions dataset.""" + X, y = load_basic_motions(split="train", return_X_y=True) + + ecp = ElbowClassPairwise() + + ecp.fit(X, y) + + # transform the training data + + ecp.transform(X, y) + + # test shape pf transformed data should be (n_samples, n_channels_selected) + assert ecp.transform(X, y).shape == (X.shape[0], len(ecp.channels_selected_)) + + # test shape of transformed data should be (n_samples, n_channels_selected) + + X_test, y_test = load_basic_motions(split="test", return_X_y=True) + + assert ecp.transform(X_test).shape == (X_test.shape[0], len(ecp.channels_selected_)) diff --git a/sktime/transformations/series/tests/test_optional_passthrough.py b/sktime/transformations/series/tests/test_optional_passthrough.py new file mode 100644 index 00000000000..490d5d5a3dd --- /dev/null +++ b/sktime/transformations/series/tests/test_optional_passthrough.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +"""Tests for OptionalPassthrough transformer.""" + +import pytest +from pandas.testing import assert_series_equal + +from sktime.transformations.series.boxcox import BoxCoxTransformer +from sktime.transformations.series.compose import OptionalPassthrough +from sktime.utils._testing.series import _make_series + + +@pytest.mark.parametrize("passthrough", [True, False]) +def test_passthrough(passthrough): + """Test that passthrough works as expected.""" + y = _make_series(n_columns=1) + + optional_passthourgh = OptionalPassthrough( + BoxCoxTransformer(), passthrough=passthrough + ) + box_cox = BoxCoxTransformer() + + y_hat_passthrough = optional_passthourgh.fit_transform(y) + y_inv_passthrough = optional_passthourgh.inverse_transform(y_hat_passthrough) + + y_hat_boxcox = box_cox.fit_transform(y) + y_inv_boxcox = box_cox.inverse_transform(y_hat_boxcox) + + assert_series_equal(y, y_inv_passthrough) + + if passthrough: + assert_series_equal(y, y_hat_passthrough) + with pytest.raises(AssertionError): + assert_series_equal(y_hat_boxcox, y_hat_passthrough) + else: + assert_series_equal(y_hat_passthrough, y_hat_boxcox) + assert_series_equal(y_inv_passthrough, y_inv_boxcox) + + with pytest.raises(AssertionError): + assert_series_equal(y, y_hat_passthrough)