In [3]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import pandas_ta as ta
import getData

In [5]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class PriceHistoryDataset(Dataset):
    def __init__(self, dataset, to_predict=['Open', 'High', 'Low', 'Close']):
        y = dataset['y'][:,:,self.__map_to_indices(to_predict)]
        x = dataset['x']
        self.columns = dataset['columns']
        self.initial_price = dataset['initial price']
        self.current_date = dataset['current date']
        
        self.X = torch.from_numpy(x).float()
        self.y = torch.from_numpy(y).float()
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
    def __map_to_indices(self, args):
        mapping = {'Open': 0, 'High': 1, 'Low': 2, 'Close': 3}
        return [mapping[arg] for arg in args]

In [36]:
v_preprocess_param = {
    'win_size':22,
    'stride':1,
    'split':False,
    'number_y':1,
    'random_state':420,
}
to_predict = ['Close']

tickers = 'BTC-USD'

prices_df_val = getData.loader(tickers=tickers, interval="1d", start='2023-01-01').dataframe

val_sets = getData.preprocessor(prices_df_val, preprocess_param=v_preprocess_param).dataset

val_set = PriceHistoryDataset(val_sets, to_predict)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

In [24]:
import torch.nn as nn
import pytorch_lightning as pl

class LSTMModel(pl.LightningModule):

    def __init__(self, hidden_size, lstm_layers, head_layers, input_size=8, output_size=3, dropout=0.05):
        super(LSTMModel, self).__init__()
        
        self.gru = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=lstm_layers)
        
        self.linears = nn.ModuleList([
            nn.Linear(hidden_size, hidden_size) for _ in range(head_layers-1)
        ])
        
        self.out_linear = nn.Linear(hidden_size, output_size)
        
        # keep track of losses function.
        self.train_losses = []
        self.test_losses = []
        self.loss_func = nn.L1Loss()
        
        
    def forward(self, x):
        lstm_out, _ = self.gru(x)
        o = lstm_out[:,-1:,:]
        
        for linear in self.linears:
            o = linear(o)
        
        output = self.out_linear(o)
        return output


    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_func(y, y_hat)#.mean()
        self.train_losses.append(loss)
        return loss

    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_func(y, y_hat)#.mean()
        self.test_losses.append(loss)
        return loss
    
    
    def on_test_epoch_end(self):
        avg_loss = torch.stack(self.test_losses).mean()
        print(f'Test Loss: {avg_loss}')
        return {'L1_loss': avg_loss}
    
    
    def on_train_epoch_end(self):
        avg_loss = torch.stack(self.train_losses).mean()
        print(f'Train Loss: {avg_loss}')
        return {'L1_loss': avg_loss}
    
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)


# Initialize the model and trainer
model = LSTMModel(output_size=1, hidden_size=128, lstm_layers=5, head_layers=2, dropout=0.0)

checkpoint = torch.load("model/BTC/checkpoints/epoch=1999-step=18000.ckpt")
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [37]:
x = val_loader.dataset.X
y = val_loader.dataset.y
ref = val_loader.dataset.initial_price
date = val_loader.dataset.current_date

with torch.no_grad():
    y_hat = model(x)
    
predict_out = np.multiply( (np.array(y_hat)+1).flatten(), ref)


import plotly.graph_objects as go

df = prices_df_val.iloc[200:]

fig = go.Figure(data=[
        go.Candlestick(
            x=df['Date'],
            open=df['Open'],
            high=df['High'],
            low=df['Low'],
            close=df['Close'],
            name='Actual Price'
        ),
        go.Scatter(
            x=date,
            y=predict_out,
            line=dict(color='blue'),
            name='Predicted Price'
            )
    ]).update_layout(title_text=tickers+' price predictions', title_x=0.3)

fig.show()