In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sklearn.svm import LinearSVC

In [None]:
from sklearn.feature_selection import RFECV, SelectFromModel

In [None]:
DATA_PATH = Path("../UCRArchive_2018/")
RESULTS_PATH = Path("../results/")

In [None]:
def load_y(path):
    frame = pd.read_csv(path, header=None, index_col=None)
    y = frame.values[:, 0]

    return y

In [None]:
from sklearn.feature_selection import SelectorMixin
from sklearn.utils import check_random_state


class RandomSelector(SelectorMixin):
    def __init__(self, n_features=0.5, random_state=None):
        self.n_features = n_features
        self.random_state = random_state

    def fit(self, X, y=None):
        _, features = X.shape

        rnd = check_random_state(self.random_state)

        self._index = rnd.choice(
            features,
            np.ceil(features * self.n_features).astype(int),
            replace=False,
        )

        self._mask = np.isin(np.arange(features), self._index)

        return self

    def _get_support_mask(self):
        return self._mask

    def _get_tags(self):
        return {}

In [None]:
selectors = [
    ("Random_10", RandomSelector(n_features=0.1, random_state=42)),
    ("Random_50", RandomSelector(n_features=0.5, random_state=42)),
    ("Random_30", RandomSelector(n_features=0.3, random_state=42)),
    ("LassoSVC", SelectFromModel(LinearSVC(penalty="l1", dual=False, random_state=42))),
    ("Tree", SelectFromModel(ExtraTreesClassifier(random_state=42))),
]

In [None]:
n_splits = 5

kfold = KFold(n_splits=n_splits, random_state=42, shuffle=True)

results = []

with open(RESULTS_PATH / f"classification_select_features_fdtw.csv", "w") as res_file:
    for dataset in tqdm(files_frame[:50].sort_values("samples").itertuples(), total=50):

        y = load_y(dataset.path)

        def_path = dataset.path.replace(".csv", "")

        for name, selector in selectors:
            for metric in ("dtw", "fdtw", "itakura", "sakoe_chiba"):
                for a in (0.6, 0.8):
                    record = {
                        "dataset": def_path.split("/")[-1],
                        "metric": f"dd_{metric}_{a:g}",
                        "method": name,
                        "accuracy": 0,
                        "n_features": 0,
                    }

                    try:
                        X = np.loadtxt(f"{def_path}_{metric}.gz", delimiter=",")
                        X_der = np.loadtxt(f"{def_path}_der_{metric}.gz", delimiter=",")
                    except OSError:
                        continue

                    for train_index, test_index in kfold.split(X):
                        X_c = (1 - a) * X + a * X_der

                        y_train = y[train_index]
                        X_train = selector.fit_transform(
                            X_c[train_index][:, train_index], y_train
                        )

                        record["n_features"] += (
                            X_train.shape[1] / train_index.shape[0] / n_splits
                        )

                        X_test = selector.transform(X_c[test_index][:, train_index])
                        y_test = y[test_index]

                        svc = LinearSVC(random_state=42)

                        svc.fit(X_train, y_train)
                        y_pred = svc.predict(X_test)

                        record["accuracy"] += (
                            accuracy_score(y_true=y_test, y_pred=y_pred) / n_splits
                        )

                    results.append(record)

                    res_file.write(
                        "{dataset},{metric},{method},{accuracy:.5g},{n_features:.3g}\n".format(
                            **record
                        )
                    )

                res_file.flush()