In [None]:
import numpy as np
import csv
import os
from multistep_lstm_company import MultiStepLSTMCompany
from datetime import date
import math


def get_optimal_epochs_batch_neurons_params(symbol, start_train_date, end_train_start_test_date, end_test_date, n_lags, n_seqs,
                                            n_epochs, n_batches, n_neurons, indicators, model_types):
    # This is optimising parameters for n_epochs, n_batch, and n_neurons
    # param = {"n_epochs": n_epochs, "n_batch": n_batch, "n_neurons": n_neurons}
    csv_columns = ["Company", "LSTM Type", "n_epoch", "n_neuros", "n_batch",
                   "n_lag", "n_seq", "Training Time",
                   "Indicator Number", "Indicators", "Trained Date",
                   "Start Train Date", "End Train/Start Test Date", "End Test Date",
                   "Model Name"]
    for i in range(30):
        csv_columns.append("Trend_t+" + str(i+1))
        csv_columns.append("APRE_t+" + str(i+1))
        csv_columns.append("RMSE_t+" + str(i+1))

    if not os.path.isdir("./experiments"):
        # create directory
        os.mkdir("experiments")

    filename = "./experiments/optimisation.csv"
    if not os.path.isfile(filename):
        # create new file
        with open(filename, "w", newline="") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
            writer.writeheader()

    for n_e in n_epochs:
        for n_b in n_batches:
            for n_n in n_neurons:
                for n_l in n_lags:
                    for n_s in n_seqs:
                        for m_t in model_types:
                            print("In process of training", symbol, "model type:", m_t, "n_lag:",
                                  n_l, "n_seq:", n_s, "n_epoch:", n_e, "n_batch:", n_b, "n_neurons:", n_n)
                            obj = MultiStepLSTMCompany(symbol, start_train_date, end_train_start_test_date,
                                                       end_test_date, n_lag=n_l, n_seq=n_s, n_epochs=n_e,
                                                       n_neurons=n_n, n_batch=n_b, tech_indicators=indicators,
                                                       model_type=m_t)
                            obj.train()
                            predictions = obj.predict()
                            trend_score = obj.score(metric="trend", predictions=predictions)
                            lstm_score = obj.score(metric="rmse", predictions=predictions)
                            apre_score = obj.score(metric="apre", predictions=predictions)

                            dic = {"Company": symbol,
                                   "LSTM Type": obj.model_type,
                                   "n_epoch": obj.n_epochs,
                                   "n_neuros": obj.n_neurons,
                                   "n_batch": obj.n_batch,
                                   "n_lag": obj.n_lag,
                                   "n_seq": obj.n_seq,
                                   "Training Time": obj.time_taken_to_train,
                                   "Start Train Date": obj.train_start_date_string,
                                   "End Train/Start Test Date": obj.train_end_test_start_date_string,
                                   "End Test Date": obj.test_end_date_string,
                                   "Indicator Number": len(obj.input_tech_indicators_list) + 1,
                                   "Indicators": "Share Price," + ",".join(obj.input_tech_indicators_list),
                                   "Trained Date": str(date.today()),
                                    "Model Name": obj.create_file_name()}
                            for i in range(n_s):
                                dic["Trend_t+" + str(i+1)] = trend_score[i]
                                dic["APRE_t+" + str(i+1)] = apre_score[i]
                                dic["RMSE_t+" + str(i+1)] = lstm_score[i]

                            append_dict_to_csv(filename, csv_columns, dic)

def append_dict_to_csv(csv_file_name, csv_columns, dic):
    with open(csv_file_name, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
        writer.writerow(dic)


n_lags = [5]
n_seqs = [3]
n_epochs = np.logspace(math.log(100, 10), math.log(5000, 10), num=10).astype(int)
n_neurons = np.logspace(math.log(1, 10), math.log(52, 10), num=10).astype(int)
n_batches = ["full_batch"]#, "half_batch", "online"]
#http://firsttimeprogrammer.blogspot.com/2015/09/selecting-number-of-neurons-in-hidden.html?m=1

indicators = "all"
model_types = ["vanilla", "stacked", "stacked", "bi", "cnn", "conv"] # ["vanilla", "stacked", "stacked", "bi", "cnn", "conv"] #
start_train_date = "01/01/2016"
end_train_start_test_date = "01/01/2018"
end_test_date = "01/01/2019"
pred = get_optimal_epochs_batch_neurons_params("AMZN", start_train_date, end_train_start_test_date, end_test_date, n_lags, n_seqs, n_epochs, n_batches, n_neurons, indicators, model_types)

Using TensorFlow backend.


Preprocessing the data
Retrieved price series and raw pd from disk


'supervised filtered pd '

Unnamed: 0_level_0,Share Price(t-5),AD(t-5),ADOSC(t-5),ADX(t-5),ADXR(t-5),APO(t-5),AROON(t-5),AROONOSC(t-5),BBANDS(t-5),BOP(t-5),...,TEMA(t-1),TRANGE(t-1),TRIMA(t-1),TRIX(t-1),ULTSOC(t-1),WILLR(t-1),WMA(t-1),Share Price(t),Share Price(t+1),Share Price(t+2)
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1998-06-25,-0.0254,-7.758414e+05,-3.058647e+05,-0.4751,-0.5987,0.1448,-5.0,0.0,-0.3974,-0.6992,...,0.1942,0.4400,0.0720,-0.0014,4.8707,13.0178,0.0877,0.0097,-0.0116,0.0157
1998-06-26,-0.0079,-1.246000e+05,-1.370203e+05,-0.2236,-0.6879,0.1925,-5.0,5.0,-0.2589,0.0564,...,0.2412,-0.7500,0.0670,0.0017,-2.2817,9.1715,0.1101,-0.0116,0.0157,0.0000
1998-06-29,0.0098,6.378333e+05,1.871012e+05,-0.2868,-0.5296,0.1501,-5.0,65.0,-0.0520,0.9061,...,0.1019,0.3100,0.0643,0.0028,1.4522,-10.9467,0.0696,0.0157,0.0000,0.0393
1998-06-30,0.0138,5.928161e+05,2.883635e+05,-0.6110,-0.4637,0.1456,-5.0,0.0,0.0467,-0.0315,...,0.2127,-0.1200,0.0615,0.0052,5.6022,14.2466,0.1096,0.0000,0.0393,-0.0295
1998-07-01,0.0138,1.063681e+06,4.533947e+05,-0.4366,-0.6405,0.0930,-5.0,0.0,0.1075,-0.0436,...,0.1728,-0.0600,0.0563,0.0068,3.7042,0.0000,0.0981,0.0393,-0.0295,0.0430
1998-07-02,0.0097,0.000000e+00,1.412459e+05,-0.3107,-0.4916,0.0570,-5.0,0.0,0.1540,-0.3817,...,0.4634,0.8100,0.0627,0.0117,5.9134,8.2859,0.2084,-0.0295,0.0430,0.0041
1998-07-06,-0.0116,1.140345e+04,4.336466e+03,-0.5927,-0.4525,0.0635,-5.0,0.0,0.0563,-0.3563,...,0.1318,-0.4400,0.0615,0.0121,-8.0778,-27.2943,0.1016,0.0430,0.0041,0.0647
1998-07-07,0.0157,1.008984e+06,2.653458e+05,-0.4907,-0.3936,0.0802,-5.0,0.0,0.0734,0.9430,...,0.4579,0.3100,0.0871,0.0163,6.0270,28.7263,0.2217,0.0041,0.0647,-0.0274
1998-07-08,0.0000,7.631739e+05,2.697877e+05,-0.4661,-0.3740,0.0160,10.0,-15.0,0.1120,-0.4853,...,0.4021,-0.5000,0.1138,0.0195,-0.9484,-8.9806,0.2174,0.0647,-0.0274,0.0117
1998-07-09,0.0393,2.580232e+06,8.252336e+05,0.2089,0.0495,0.0783,-5.0,75.0,0.3311,0.6119,...,0.8550,1.5700,0.1559,0.0275,5.2414,2.8417,0.3981,-0.0274,0.0117,0.0590


'test supervised values'

array([[ 0.00000000e+00,  1.66780416e+06, -6.44569745e+05, ...,
         2.97250000e+00, -2.95000000e-02,  7.84900000e-01],
       [-4.35580000e+00, -1.65669181e+05, -8.15503647e+05, ...,
        -2.95000000e-02,  7.84900000e-01,  1.93260000e+00],
       [ 2.95000000e-02,  1.43805271e+07,  3.81704073e+06, ...,
         7.84900000e-01,  1.93260000e+00, -6.37700000e-01],
       ...,
       [-5.15790000e+00, -2.79264514e+07, -7.02369781e+06, ...,
        -1.01570000e+00,  7.97000000e-02,  1.50360000e+00],
       [-4.04270000e+00, -3.56679295e+07, -1.09380715e+07, ...,
         7.97000000e-02,  1.50360000e+00,  1.79200000e-01],
       [-6.07390000e+00, -7.10506390e+07, -2.28031708e+07, ...,
         1.50360000e+00,  1.79200000e-01, -1.56629000e+01]])

Preprocessed data in  0.010910336176554363 mins
Fitting the model
train X size 4912  train X data dimension (4912, 5, 52) train y size 4912  train X data dimension (4912, 3)
Training model with batch size 4912
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_1 (LSTM)                (4912, 1)                 216       
_________________________________________________________________
dense_1 (Dense)              (4912, 3)                 6         
Total params: 222
Trainable params: 222
Non-trainable params: 0
_________________________________________________________________


New model with batch size 1 for prediction
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_2 (LSTM)                (1, 1)                    216       
_________________________________________________________________
dense_2 (Dense)      

'Len(predictions)'

251

2018-01-02     [0.024825256, 0.03713826, 0.0055663027]
2018-01-03      [0.03704686, 0.044518873, 0.026067372]
2018-01-04      [0.036134977, 0.04396819, 0.024537737]
2018-01-05     [0.039557412, 0.046034995, 0.030278686]
2018-01-08        [0.0725749, 0.06597422, 0.085663706]
2018-01-09       [0.04680986, 0.05041474, 0.042444266]
2018-01-10      [0.055514142, 0.055671245, 0.05704522]
2018-01-11        [0.07684396, 0.06855231, 0.09282483]
2018-01-12      [0.061424468, 0.059240486, 0.06695947]
2018-01-16        [0.06735689, 0.06282307, 0.07691078]
2018-01-17      [0.029721241, 0.04009494, 0.013779048]
2018-01-18       [0.06459017, 0.061152253, 0.07226977]
2018-01-19        [0.07232131, 0.06582108, 0.08523832]
2018-01-22     [0.053063013, 0.054191012, 0.052933585]
2018-01-23       [0.05948942, 0.05807191, 0.063713536]
2018-01-24       [0.068538435, 0.06353661, 0.07889276]
2018-01-25      [0.03920406, 0.045821607, 0.029685957]
2018-01-26     [0.034562934, 0.043018833, 0.021900719]
2018-01-29

'Len(actual)'

249

array([[ 2.97250e+00, -2.95000e-02,  7.84900e-01],
       [-2.95000e-02,  7.84900e-01,  1.93260e+00],
       [ 7.84900e-01,  1.93260e+00, -6.37700e-01],
       [ 1.93260e+00, -6.37700e-01, -1.96000e-02],
       [-6.37700e-01, -1.96000e-02, -3.93000e-02],
       [-1.96000e-02, -3.93000e-02,  9.71200e-01],
       [-3.93000e-02,  9.71200e-01,  1.77570e+00],
       [ 9.71200e-01,  1.77570e+00, -8.82900e-01],
       [ 1.77570e+00, -8.82900e-01,  2.85470e+00],
       [-8.82900e-01,  2.85470e+00,  1.57000e-01],
       [ 2.85470e+00,  1.57000e-01, -7.84800e-01],
       [ 1.57000e-01, -7.84800e-01, -1.43230e+00],
       [-7.84800e-01, -1.43230e+00,  3.92000e-02],
       [-1.43230e+00,  3.92000e-02, -2.76640e+00],
       [ 3.92000e-02, -2.76640e+00, -3.05100e+00],
       [-2.76640e+00, -3.05100e+00,  3.92400e-01],
       [-3.05100e+00,  3.92400e-01, -3.48260e+00],
       [ 3.92400e-01, -3.48260e+00, -9.71200e-01],
       [-3.48260e+00, -9.71200e-01,  4.51300e-01],
       [-9.71200e-01,  4.51300e

'Len(predictions)'

249

array([[ -2.07905377,  -2.24966626,  -2.66012933],
       [  2.80119315,   2.6866508 ,   2.43193333],
       [ -0.20773439,  -0.32646025,  -0.59279829],
       [  0.63266568,   0.52964124,   0.30691691],
       [  2.03119788,   2.07965073,   2.27768362],
       [ -0.73483784,  -0.80458957,  -0.9348926 ],
       [ -0.05061184,  -0.08043018,  -0.09981048],
       [  0.09172974,   0.15976819,   0.4122038 ],
       [  0.98508861,   0.98238562,   1.03832337],
       [  1.83465694,   1.85917068,   1.99070802],
       [ -1.10985921,  -1.25800995,  -1.60608121],
       [  2.8926383 ,   2.90445891,   3.00073872],
       [  0.25367138,   0.30096085,   0.49576212],
       [ -0.83443295,  -0.87549654,  -0.92611274],
       [ -1.43311184,  -1.44469243,  -1.41341389],
       [  0.10713306,   0.13706749,   0.28366184],
       [ -2.92131871,  -3.02596425,  -3.25319152],
       [ -3.24117712,  -3.3671152 ,  -3.65348653],
       [  0.43112806,   0.4434256 ,   0.54103017],
       [ -3.54994376,  -3.60170


Calculating trend score for  1
Correct counts:  112   Size of test set: 251

Calculating trend score for  2
Correct counts:  169   Size of test set: 251

Calculating trend score for  3
Correct counts:  159   Size of test set: 251


'Len(actual)'

249

array([[ 2.97250e+00, -2.95000e-02,  7.84900e-01],
       [-2.95000e-02,  7.84900e-01,  1.93260e+00],
       [ 7.84900e-01,  1.93260e+00, -6.37700e-01],
       [ 1.93260e+00, -6.37700e-01, -1.96000e-02],
       [-6.37700e-01, -1.96000e-02, -3.93000e-02],
       [-1.96000e-02, -3.93000e-02,  9.71200e-01],
       [-3.93000e-02,  9.71200e-01,  1.77570e+00],
       [ 9.71200e-01,  1.77570e+00, -8.82900e-01],
       [ 1.77570e+00, -8.82900e-01,  2.85470e+00],
       [-8.82900e-01,  2.85470e+00,  1.57000e-01],
       [ 2.85470e+00,  1.57000e-01, -7.84800e-01],
       [ 1.57000e-01, -7.84800e-01, -1.43230e+00],
       [-7.84800e-01, -1.43230e+00,  3.92000e-02],
       [-1.43230e+00,  3.92000e-02, -2.76640e+00],
       [ 3.92000e-02, -2.76640e+00, -3.05100e+00],
       [-2.76640e+00, -3.05100e+00,  3.92400e-01],
       [-3.05100e+00,  3.92400e-01, -3.48260e+00],
       [ 3.92400e-01, -3.48260e+00, -9.71200e-01],
       [-3.48260e+00, -9.71200e-01,  4.51300e-01],
       [-9.71200e-01,  4.51300e

'Len(predictions)'

249

array([[ -2.07905377,  -2.24966626,  -2.66012933],
       [  2.80119315,   2.6866508 ,   2.43193333],
       [ -0.20773439,  -0.32646025,  -0.59279829],
       [  0.63266568,   0.52964124,   0.30691691],
       [  2.03119788,   2.07965073,   2.27768362],
       [ -0.73483784,  -0.80458957,  -0.9348926 ],
       [ -0.05061184,  -0.08043018,  -0.09981048],
       [  0.09172974,   0.15976819,   0.4122038 ],
       [  0.98508861,   0.98238562,   1.03832337],
       [  1.83465694,   1.85917068,   1.99070802],
       [ -1.10985921,  -1.25800995,  -1.60608121],
       [  2.8926383 ,   2.90445891,   3.00073872],
       [  0.25367138,   0.30096085,   0.49576212],
       [ -0.83443295,  -0.87549654,  -0.92611274],
       [ -1.43311184,  -1.44469243,  -1.41341389],
       [  0.10713306,   0.13706749,   0.28366184],
       [ -2.92131871,  -3.02596425,  -3.25319152],
       [ -3.24117712,  -3.3671152 ,  -3.65348653],
       [  0.43112806,   0.4434256 ,   0.54103017],
       [ -3.54994376,  -3.60170

t+1 RMSE: 4.685704
t+2 RMSE: 4.951855
t+3 RMSE: 4.944601


'Len(actual)'

249

array([[ 2.97250e+00, -2.95000e-02,  7.84900e-01],
       [-2.95000e-02,  7.84900e-01,  1.93260e+00],
       [ 7.84900e-01,  1.93260e+00, -6.37700e-01],
       [ 1.93260e+00, -6.37700e-01, -1.96000e-02],
       [-6.37700e-01, -1.96000e-02, -3.93000e-02],
       [-1.96000e-02, -3.93000e-02,  9.71200e-01],
       [-3.93000e-02,  9.71200e-01,  1.77570e+00],
       [ 9.71200e-01,  1.77570e+00, -8.82900e-01],
       [ 1.77570e+00, -8.82900e-01,  2.85470e+00],
       [-8.82900e-01,  2.85470e+00,  1.57000e-01],
       [ 2.85470e+00,  1.57000e-01, -7.84800e-01],
       [ 1.57000e-01, -7.84800e-01, -1.43230e+00],
       [-7.84800e-01, -1.43230e+00,  3.92000e-02],
       [-1.43230e+00,  3.92000e-02, -2.76640e+00],
       [ 3.92000e-02, -2.76640e+00, -3.05100e+00],
       [-2.76640e+00, -3.05100e+00,  3.92400e-01],
       [-3.05100e+00,  3.92400e-01, -3.48260e+00],
       [ 3.92400e-01, -3.48260e+00, -9.71200e-01],
       [-3.48260e+00, -9.71200e-01,  4.51300e-01],
       [-9.71200e-01,  4.51300e

'Len(predictions)'

249

array([[ -2.07905377,  -2.24966626,  -2.66012933],
       [  2.80119315,   2.6866508 ,   2.43193333],
       [ -0.20773439,  -0.32646025,  -0.59279829],
       [  0.63266568,   0.52964124,   0.30691691],
       [  2.03119788,   2.07965073,   2.27768362],
       [ -0.73483784,  -0.80458957,  -0.9348926 ],
       [ -0.05061184,  -0.08043018,  -0.09981048],
       [  0.09172974,   0.15976819,   0.4122038 ],
       [  0.98508861,   0.98238562,   1.03832337],
       [  1.83465694,   1.85917068,   1.99070802],
       [ -1.10985921,  -1.25800995,  -1.60608121],
       [  2.8926383 ,   2.90445891,   3.00073872],
       [  0.25367138,   0.30096085,   0.49576212],
       [ -0.83443295,  -0.87549654,  -0.92611274],
       [ -1.43311184,  -1.44469243,  -1.41341389],
       [  0.10713306,   0.13706749,   0.28366184],
       [ -2.92131871,  -3.02596425,  -3.25319152],
       [ -3.24117712,  -3.3671152 ,  -3.65348653],
       [  0.43112806,   0.4434256 ,   0.54103017],
       [ -3.54994376,  -3.60170

t+1 APRE: -0.633255
t+2 APRE: 0.666891
t+3 APRE: 2.081457


(4912, 263)

array([[ 0.05959629,  0.1684831 ,  0.15942954, ...,  0.45087173,
         0.05571315,  0.16291406],
       [-0.51376539,  0.15402169,  0.15592674, ...,  0.05571315,
         0.16291406,  0.31398785],
       [ 0.06347942,  0.26875388,  0.2508574 , ...,  0.16291406,
         0.31398785, -0.0243453 ],
       ...,
       [-0.61934724, -0.06493969,  0.02870768, ..., -0.07410211,
         0.07008734,  0.25751782],
       [-0.47255148, -0.1260001 , -0.05150614, ...,  0.07008734,
         0.25751782,  0.0831847 ],
       [-0.73992194, -0.40507894, -0.29464718, ...,  0.25751782,
         0.0831847 , -2.00213902]])

(4912,)

date
2018-01-02     2.9725
2018-01-03    -0.0295
2018-01-04     0.7849
2018-01-05     1.9326
2018-01-08    -0.6377
2018-01-09    -0.0196
2018-01-10    -0.0393
2018-01-11     0.9712
2018-01-12     1.7757
2018-01-16    -0.8829
2018-01-17     2.8547
2018-01-18     0.1570
2018-01-19    -0.7848
2018-01-22    -1.4323
2018-01-23     0.0392
2018-01-24    -2.7664
2018-01-25    -3.0510
2018-01-26     0.3924
2018-01-29    -3.4826
2018-01-30    -0.9712
2018-01-31     0.4513
2018-02-01     0.3433
2018-02-02    -7.1418
2018-02-05    -3.9339
2018-02-06     6.4159
2018-02-07    -3.4238
2018-02-08    -4.3067
2018-02-09     1.8542
2018-02-12     6.2053
2018-02-13     1.6055
               ...   
2018-11-15     4.5903
2018-11-16     2.1110
2018-11-19    -7.6373
2018-11-20    -8.8421
2018-11-21    -0.1992
2018-11-23    -4.4708
2018-11-26     2.3201
2018-11-27    -0.3784
2018-11-28     6.6714
2018-11-29    -1.3841
2018-11-30    -0.9658
2018-12-03     6.2133
2018-12-04    -8.0953
2018-12-06    -1.9616
2018-