In [65]:
#!/usr/bin/env python
# coding: utf-8
import pandas as pd
import numpy as np
from datetime import date
class TimeseriesTestTrainSplit:
    def __init__(self):
        pass
    @classmethod
    def timeseries_test_train_split(cls, stock_name='GSIT'):
        stocks_df = cls.load_df()
        data = cls.get_data(stocks_df, stock_name)
        train_length = cls.get_train_length(data)
        training_set, test_set = cls.get_train_test_split(data, train_length)
        X_train, y_train = cls.get_x_y_split(training_set)
        X_test, y_test = cls.get_x_y_split(test_set)
        return X_train, y_train, X_test, y_test
    def load_df():
        stocks_df = pd.read_pickle('./data/stocks_df_{}.pickle'.format(date.today()))
        stocks_df.sort_values(by=['Date'], inplace=True)
        return stocks_df
    def get_data(stocks_df, stock_name):
        data = stocks_df[stocks_df.company_name==stock_name]['Adj Close']
        return data
    def get_train_length(data):
        train_length = int(np.ceil(len(data.values)*0.8))
        if len(data.values) - train_length < 63:
            train_length = len(data.values) - 63
        return train_length
    def get_train_test_split(data, train_length):
        training_set = data.iloc[:train_length].values
        test_set = data.iloc[train_length:].values
        return training_set, test_set
    def get_x_y_split(data_set):
        # Creating a data structure with 60 time-steps and 1 output
        length=len(data_set)
        X = []
        y = []
        for i in range(60, length):
            X.append(data_set[i-60:i])
            y.append(data_set[i])
        X, y = np.array(X), np.array(y)
        X.reshape(-1, 1)
#         X = np.reshape(X, (X.shape[0], X.shape[1]))
        return X, y

In [68]:
stocks_df = TimeseriesTestTrainSplit.load_df()

In [69]:
name = 'GSIT'
data = TimeseriesTestTrainSplit.get_data(stocks_df,stock_name = name)
data

Date
2016-02-18    3.500
2016-02-19    3.450
2016-02-22    3.410
2016-02-23    3.400
2016-02-24    3.440
              ...  
2021-02-11    8.010
2021-02-12    7.850
2021-02-16    7.760
2021-02-17    7.660
2021-02-18    7.665
Name: Adj Close, Length: 1260, dtype: float64

In [70]:
train_length = TimeseriesTestTrainSplit.get_train_length(data)
train_length

1008

In [71]:
ttsplit = TimeseriesTestTrainSplit.get_train_test_split(data, train_length)
len(ttsplit[1])

252

In [74]:
TimeseriesTestTrainSplit.get_x_y_split(ttsplit[0])

(array([[3.5       , 3.45000005, 3.41000009, ..., 3.67000008, 3.83999991,
         3.83999991],
        [3.45000005, 3.41000009, 3.4000001 , ..., 3.83999991, 3.83999991,
         3.68000007],
        [3.41000009, 3.4000001 , 3.44000006, ..., 3.83999991, 3.68000007,
         3.67000008],
        ...,
        [7.53999996, 7.53999996, 7.40999985, ..., 7.76999998, 7.88999987,
         7.78999996],
        [7.53999996, 7.40999985, 7.32999992, ..., 7.88999987, 7.78999996,
         7.76000023],
        [7.40999985, 7.32999992, 7.30000019, ..., 7.78999996, 7.76000023,
         7.63000011]]),
 array([3.68000007, 3.67000008, 3.71000004, 3.76999998, 3.77999997,
        3.76999998, 3.9000001 , 3.78999996, 3.79999995, 3.80999994,
        3.83999991, 4.        , 4.03000021, 4.05000019, 4.09000015,
        4.19999981, 4.17999983, 4.19000006, 4.15999985, 4.07999992,
        4.05000019, 3.97000003, 3.98000002, 3.93000007, 3.94000006,
        4.        , 4.19999981, 4.1500001 , 4.26000023, 3.95000005,
 

In [75]:
TimeseriesTestTrainSplit.timeseries_test_train_split()

(array([[3.5       , 3.45000005, 3.41000009, ..., 3.67000008, 3.83999991,
         3.83999991],
        [3.45000005, 3.41000009, 3.4000001 , ..., 3.83999991, 3.83999991,
         3.68000007],
        [3.41000009, 3.4000001 , 3.44000006, ..., 3.83999991, 3.68000007,
         3.67000008],
        ...,
        [7.53999996, 7.53999996, 7.40999985, ..., 7.76999998, 7.88999987,
         7.78999996],
        [7.53999996, 7.40999985, 7.32999992, ..., 7.88999987, 7.78999996,
         7.76000023],
        [7.40999985, 7.32999992, 7.30000019, ..., 7.78999996, 7.76000023,
         7.63000011]]),
 array([3.68000007, 3.67000008, 3.71000004, 3.76999998, 3.77999997,
        3.76999998, 3.9000001 , 3.78999996, 3.79999995, 3.80999994,
        3.83999991, 4.        , 4.03000021, 4.05000019, 4.09000015,
        4.19999981, 4.17999983, 4.19000006, 4.15999985, 4.07999992,
        4.05000019, 3.97000003, 3.98000002, 3.93000007, 3.94000006,
        4.        , 4.19999981, 4.1500001 , 4.26000023, 3.95000005,
 