In [206]:

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import pandas as pd
import numpy as np
from predict import date_to_int
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler

TRAIN_RATIO = 0.8
TIMESTEP = 12


In [207]:
def create_model(input_shape):
    
    model = Sequential()
    model.add(LSTM(units=50, input_shape=input_shape, return_sequences=False))
    model.add(Dropout(0.2))
    model.add(Dense(units=1))
    model.compile(optimizer='adam',loss='mean_squared_error')
        
    return model

In [208]:
df = pd.read_csv('data/yf_attributes/AAPL.csv')
df['Date'] = [date_to_int(date) for date in df['Date']]

X_sc = MinMaxScaler()
scaled_X = X_sc.fit_transform(df)

targets = df[['Close']]
y_sc = MinMaxScaler()
scaled_y = y_sc.fit_transform(targets)

In [209]:
def create_sequences(scaled_X, scaled_y, timestep):
    X, y = [], []
    for i in range(len(scaled_X) - timestep):
        X.append(scaled_X[i:i + timestep]) #Using all data between timesteps
        y.append(scaled_y[i + timestep]) #Using next timesteps close data as target
    
    return np.array(X), np.array(y)

In [210]:

X, y = create_sequences(scaled_X, scaled_y, TIMESTEP)

split_point = int(TRAIN_RATIO * len(X))
X_train, y_train = X[:split_point], y[:split_point]
X_test, y_test = X[split_point:], y[split_point:]



In [211]:
model = create_model((X_train.shape[1], X_train.shape[2]))
model.fit(X_train, y_train, epochs=10, batch_size=32)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x1b9c832c910>

In [212]:

loss = model.evaluate(X_test, y_test)
print("Test Loss:", loss)

Test Loss: 9.633042335510254


In [201]:

last_sequence = X_test[-1].reshape(1, X_test[-1].shape[0], X_test[-1].shape[1])

scaled_prediction = model.predict(last_sequence)

prediction = y_sc.inverse_transform(scaled_prediction)

last_y = y_sc.inverse_transform([y_test[-1]])
print(last_y, prediction)




[[180.75]] [[51.9936]]


In [202]:
test_predictions = model.predict(X_test)
test_predictions = y_sc.inverse_transform(test_predictions)

train_predictions = model.predict(X_train)
train_predictions = y_sc.inverse_transform(train_predictions)



train_target_values = y_sc.inverse_transform(y_train)
test_target_values = y_sc.inverse_transform(y_test)
print(test_target_values)

[[ 26.05249977]
 [ 26.52499962]
 [ 28.26250076]
 [ 28.38500023]
 [ 27.62999916]
 [ 28.95499992]
 [ 30.33749962]
 [ 34.24750137]
 [ 35.91500092]
 [ 35.91249847]
 [ 38.18999863]
 [ 36.00500107]
 [ 37.18249893]
 [ 41.        ]
 [ 38.52999878]
 [ 42.25999832]
 [ 42.96250153]
 [ 42.30749893]
 [ 41.85749817]
 [ 44.52999878]
 [ 41.94499969]
 [ 41.31499863]
 [ 46.71749878]
 [ 46.27750015]
 [ 47.57249832]
 [ 56.90750122]
 [ 56.43500137]
 [ 54.71500015]
 [ 44.64500046]
 [ 39.43500137]
 [ 41.61000061]
 [ 43.28749847]
 [ 47.48749924]
 [ 50.16749954]
 [ 43.76750183]
 [ 49.47999954]
 [ 53.25999832]
 [ 52.18500137]
 [ 55.99250031]
 [ 62.18999863]
 [ 66.8125    ]
 [ 73.41249847]
 [ 77.37750244]
 [ 68.33999634]
 [ 63.57249832]
 [ 73.44999695]
 [ 79.48500061]
 [ 91.19999695]
 [106.26000214]
 [129.03999329]
 [115.80999756]
 [108.86000061]
 [119.05000305]
 [132.69000244]
 [131.96000671]
 [121.26000214]
 [122.15000153]
 [131.46000671]
 [124.61000061]
 [136.96000671]
 [145.86000061]
 [151.83000183]
 [141.5 

In [203]:
print(test_predictions)

[[26.872982]
 [26.569664]
 [26.546322]
 [27.161312]
 [27.323883]
 [27.659855]
 [27.813541]
 [28.17834 ]
 [28.935535]
 [29.908094]
 [30.664635]
 [31.901613]
 [32.899326]
 [33.27334 ]
 [34.228157]
 [35.05977 ]
 [35.499035]
 [36.3427  ]
 [36.92878 ]
 [37.57807 ]
 [38.48694 ]
 [38.681664]
 [38.743896]
 [38.933487]
 [39.31376 ]
 [39.599167]
 [40.766953]
 [42.08364 ]
 [43.211765]
 [43.690994]
 [43.485523]
 [42.79779 ]
 [41.77313 ]
 [41.56621 ]
 [41.56438 ]
 [41.840046]
 [41.794674]
 [42.064877]
 [42.895733]
 [43.294975]
 [44.10669 ]
 [44.959328]
 [46.334385]
 [47.67264 ]
 [48.509995]
 [50.50307 ]
 [50.540337]
 [50.654476]
 [51.226807]
 [51.758263]
 [52.204243]
 [52.386868]
 [53.49925 ]
 [54.085766]
 [53.67122 ]
 [53.489193]
 [53.89283 ]
 [55.18344 ]
 [54.856155]
 [54.90177 ]
 [54.902325]
 [54.395416]
 [54.31245 ]
 [54.164574]
 [54.607513]
 [54.10277 ]
 [53.597927]
 [53.622883]
 [53.940895]
 [53.98676 ]
 [53.53815 ]
 [54.21916 ]
 [54.80177 ]
 [53.87698 ]
 [52.49629 ]
 [52.98363 ]
 [53.62741 ]

In [204]:
print(train_target_values)

[[ 0.103237  ]
 [ 0.111607  ]
 [ 0.12611599]
 [ 0.13504501]
 [ 0.165179  ]
 [ 0.160156  ]
 [ 0.13950901]
 [ 0.165179  ]
 [ 0.149554  ]
 [ 0.154576  ]
 [ 0.178571  ]
 [ 0.180804  ]
 [ 0.247768  ]
 [ 0.3125    ]
 [ 0.28794599]
 [ 0.35379499]
 [ 0.35267901]
 [ 0.36160699]
 [ 0.36830401]
 [ 0.48214301]
 [ 0.50446397]
 [ 0.34486601]
 [ 0.29464301]
 [ 0.375     ]
 [ 0.370536  ]
 [ 0.38392901]
 [ 0.35714301]
 [ 0.36607099]
 [ 0.370536  ]
 [ 0.41294599]
 [ 0.39620501]
 [ 0.35602701]
 [ 0.386161  ]
 [ 0.34486601]
 [ 0.33593801]
 [ 0.359375  ]
 [ 0.33705401]
 [ 0.323661  ]
 [ 0.31808001]
 [ 0.348214  ]
 [ 0.426339  ]
 [ 0.36830401]
 [ 0.354911  ]
 [ 0.39732099]
 [ 0.39732099]
 [ 0.41517901]
 [ 0.395089  ]
 [ 0.31473199]
 [ 0.30357099]
 [ 0.30357099]
 [ 0.359375  ]
 [ 0.35156301]
 [ 0.36830401]
 [ 0.39955401]
 [ 0.375     ]
 [ 0.33035699]
 [ 0.25892901]
 [ 0.27455401]
 [ 0.328125  ]
 [ 0.38392901]
 [ 0.495536  ]
 [ 0.51116103]
 [ 0.60714298]
 [ 0.49107099]
 [ 0.41964301]
 [ 0.370536  ]
 [ 0.41294

In [205]:
print(train_predictions)

[[ 9.49616134e-01]
 [ 1.08987486e+00]
 [ 1.10505927e+00]
 [ 9.98364151e-01]
 [ 9.47634399e-01]
 [ 9.22981918e-01]
 [ 8.79915893e-01]
 [ 8.25222671e-01]
 [ 7.66899765e-01]
 [ 6.48789108e-01]
 [ 5.96201599e-01]
 [ 6.23968780e-01]
 [ 7.05877960e-01]
 [ 8.10663879e-01]
 [ 8.59060943e-01]
 [ 7.31945693e-01]
 [ 6.06922805e-01]
 [ 5.03245056e-01]
 [ 3.26683670e-01]
 [ 2.76831299e-01]
 [ 2.91245133e-01]
 [ 3.73453766e-01]
 [ 5.56566894e-01]
 [ 6.97529495e-01]
 [ 6.68612182e-01]
 [ 7.26372421e-01]
 [ 6.88862503e-01]
 [ 5.97567260e-01]
 [ 5.62422454e-01]
 [ 4.67455536e-01]
 [ 4.60089356e-01]
 [ 4.73461777e-01]
 [ 4.87826020e-01]
 [ 5.58731735e-01]
 [ 6.67673767e-01]
 [ 6.99965179e-01]
 [ 7.00579345e-01]
 [ 9.03570831e-01]
 [ 9.26014602e-01]
 [ 7.53498733e-01]
 [ 5.98963439e-01]
 [ 5.56809127e-01]
 [ 5.12770355e-01]
 [ 4.14904267e-01]
 [ 3.55347306e-01]
 [ 2.94081360e-01]
 [ 3.50027710e-01]
 [ 3.98735672e-01]
 [ 5.36575973e-01]
 [ 6.00653350e-01]
 [ 4.76950318e-01]
 [ 4.71802384e-01]
 [ 4.8751321