In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import pandas_ta as ta
import getData

In [2]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class PriceHistoryDataset(Dataset):
    def __init__(self, dataset, to_predict=['Open', 'High', 'Low', 'Close']):
        y = dataset['y'][:,:,self.__map_to_indices(to_predict)]
        x = dataset['x']
        self.columns = dataset['columns']
        self.initial_price = dataset['initial price']
        self.current_date = dataset['current date']
        
        self.X = torch.from_numpy(x).float()
        self.y = torch.from_numpy(y).float()
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
    def __map_to_indices(self, args):
        mapping = {'Open': 0, 'High': 1, 'Low': 2, 'Close': 3}
        return [mapping[arg] for arg in args]

In [3]:
v_preprocess_param = {
    'win_size':21,
    'stride':1,
    'split':False,
    'number_y':0,
    'random_state':420,
}
to_predict = ['Close']

tickers = 'BTC-USD'

prices_df_val = getData.loader(tickers=tickers, interval="1d", start='2018-01-01').dataframe

val_sets = getData.preprocessor(prices_df_val, preprocess_param=v_preprocess_param).dataset

val_set = PriceHistoryDataset(val_sets, to_predict)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False)

In [4]:
from model.LSTM_BTC.load_model import model



In [6]:
x = val_loader.dataset.X
y = val_loader.dataset.y
ref = val_loader.dataset.initial_price
date = val_loader.dataset.current_date
date = [ d + pd.DateOffset(1) for d in date ]

with torch.no_grad():
    y_hat = model(x)
    
predict_out = np.multiply( (np.array(y_hat)+1).flatten(), ref)


import plotly.graph_objects as go

df = prices_df_val

fig = go.Figure(data=[
        go.Candlestick(
            x=df['Date'],
            open=df['Open'],
            high=df['High'],
            low=df['Low'],
            close=df['Close'],
            name='Actual Price'
        ),
        go.Scatter(
            x=date,
            y=predict_out,
            line=dict(color='blue'),
            name='Predicted close Price'
            )
    ]).update_layout(title_text=tickers+' price predictions', title_x=0.3, height=700)

fig.show()