In [27]:
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
from scipy.stats import uniform, randint

from tools import create_x_y

In [28]:
df = pd.read_csv('../datasets/relevant/amzn.csv')

In [29]:
VITAL_COLS = ["date_", "ticker", "close", "diffs", "bin_2", "bin_3", "bin_5"]
OTHER_COLS = [i for i in df.columns if i not in VITAL_COLS]
print(OTHER_COLS)

['low', 'vol', 'max_pos', 'std_neg', 'std_pos', 'mean_comp', 'mean_neg', 'mean_pos', 'median_comp', 'count']


In [30]:
def weight_data(x):
    # least represented class is 1.0, others are percentage_least/percentage_other
    percentages = {}
    len_x = len(x)
    for uniq_val in set(x):
        percentages[uniq_val] = sum([1 for i in x if i==uniq_val])/len_x
    least = min(percentages, key=percentages.get)
    weights = {i:percentages[least]/percentages[i] for i in percentages}
    return weights

In [31]:
def measure(x, y):
    clf = DecisionTreeClassifier(class_weight=weight_data(y))

    params = {
        'max_depth': [2, 3, 5, 10, 20, 50, 100, 200],
        'min_samples_leaf': [2, 3, 5, 10],
        'criterion': ["gini", "entropy"]
    }
    time_split = TimeSeriesSplit(n_splits=5)

    tree_search = GridSearchCV(
        clf,
        param_grid=params,
        cv=time_split,
        verbose=1,
        n_jobs=4,
    )
    split = int(0.8 * len(x))
    tree_search.fit(x[:split], y[:split])
    y_pred = tree_search.predict(x[split:])
    cm = confusion_matrix(y[split:], y_pred)
    acc = sum(y_pred == y[split:])/len(y_pred)
    
    return cm, acc

In [32]:
acc = {}
cm = {}
for lag in [3, 6, 10, 16]:
    x, y = create_x_y(df, x_cols=OTHER_COLS, y_col="bin_3", lag=lag)
    cm[lag], acc[lag] = measure(x, y)

In [33]:
acc

{3: 0.5505226480836237,
 6: 0.5522648083623694,
 10: 0.5502614758861127,
 16: 0.5482558139534883}

In [34]:
cm

{3: array([[ 10, 279,  97],
        [ 16, 807, 112],
        [  7, 263, 131]], dtype=int64),
 6: array([[ 25, 257, 104],
        [ 25, 813,  97],
        [ 16, 272, 113]], dtype=int64),
 10: array([[ 13, 314,  58],
        [ 24, 850,  61],
        [ 12, 305,  84]], dtype=int64),
 16: array([[ 58, 326,   0],
        [ 50, 885,   0],
        [ 59, 342,   0]], dtype=int64)}