In [10]:
import IPython
ipython = IPython.get_ipython()
ipython.run_line_magic('load_ext', 'autoreload')
ipython.run_line_magic('autoreload', '2')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from dotenv import load_dotenv
import sys, os, time

import pandas as pd
import numpy as np

from modules.database import DB
from modules.data_fetcher import DataFetcher
from modules.trading_env import TradingEnv

In [12]:
# Connect to database
load_dotenv()
DB_NAME = os.getenv("POSTGRES_DB")
DB_USER = os.getenv("POSTGRES_USER")
DB_PASS = os.getenv("POSTGRES_PASSWORD")
DB_HOST = os.getenv("POSTGRES_HOST")
DB_PORT = os.getenv("POSTGRES_PORT")


db = DB(DB_NAME, DB_USER, DB_PASS, DB_HOST, DB_PORT)
res = db.list_tables()
print(res)

[('stock_data',)]


In [13]:
# Fetch data
data_fetcher = DataFetcher(db)

# See format of data
test_ticker = "AAPL"
start_date = pd.to_datetime("2020-01-01")
end_date = pd.to_datetime("2021-01-01")
df = data_fetcher.fetch_stock_data(test_ticker, start_date, end_date)

print(df.head())
print(df.columns)

# Create table
stock_data_table_name = "stock_data"
db.drop_table(stock_data_table_name)
time.sleep(1)
print("Creating table") 
db.create_table_from_df(stock_data_table_name, df)

[*********************100%***********************]  1 of 1 completed

        Date Ticker      Close       High        Low       Open     Volume
0 2020-01-02   AAPL  72.716072  72.776598  71.466812  71.721019  135480400
1 2020-01-03   AAPL  72.009132  72.771760  71.783977  71.941343  146322800
2 2020-01-06   AAPL  72.582901  72.621639  70.876068  71.127858  118387200
3 2020-01-07   AAPL  72.241554  72.849231  72.021238  72.592601  108872000
4 2020-01-08   AAPL  73.403648  73.706279  71.943759  71.943759  132079200
Index(['Date', 'Ticker', 'Close', 'High', 'Low', 'Open', 'Volume'], dtype='object')





Creating table


In [14]:
# Ingest data
tickers = ["AAPL", "MSFT", "GOOGL", "AMZN", "TSLA"]
start_date = pd.to_datetime("2015-01-01")
end_date = pd.to_datetime("2022-01-01")

for ticker in tickers:
    data_fetcher.ingest_stock_data_to_db(stock_data_table_name, ticker, start_date, end_date)    

[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed


In [15]:
# Add RSI
for ticker in tickers:
    db.add_rsi_column(stock_data_table_name, ticker, 14)


In [16]:
# Test fetch
import time
before = time.time()
data = db.get_ticker_data(stock_data_table_name, "MSFT", limit=2000)
after = time.time()
print(f"Time taken: {after - before}")
print(data.head(25))
print(data.tail())

Time taken: 0.004533529281616211
         date ticker      close       high        low       open     volume  \
0  2015-01-02   MSFT  40.152470  40.719207  39.963559  40.066602   27913900   
1  2015-01-05   MSFT  39.783249  40.126724  39.714552  39.817594   39673900   
2  2015-01-06   MSFT  39.199333  40.143894  39.104876  39.826179   36447900   
3  2015-01-07   MSFT  39.697384  39.894883  39.061953  39.482711   29114100   
4  2015-01-08   MSFT  40.865196  41.002587  40.118134  40.143894   29645200   
5  2015-01-09   MSFT  40.521721  41.062698  40.272702  40.882374   23944200   
6  2015-01-12   MSFT  40.015083  40.822257  39.808999  40.719211   23651900   
7  2015-01-13   MSFT  39.809013  41.139987  39.551406  40.332817   35270600   
8  2015-01-14   MSFT  39.465527  39.705963  39.173571  39.465527   29719600   
9  2015-01-15   MSFT  39.053356  39.826181  38.993248  39.688790   32750800   
10 2015-01-16   MSFT  39.705959  39.740305  38.787155  38.907375   35695300   
11 2015-01-20   MSF

In [None]:
db.drop_nans(stock_data_table_name)
data = db.get_ticker_data(stock_data_table_name, "MSFT", limit=2000)
print(f"Time taken: {after - before}")
print(data.head(25))

In [None]:
window_size = 10
simulation_length = 60
n_envs = 3
headers = ["close", "rsi"]

te = TradingEnv(db, stock_data_table_name, tickers, window_size, simulation_length, n_envs, headers)

