Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 108 additions & 57 deletions 1_HomePage.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
from datetime import datetime
import time

import pandas as pd
from BatchProcess.DataSource.YahooFinance.YahooFinances_Services import YahooFinance
from BatchProcess.DataSource.ListSnP500.ListSnP500Collect import ListSAndP500
from BatchProcess.BatchProcess import BatchProcessManager
from multiprocessing.pool import ThreadPool
import plotly.graph_objects as go
from dotenv import load_dotenv
from pathlib import Path
import streamlit as st
import pandas as pd
import time
import os

from Database.PostGreSQLInteraction import StockDatabaseManager

load_dotenv(override=True)
pool = ThreadPool(processes=6)
load_dotenv(override=True)

current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd()
css_file = current_dir / os.getenv("CSS_DIR")
defaut_start_date = "2014-01-01"

st.set_page_config(page_title="Home Page", page_icon=":house:")
st.set_page_config(page_title="Home Page", page_icon=":house:",
initial_sidebar_state="collapsed")
st.sidebar.header("Quantitative Trading Project")
st.title("Welcome to the Home Page")
st.markdown(
Expand All @@ -36,70 +34,123 @@
st.markdown("<style>{}</style>".format(f.read()), unsafe_allow_html=True)


# --- CACHE DATA ---
@st.cache_data(ttl=1800)
def retrieve_list_ticket():
list_of_symbols__ = ListSAndP500().tickers_list
list_of_symbols__ = BatchProcessManager().get_stock_list_in_database()
if list_of_symbols__ is None or len(list_of_symbols__) < 497:
list_of_symbols__ = ListSAndP500().tickers_list
return list_of_symbols__


PROCESS_TIME = 90 # seconds
_list_of_symbols = retrieve_list_ticket()


@st.cache_data(ttl=1800)
def retrieve_data_from_yahoo(list_of_symbols, date_from, date_to):
transformed_data = YahooFinance(list_of_symbols, date_from, date_to)
return transformed_data.process_data()
def batch_process(list_of_symbols__):
return BatchProcessManager().run_process(list_of_symbols__)


@st.cache_data(ttl=1800)
def update_datebase_func(list_of_symbols=retrieve_list_ticket(), date_from=defaut_start_date, date_to=datetime.now().strftime('%Y-%m-%d')):
st.write("Database Updated")
# Retrieve Data from yahoo finance
async_result = pool.apply_async(
retrieve_data_from_yahoo, args=(list_of_symbols, date_from, date_to,))
bar = st.progress(0)
per = PROCESS_TIME / 100
for i in range(100):
time.sleep(per)
bar.progress(i + 1)
df = async_result.get()
return df
def batch_process_retrieve_data_by_stock(the_stock_in):
return BatchProcessManager().get_stock_data_by_ticker(the_stock_in)


@st.cache_data(ttl=1800)
def process_data_retrieve_from_database(df_in, list_of_symbols__):
total_data_dict = dict()
for i in range(len(list_of_symbols__)):
filtered_data = df_in[df_in['stock_id'] == list_of_symbols__[i]]
filtered_data = filtered_data.reset_index()
total_data_dict[list_of_symbols__[i]] = filtered_data
return total_data_dict


update_database = st.button("Update Database")
if update_database:
df_historical_yahoo = update_datebase_func(
list_of_symbols=_list_of_symbols)
total_data_dict_ = process_data_retrieve_from_database(
df_historical_yahoo, _list_of_symbols)
db_manager = StockDatabaseManager()
db_manager.create_schema_and_tables(_list_of_symbols)
for key, value in total_data_dict_.items():
if isinstance(value, pd.DataFrame):
db_manager.insert_data(key, value)
all_data = db_manager.fetch_all_data()
for table, df in all_data.items():
st.write(f"Data for table {table}:")
st.write(df.head(10))
db_manager.close_connection()
st.write("Done")
# st.write(total_data_dict_)
PROCESS_TIME = 180 # seconds
_list_of_symbols = retrieve_list_ticket()


# --- MAIN PAGE ---
if "stock_data" not in st.session_state:
st.session_state.stock_data = None
st.markdown('---')
st.markdown("### I. Retrieve stock data from database if available")

the_stock = st.selectbox(
"Select the stock you want to retrieve from database", _list_of_symbols)
btn_prepare = st.button("Retrieve stock data from database...")

if btn_prepare:
st.session_state.stock_data = the_stock
# df = batch_process_retrieve_data_by_stock(the_stock)
# df = pd.DataFrame(df)
# if df is not None:
# st.write(df)
# st.write("Done")
# else:
# st.write("No data found for this stock, please update the database first.")

st.markdown('---')
# --- TABS ---
st.markdown(
"### II. List of 500 S&P, Historical data, In Day Data, Top News, Reddit News")
List500, Historical_data, IndayData_RealTime, news, reddit_news = st.tabs(
["List 500 S&P", "Historical data", "In Day Data", "Top News", "Reddit News"])

# --- TABS LIST500 S&P CONTENT---
with List500:
st.write("List of 500 S&P")
st.write(_list_of_symbols)

# --- TABS HISTORICAL DATA CONTENT---
with Historical_data:
if st.session_state.stock_data is not None:
df = batch_process_retrieve_data_by_stock(st.session_state.stock_data)
df = pd.DataFrame(df)
if df is not None:
fig = go.Figure(data=[go.Candlestick(x=df['date'],
open=df['open'],
high=df['high'],
low=df['low'],
close=df['close'])])
# Add a title
fig.update_layout(
title=f"{st.session_state.stock_data} Price Candlestick Chart",
# Center the title
title_x=0.3,

# Customize the font and size of the title
title_font=dict(size=24, family="Arial"),

# Set the background color of the plot
plot_bgcolor='white',

# Customize the grid lines
xaxis=dict(showgrid=True, gridwidth=1, gridcolor='lightgray'),
yaxis=dict(showgrid=True, gridwidth=1, gridcolor='lightgray'),
)

# Add a range slider and customize it
fig.update_layout(
xaxis_rangeslider_visible=True, # Show the range slider

# Customize the range slider's appearance
xaxis_rangeslider=dict(
thickness=0.1, # Set the thickness of the slider
bordercolor='black', # Set the border color
borderwidth=1, # Set the border width
)
)

# Display the chart in Streamlit
st.plotly_chart(fig)
st.markdown(
f"#### Dataframe of {st.session_state.stock_data} Prices")
st.write(df)
else:
st.write(
"No data found for this stock, please update the database first.")
else:
st.write("Please select the stock to retrieve the data")

st.markdown('---')
# --- Set Up/ Update all data in database---
st.markdown("### III. Set Up data in database for the first time")
update_database = st.button("Update Database")
if update_database:
async_result = pool.apply_async(
batch_process, args=(_list_of_symbols,))
bar = st.progress(0)
per = PROCESS_TIME / 100
for i in range(100):
time.sleep(per)
bar.progress(i + 1)
df_dict = async_result.get()
st.write("Please check the data in the database")
77 changes: 77 additions & 0 deletions BatchProcess/BatchProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from Database.PostGreSQLInteraction import DatabaseManager, StockDatabaseManager, TicketDimDatabaseManager, RedditNewsDatabaseManager
from BatchProcess.DataSource.YahooFinance.YahooFinances_Services import YahooFinance
from BatchProcess.DataSource.ListSnP500.ListSnP500Collect import ListSAndP500
from datetime import datetime
import pandas as pd


defaut_start_date = "2014-01-01"

date_to = datetime.now().strftime('%Y-%m-%d')


class BatchProcessManager:
def __init__(self):
self.list_of_symbols = None
self.dict_ticket = dict()

def run_process(self, list_of_symbols_):
self.list_of_symbols = list_of_symbols_

# Get data from Yahoo Finance
transformed_data = YahooFinance(
self.list_of_symbols, defaut_start_date, date_to)
df = transformed_data.process_data()

# Create Database Manager
db_manager = DatabaseManager()

# Drop all tables exist in the database
db_manager.delete_schema()

# Create Stock table
db_manager.StockDatabaseManager.create_schema_and_tables(
self.list_of_symbols)

# Create TicketDim table
db_manager.TicketDimDatabaseManager.create_table()

# Create RedditNews table
db_manager.RedditNewsDatabaseManager.create_schema_and_tables()

# Apply multiprocessing to insert data into the database (Testing later)
for i in range(len(self.list_of_symbols)):
filtered_data = df[df['stock_id'] == self.list_of_symbols[i]]
filtered_data = filtered_data.reset_index()
self.dict_ticket[self.list_of_symbols[i]] = filtered_data

# Insert data into the database Stock table
for key, value in self.dict_ticket.items():
if isinstance(value, pd.DataFrame):
db_manager.StockDatabaseManager.insert_data(key, value)

# Insert data into the database TicketDim table
db_manager.TicketDimDatabaseManager.insert_data(self.list_of_symbols)

db_manager.close_connection()
return self.dict_ticket

def get_stock_data_by_ticker(self, ticker):
try:
db_manager = StockDatabaseManager()
data = db_manager.get_data_by_table(ticker)
db_manager.close_connection()
return data
except Exception as e:
print(e)
return None

def get_stock_list_in_database(self):
try:
db_manager = TicketDimDatabaseManager()
data = db_manager.get_data()
db_manager.close_connection()
return data
except Exception as e:
print(e)
return None
8 changes: 7 additions & 1 deletion Database/.env
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,10 @@ DATABASE_PORT="5432"
DATABASE_NAME="postgres"
DATABASE_USER="postgres"
DATABASE_PASSWORD="admin"
CREATE_SCHEMA_QUERY="CREATE SCHEMA IF NOT EXISTS tickets"
CREATE_SCHEMA_QUERY="CREATE SCHEMA IF NOT EXISTS tickets;"

INSERT_QUERY_REDDIT_TABLE="INSERT INTO reddits.stock_reddit_news (id, subreddit, url, title, score, num_comments, downvotes, ups, date_created_utc) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO NOTHING;"
POSTGRE_CONNECTION="dbname=postgres user=postgres host=localhost password=admin"
CONFIGURE_REDDIT_TABLE = "CREATE INDEX IF NOT EXISTS idx_stock_reddit_news_id ON reddits.stock_reddit_news(id);"
CREATE_REDDIT_TABLE_QUERY = "CREATE TABLE IF NOT EXISTS reddits.stock_reddit_news (id VARCHAR PRIMARY KEY, subreddit VARCHAR, url VARCHAR, title TEXT, score TEXT, num_comments TEXT, downvotes TEXT, ups TEXT, date_created_utc TEXT);"
CREATE_REDDIT_SCHEMA_QUERY="CREATE SCHEMA IF NOT EXISTS reddits;"
Loading