In [2]:
import pandas as pd
import numpy as np
#import torch
try:
    from PyEMD import EEMD
except:
    !pip install EMD-signal
    from PyEMD import EEMD
    
from ta import add_all_ta_features

In [92]:
def get_data(currency_lst,
             frequency, 
             window_size,
             neutral_quantile = 0.25,
             beg_date = pd.Timestamp(2013,1,1),
             end_date = pd.Timestamp.now(),
             log_price = True, 
             include_indicators = True,
             include_imfs = True):
        
        currency_dfs = {}
        for cur in currency_lst: 
            currency_dfs[cur] = pd.read_csv(f"../data/0_raw/Binance/{str.lower(cur)}_usdt_1d.csv", index_col=0)        
        
        for cur, df in currency_dfs.items():
           
            if include_indicators:
                df = add_all_ta_features(df, 
                                         open="Open", high="High", low="Low", close="Close", volume="Volume", 
                                         fillna=True).reset_index()
            else:
                df.drop("Volume", axis=1, inplace=True)
            
            df.Date = df.Date.apply(pd.Timestamp)
            df.sort_values("Date", ascending=True, inplace=True)
            df.set_index("Date", inplace=True)
            df.drop(["Open", "High", "Low"], axis=1, inplace=True)
            df.rename(str.lower, axis=1, inplace=True)
            
            if log_price:
                df["close"] = df["close"].apply(np.log)
           
            price_diff = df["close"].diff().dropna()

            rolling_quantiles = price_diff.abs().rolling(window_size).quantile(neutral_quantile)
            rolling_quantiles.dropna(inplace=True)
            conditions = [(price_diff[window_size-1:] < 0) & (price_diff[window_size-1:].abs() > rolling_quantiles),
                          (price_diff[window_size-1:] > 0) & (price_diff[window_size-1:].abs() > rolling_quantiles)]

            #1 is decrease, 2 is decrease
            classes = [1,2] #0 is the default class if none of conditions is met

            y = pd.DataFrame(np.select(conditions, classes, default=0), index=price_diff[window_size-1:].index)
            df.insert(loc=0, column="change_dir", value=y)
      
            if include_imfs:
                eemd = EEMD()
                imfs = eemd(df[PRICE_TYPE].values)
                imf_features = ["imf_"+str(i) for i in range(imfs.shape[0])]
                df = pd.concat((df, pd.DataFrame(imfs.T, columns=imf_features, index=df.index)), axis=1)
                
            currency_dfs[cur] = df

        min_dates = [df.index.min() for cur, df in currency_dfs.items()]
        max_dates = [df.index.max() for cur, df in currency_dfs.items()]
        beg_date = max([max(min_dates), beg_date])
        end_date = min([min(max_dates), end_date])
        common_range = pd.date_range(beg_date, end_date, freq=frequency)
        
        arr = np.array([currency_dfs[cur].loc[common_range].values for cur in currency_lst])
        features = df.columns.tolist()
        
        return arr, y, features, currency_dfs

In [90]:
CURRENCY_LST = ['BTC', 'ETH', 'LTC']
PRICE_TYPE = 'close'
FREQUENCY = "D"
WINDOW_SIZE = 50
NEUTRAL_QUANTILE = 0.25

In [93]:
arr, y, features, dfs = get_data(CURRENCY_LST,
                                 FREQUENCY, 
                                 WINDOW_SIZE,
                                 neutral_quantile = NEUTRAL_QUANTILE,
                                 log_price=True,
                                 include_indicators = False,
                                 include_imfs = False
                                )

In [96]:
arr.shape

(3, 1242, 2)

In [97]:
dfs['BTC'].iloc[WINDOW_SIZE-1:WINDOW_SIZE+1]

Unnamed: 0_level_0,change_dir,close
Date,Unnamed: 1_level_1,Unnamed: 2_level_1
2017-10-05,,8.364608
2017-10-06,2.0,8.382289


In [100]:
N_CURRENCIES = 1
INPUT_FEATURE_SIZE = 1
WINDOW_SIZE = 50
TRAIN_PERCENTAGE, VAL_PERCENTAGE, TEST_PERCENTAGE = 0.70, 0.15, 0.15

In [101]:
class TimeSeriesDataset():
    def __init__(self, 
                 x: np.ndarray, 
                 data_use_type,
                 train_percentage = TRAIN_PERCENTAGE,
                 val_percentage = VAL_PERCENTAGE,
                 test_percentage = TEST_PERCENTAGE,
                 seq_len = WINDOW_SIZE, 
                 ):
        
        self.x = torch.tensor(x).float()
        self.seq_len = seq_len
        
        self.data_use_type = data_use_type
        
        #self.train_size = int(len(self.x) * train_percentage)
        self.val_size = int(len(self.x) * val_percentage)
        self.test_size = int(len(self.x) * test_percentage)
        self.train_size = len(self.x) - self.val_size - self.test_size 
        
    def __len__(self):
        
        if self.data_use_type == "train":
            return self.train_size - self.seq_len
        
        if self.data_use_type == "val":
            return self.val_size 
        else:
            return self.test_size
        
    def __getitem__(self, index):
        
        if self.data_use_type =="val":
            index = self.train_size + index - self.seq_len
            
        elif self.data_use_type =="test":
            index = self.train_size + self.val_size + index - self.seq_len
        
        window = self.x[index:index+self.seq_len, 1]
        price_change = self.x[index+self.seq_len, 0]
        
        return (window, price_change)

In [102]:
a = TimeSeriesDataset(arr[0], 'train')
b = TimeSeriesDataset(arr[0], 'val')
c = TimeSeriesDataset(arr[0], 'test')

In [103]:
len(a) + len(b) + len(c) + WINDOW_SIZE

1242

In [104]:
a[0]

(array([9.68242246, 9.7010641 , 9.77222957, 9.85758287, 9.84479962,
        9.8445997 , 9.75818429, 9.71044756, 9.64812035, 9.49751807,
        9.49551931, 9.51044496, 9.52510294, 9.66071575, 9.64601068,
        9.55973659, 9.57351713, 9.42867317, 9.52634456, 9.50151633,
        9.59390814, 9.61042503, 9.61976696, 9.7386359 , 9.74506551,
        9.68967719, 9.60928695, 9.57498349, 9.60959222, 9.49090568,
        9.52806729, 9.56170122, 9.50859065, 9.51339838, 9.29651807,
        9.30463094, 9.30218729, 9.34792429, 9.45719576, 9.35270761,
        9.28359548, 9.28722548, 9.33697214, 9.32145858, 9.3137089 ,
        9.3493194 , 9.38260738, 9.32821229, 9.2338137 , 9.23845152]),
 1.0)