diff --git a/1_HomePage.py b/1_HomePage.py index 3dd0530..3487cd2 100644 --- a/1_HomePage.py +++ b/1_HomePage.py @@ -1,25 +1,23 @@ -from datetime import datetime -import time - -import pandas as pd -from BatchProcess.DataSource.YahooFinance.YahooFinances_Services import YahooFinance from BatchProcess.DataSource.ListSnP500.ListSnP500Collect import ListSAndP500 +from BatchProcess.BatchProcess import BatchProcessManager from multiprocessing.pool import ThreadPool +import plotly.graph_objects as go from dotenv import load_dotenv from pathlib import Path import streamlit as st +import pandas as pd +import time import os -from Database.PostGreSQLInteraction import StockDatabaseManager - -load_dotenv(override=True) pool = ThreadPool(processes=6) +load_dotenv(override=True) current_dir = Path(__file__).parent if "__file__" in locals() else Path.cwd() css_file = current_dir / os.getenv("CSS_DIR") defaut_start_date = "2014-01-01" -st.set_page_config(page_title="Home Page", page_icon=":house:") +st.set_page_config(page_title="Home Page", page_icon=":house:", + initial_sidebar_state="collapsed") st.sidebar.header("Quantitative Trading Project") st.title("Welcome to the Home Page") st.markdown( @@ -36,70 +34,123 @@ st.markdown("".format(f.read()), unsafe_allow_html=True) +# --- CACHE DATA --- @st.cache_data(ttl=1800) def retrieve_list_ticket(): - list_of_symbols__ = ListSAndP500().tickers_list + list_of_symbols__ = BatchProcessManager().get_stock_list_in_database() + if list_of_symbols__ is None or len(list_of_symbols__) < 497: + list_of_symbols__ = ListSAndP500().tickers_list return list_of_symbols__ -PROCESS_TIME = 90 # seconds -_list_of_symbols = retrieve_list_ticket() - - @st.cache_data(ttl=1800) -def retrieve_data_from_yahoo(list_of_symbols, date_from, date_to): - transformed_data = YahooFinance(list_of_symbols, date_from, date_to) - return transformed_data.process_data() +def batch_process(list_of_symbols__): + return BatchProcessManager().run_process(list_of_symbols__) @st.cache_data(ttl=1800) -def update_datebase_func(list_of_symbols=retrieve_list_ticket(), date_from=defaut_start_date, date_to=datetime.now().strftime('%Y-%m-%d')): - st.write("Database Updated") - # Retrieve Data from yahoo finance - async_result = pool.apply_async( - retrieve_data_from_yahoo, args=(list_of_symbols, date_from, date_to,)) - bar = st.progress(0) - per = PROCESS_TIME / 100 - for i in range(100): - time.sleep(per) - bar.progress(i + 1) - df = async_result.get() - return df +def batch_process_retrieve_data_by_stock(the_stock_in): + return BatchProcessManager().get_stock_data_by_ticker(the_stock_in) -@st.cache_data(ttl=1800) -def process_data_retrieve_from_database(df_in, list_of_symbols__): - total_data_dict = dict() - for i in range(len(list_of_symbols__)): - filtered_data = df_in[df_in['stock_id'] == list_of_symbols__[i]] - filtered_data = filtered_data.reset_index() - total_data_dict[list_of_symbols__[i]] = filtered_data - return total_data_dict - - -update_database = st.button("Update Database") -if update_database: - df_historical_yahoo = update_datebase_func( - list_of_symbols=_list_of_symbols) - total_data_dict_ = process_data_retrieve_from_database( - df_historical_yahoo, _list_of_symbols) - db_manager = StockDatabaseManager() - db_manager.create_schema_and_tables(_list_of_symbols) - for key, value in total_data_dict_.items(): - if isinstance(value, pd.DataFrame): - db_manager.insert_data(key, value) - all_data = db_manager.fetch_all_data() - for table, df in all_data.items(): - st.write(f"Data for table {table}:") - st.write(df.head(10)) - db_manager.close_connection() - st.write("Done") - # st.write(total_data_dict_) +PROCESS_TIME = 180 # seconds +_list_of_symbols = retrieve_list_ticket() +# --- MAIN PAGE --- +if "stock_data" not in st.session_state: + st.session_state.stock_data = None +st.markdown('---') +st.markdown("### I. Retrieve stock data from database if available") + +the_stock = st.selectbox( + "Select the stock you want to retrieve from database", _list_of_symbols) +btn_prepare = st.button("Retrieve stock data from database...") + +if btn_prepare: + st.session_state.stock_data = the_stock + # df = batch_process_retrieve_data_by_stock(the_stock) + # df = pd.DataFrame(df) + # if df is not None: + # st.write(df) + # st.write("Done") + # else: + # st.write("No data found for this stock, please update the database first.") + +st.markdown('---') +# --- TABS --- +st.markdown( + "### II. List of 500 S&P, Historical data, In Day Data, Top News, Reddit News") List500, Historical_data, IndayData_RealTime, news, reddit_news = st.tabs( ["List 500 S&P", "Historical data", "In Day Data", "Top News", "Reddit News"]) +# --- TABS LIST500 S&P CONTENT--- with List500: st.write("List of 500 S&P") st.write(_list_of_symbols) + +# --- TABS HISTORICAL DATA CONTENT--- +with Historical_data: + if st.session_state.stock_data is not None: + df = batch_process_retrieve_data_by_stock(st.session_state.stock_data) + df = pd.DataFrame(df) + if df is not None: + fig = go.Figure(data=[go.Candlestick(x=df['date'], + open=df['open'], + high=df['high'], + low=df['low'], + close=df['close'])]) + # Add a title + fig.update_layout( + title=f"{st.session_state.stock_data} Price Candlestick Chart", + # Center the title + title_x=0.3, + + # Customize the font and size of the title + title_font=dict(size=24, family="Arial"), + + # Set the background color of the plot + plot_bgcolor='white', + + # Customize the grid lines + xaxis=dict(showgrid=True, gridwidth=1, gridcolor='lightgray'), + yaxis=dict(showgrid=True, gridwidth=1, gridcolor='lightgray'), + ) + + # Add a range slider and customize it + fig.update_layout( + xaxis_rangeslider_visible=True, # Show the range slider + + # Customize the range slider's appearance + xaxis_rangeslider=dict( + thickness=0.1, # Set the thickness of the slider + bordercolor='black', # Set the border color + borderwidth=1, # Set the border width + ) + ) + + # Display the chart in Streamlit + st.plotly_chart(fig) + st.markdown( + f"#### Dataframe of {st.session_state.stock_data} Prices") + st.write(df) + else: + st.write( + "No data found for this stock, please update the database first.") + else: + st.write("Please select the stock to retrieve the data") + +st.markdown('---') +# --- Set Up/ Update all data in database--- +st.markdown("### III. Set Up data in database for the first time") +update_database = st.button("Update Database") +if update_database: + async_result = pool.apply_async( + batch_process, args=(_list_of_symbols,)) + bar = st.progress(0) + per = PROCESS_TIME / 100 + for i in range(100): + time.sleep(per) + bar.progress(i + 1) + df_dict = async_result.get() + st.write("Please check the data in the database") diff --git a/BatchProcess/BatchProcess.py b/BatchProcess/BatchProcess.py new file mode 100644 index 0000000..e41b5fe --- /dev/null +++ b/BatchProcess/BatchProcess.py @@ -0,0 +1,77 @@ +from Database.PostGreSQLInteraction import DatabaseManager, StockDatabaseManager, TicketDimDatabaseManager, RedditNewsDatabaseManager +from BatchProcess.DataSource.YahooFinance.YahooFinances_Services import YahooFinance +from BatchProcess.DataSource.ListSnP500.ListSnP500Collect import ListSAndP500 +from datetime import datetime +import pandas as pd + + +defaut_start_date = "2014-01-01" + +date_to = datetime.now().strftime('%Y-%m-%d') + + +class BatchProcessManager: + def __init__(self): + self.list_of_symbols = None + self.dict_ticket = dict() + + def run_process(self, list_of_symbols_): + self.list_of_symbols = list_of_symbols_ + + # Get data from Yahoo Finance + transformed_data = YahooFinance( + self.list_of_symbols, defaut_start_date, date_to) + df = transformed_data.process_data() + + # Create Database Manager + db_manager = DatabaseManager() + + # Drop all tables exist in the database + db_manager.delete_schema() + + # Create Stock table + db_manager.StockDatabaseManager.create_schema_and_tables( + self.list_of_symbols) + + # Create TicketDim table + db_manager.TicketDimDatabaseManager.create_table() + + # Create RedditNews table + db_manager.RedditNewsDatabaseManager.create_schema_and_tables() + + # Apply multiprocessing to insert data into the database (Testing later) + for i in range(len(self.list_of_symbols)): + filtered_data = df[df['stock_id'] == self.list_of_symbols[i]] + filtered_data = filtered_data.reset_index() + self.dict_ticket[self.list_of_symbols[i]] = filtered_data + + # Insert data into the database Stock table + for key, value in self.dict_ticket.items(): + if isinstance(value, pd.DataFrame): + db_manager.StockDatabaseManager.insert_data(key, value) + + # Insert data into the database TicketDim table + db_manager.TicketDimDatabaseManager.insert_data(self.list_of_symbols) + + db_manager.close_connection() + return self.dict_ticket + + def get_stock_data_by_ticker(self, ticker): + try: + db_manager = StockDatabaseManager() + data = db_manager.get_data_by_table(ticker) + db_manager.close_connection() + return data + except Exception as e: + print(e) + return None + + def get_stock_list_in_database(self): + try: + db_manager = TicketDimDatabaseManager() + data = db_manager.get_data() + db_manager.close_connection() + return data + except Exception as e: + print(e) + return None diff --git a/Database/.env b/Database/.env index c8c3e7c..8ed62df 100644 --- a/Database/.env +++ b/Database/.env @@ -3,4 +3,10 @@ DATABASE_PORT="5432" DATABASE_NAME="postgres" DATABASE_USER="postgres" DATABASE_PASSWORD="admin" -CREATE_SCHEMA_QUERY="CREATE SCHEMA IF NOT EXISTS tickets" \ No newline at end of file +CREATE_SCHEMA_QUERY="CREATE SCHEMA IF NOT EXISTS tickets;" + +INSERT_QUERY_REDDIT_TABLE="INSERT INTO reddits.stock_reddit_news (id, subreddit, url, title, score, num_comments, downvotes, ups, date_created_utc) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO NOTHING;" +POSTGRE_CONNECTION="dbname=postgres user=postgres host=localhost password=admin" +CONFIGURE_REDDIT_TABLE = "CREATE INDEX IF NOT EXISTS idx_stock_reddit_news_id ON reddits.stock_reddit_news(id);" +CREATE_REDDIT_TABLE_QUERY = "CREATE TABLE IF NOT EXISTS reddits.stock_reddit_news (id VARCHAR PRIMARY KEY, subreddit VARCHAR, url VARCHAR, title TEXT, score TEXT, num_comments TEXT, downvotes TEXT, ups TEXT, date_created_utc TEXT);" +CREATE_REDDIT_SCHEMA_QUERY="CREATE SCHEMA IF NOT EXISTS reddits;" diff --git a/Database/PostGreSQLInteraction.py b/Database/PostGreSQLInteraction.py index f161e56..26c5c75 100644 --- a/Database/PostGreSQLInteraction.py +++ b/Database/PostGreSQLInteraction.py @@ -1,4 +1,6 @@ +import plotly.graph_objects as go from dotenv import load_dotenv +from datetime import datetime from psycopg2 import sql import pandas as pd import psycopg2 @@ -14,6 +16,11 @@ create_schema_query = os.getenv("CREATE_SCHEMA_QUERY") +CREATE_REDDIT_TABLE_QUERY = os.getenv("CREATE_REDDIT_TABLE_QUERY") +CONFIGURE_REDDIT_TABLE = os.getenv("CONFIGURE_REDDIT_TABLE") +INSERT_QUERY_REDDIT_TABLE = os.getenv("INSERT_QUERY_REDDIT_TABLE") +CREATE_REDDIT_SCHEMA_QUERY = os.getenv("CREATE_REDDIT_SCHEMA_QUERY") + class StockDatabaseManager: def __init__(self): @@ -31,99 +38,379 @@ def create_connection(self): """ Create a connection to the database """ - conn = psycopg2.connect( - dbname=self.dbname, - user=self.user, - password=self.password, - host=self.host, - port=self.port - ) - return conn + try: + conn = psycopg2.connect( + dbname=self.dbname, + user=self.user, + password=self.password, + host=self.host, + port=self.port + ) + return conn + except Exception as e: + print(e) + return None def create_schema_and_tables(self, tickers): """ Create the schema and tables for the given tickers list """ - cursor = self.conn.cursor() + try: + cursor = self.conn.cursor() - # Create schema if it doesn't exist - cursor.execute(create_schema_query) + # Create schema if it doesn't exist + cursor.execute(create_schema_query) - # Loop through each ticker and create the corresponding table - for ticker in tickers: - cursor.execute( - "CREATE TABLE IF NOT EXISTS tickets." + ticker + " (" - "stock_id VARCHAR(10)," - "date VARCHAR(10)," - "open VARCHAR(50)," - "high VARCHAR(50)," - "low VARCHAR(50)," - "close VARCHAR(50)," - "volume VARCHAR(50)," - "PRIMARY KEY (stock_id, date))" - ) + # Loop through each ticker and create the corresponding table + for ticker in tickers: + cursor.execute( + "CREATE TABLE IF NOT EXISTS tickets." + ticker + " (" + "stock_id VARCHAR(10)," + "date VARCHAR(10)," + "open VARCHAR(50)," + "high VARCHAR(50)," + "low VARCHAR(50)," + "close VARCHAR(50)," + "volume VARCHAR(50)," + "PRIMARY KEY (stock_id, date))" + ) - # Create index on date for faster queries - cursor.execute( - "CREATE INDEX IF NOT EXISTS " + ticker + - "_date_idx ON tickets." + ticker + " (date)" - ) + # Create index on date for faster queries + cursor.execute( + "CREATE INDEX IF NOT EXISTS " + ticker + + "_date_idx ON tickets." + ticker + " (date)" + ) - self.conn.commit() - cursor.close() + self.conn.commit() + cursor.close() + except Exception as e: + print(e) def insert_data(self, ticker, data): """ Insert data into the database """ - cursor = self.conn.cursor() - - # Ensure all data is treated as string - data = data.astype(str) - - insert_query = ( - "INSERT INTO tickets." + ticker + - " (stock_id, date, open, high, low, close, volume)" - " VALUES (%s, %s, %s, %s, %s, %s, %s)" - " ON CONFLICT (stock_id, date) DO UPDATE SET" - " open = EXCLUDED.open," - " high = EXCLUDED.high," - " low = EXCLUDED.low," - " close = EXCLUDED.close," - " volume = EXCLUDED.volume" - ) - - for index, row in data.iterrows(): - cursor.execute(insert_query, (row['stock_id'], row['date'], - row['open'], row['high'], row['low'], row['close'], row['volume'])) - self.conn.commit() - cursor.close() + try: + cursor = self.conn.cursor() + + # Ensure all data is treated as string + data = data.astype(str) + + insert_query = ( + "INSERT INTO tickets." + ticker + + " (stock_id, date, open, high, low, close, volume)" + " VALUES (%s, %s, %s, %s, %s, %s, %s)" + " ON CONFLICT (stock_id, date) DO UPDATE SET" + " open = EXCLUDED.open," + " high = EXCLUDED.high," + " low = EXCLUDED.low," + " close = EXCLUDED.close," + " volume = EXCLUDED.volume" + ) + + for index, row in data.iterrows(): + cursor.execute(insert_query, (row['stock_id'], row['date'], + row['open'], row['high'], row['low'], row['close'], row['volume'])) + self.conn.commit() + cursor.close() + except Exception as e: + print(e) + + def get_data_by_table(self, table_name): + """ + Get data from the given ticket table + """ + try: + query = f"SELECT * FROM tickets.{table_name}" + data = pd.read_sql(query, self.conn) + return data + except Exception as e: + print(e) + return None def get_tables(self, schema='tickets'): """ Get all tables in the given schema """ - cursor = self.conn.cursor() - cursor.execute( - "SELECT table_name FROM information_schema.tables " - "WHERE table_schema = %s", (schema,) - ) - tables = cursor.fetchall() - cursor.close() - return [table[0] for table in tables] + try: + cursor = self.conn.cursor() + cursor.execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = %s", (schema,) + ) + tables = cursor.fetchall() + cursor.close() + return [table[0] for table in tables] + except Exception as e: + print(e) + return None def fetch_all_data(self, schema='tickets'): """ Fetch all data from the given schema """ - tables = self.get_tables(schema) - all_data = {} - for table in tables: - query = "SELECT * FROM " + schema + "." + table - df = pd.read_sql(query, self.conn) - all_data[table] = df - return all_data + try: + tables = self.get_tables(schema) + all_data = {} + for table in tables: + query = "SELECT * FROM " + schema + "." + table + df = pd.read_sql(query, self.conn) + all_data[table] = df + return all_data + except Exception as e: + print(e) + return None + + def close_connection(self): + if self.conn: + try: + self.conn.close() + except Exception as e: + print(e) + + +class TicketDimDatabaseManager: + def __init__(self): + """ + Initialize the database connection + """ + self.dbname = postgres_dbname + self.user = postgres_user + self.password = postgres_pass + self.host = postgres_server + self.port = postgres_port + self.conn = self.create_connection() + + def create_connection(self): + """ + Create a connection to the database + """ + try: + conn = psycopg2.connect( + dbname=self.dbname, + user=self.user, + password=self.password, + host=self.host, + port=self.port + ) + return conn + except Exception as e: + print(e) + return None + + def create_table(self): + """ + Create the ticket_dim table + """ + try: + cursor = self.conn.cursor() + + # Create schema if it doesn't exist + cursor.execute(create_schema_query) + + # Create the ticket_dim table + cursor.execute( + "CREATE TABLE IF NOT EXISTS tickets.ticket_dim (" + "symbol VARCHAR(10) PRIMARY KEY," + "company_name VARCHAR(255) NULL," + "established DATE NULL," + "sector VARCHAR(100) NULL," + "industry VARCHAR(100) NULL," + "exchange VARCHAR(50) NULL" + ")" + ) + + self.conn.commit() + cursor.close() + except Exception as e: + print(e) + + def insert_data(self, data): + """ + Insert data into the ticket_dim table + """ + try: + if isinstance(data, list): + data = pd.DataFrame(data, columns=['symbol']) + data['company_name'] = None + data['established'] = None + data['sector'] = None + data['industry'] = None + data['exchange'] = None + + cursor = self.conn.cursor() + + # Ensure all data is treated as string + # df_data = pd.DataFrame(data_in) + # data = df_data.astype(str) + + insert_query = ( + "INSERT INTO tickets.ticket_dim (symbol, company_name, established, sector, industry, exchange)" + " VALUES (%s, %s, %s, %s, %s, %s)" + " ON CONFLICT (symbol) DO UPDATE SET" + " company_name = EXCLUDED.company_name," + " established = EXCLUDED.established," + " sector = EXCLUDED.sector," + " industry = EXCLUDED.industry," + " exchange = EXCLUDED.exchange" + ) + + for index, row in data.iterrows(): + cursor.execute(insert_query, (row['symbol'], row.get('company_name'), + row.get('established'), row.get('sector'), row.get('industry'), row.get('exchange'))) + self.conn.commit() + cursor.close() + except Exception as e: + print(e) + + def get_data(self): + """ + Get all data from the ticket_dim table + """ + try: + cursor = self.conn.cursor() + cursor.execute( + "SELECT * FROM tickets.ticket_dim" + ) + data = cursor.fetchall() + cursor.close() + # Convert data to list + data = [list(row)[0] for row in data] + return data + except Exception as e: + print(e) + return None + + def search_ticker(self, symbol): + """ + Search for a specific ticker by its symbol + """ + try: + cursor = self.conn.cursor() + cursor.execute( + "SELECT * FROM tickets.ticket_dim WHERE symbol = %s", (symbol,) + ) + data = cursor.fetchone() + cursor.close() + return data + except Exception as e: + print(e) + return None def close_connection(self): if self.conn: - self.conn.close() + try: + self.conn.close() + except Exception as e: + print(e) + + +class RedditNewsDatabaseManager: + def __init__(self): + """ + Initialize the database connection + """ + self.dbname = postgres_dbname + self.user = postgres_user + self.password = postgres_pass + self.host = postgres_server + self.port = postgres_port + self.conn = self.create_connection() + + def create_connection(self): + """ + Create a connection to the database + """ + try: + conn = psycopg2.connect( + dbname=self.dbname, + user=self.user, + password=self.password, + host=self.host, + port=self.port + ) + return conn + except Exception as e: + print(e) + return None + + def create_schema_and_tables(self): + """ + Create the stock_reddit_news table + """ + try: + cursor = self.conn.cursor() + cursor.execute(CREATE_REDDIT_SCHEMA_QUERY) + cursor.execute(CREATE_REDDIT_TABLE_QUERY) + cursor.execute(CONFIGURE_REDDIT_TABLE) + self.conn.commit() + cursor.close() + except Exception as e: + print(e) + + def insert_data(self, data): + """ + Insert data into the stock_reddit_news table + """ + try: + cursor = self.conn.cursor() + insert_query = INSERT_QUERY_REDDIT_TABLE + + for index, row in data.iterrows(): + cursor.execute(insert_query, (row["id"], row["subreddit"], row["url"], row["title"].replace('\'', ""), + row["score"], row["num_comments"], row["downvotes"], row["ups"], row["date_created_utc"])) + self.conn.commit() + cursor.close() + except Exception as e: + print(e) + + def get_news_by_ticker(self, ticker): + """ + Get news from the stock_reddit_news table corresponding to a specific ticker + """ + try: + cursor = self.conn.cursor() + query = """ + SELECT * FROM reddits.stock_reddit_news + WHERE LOWER(title) LIKE %s + """ + cursor.execute(query, ('%' + ticker.lower() + '%',)) + data = cursor.fetchall() + cursor.close() + return data + except Exception as e: + print(e) + return None + + def close_connection(self): + if self.conn: + try: + self.conn.close() + except Exception as e: + print(e) + + +class DatabaseManager: + def __init__(self): + """ + Initialize the database connection + """ + self.StockDatabaseManager = StockDatabaseManager() + self.TicketDimDatabaseManager = TicketDimDatabaseManager() + self.RedditNewsDatabaseManager = RedditNewsDatabaseManager() + + def delete_schema(self): + try: + # Fix later + cursor = self.StockDatabaseManager.conn.cursor() + cursor.execute("DROP SCHEMA IF EXISTS tickets CASCADE;") + cursor.execute("DROP SCHEMA IF EXISTS reddits CASCADE;") + self.StockDatabaseManager.conn.commit() + cursor.close() + except Exception as e: + print(e) + + def close_connection(self): + self.StockDatabaseManager.close_connection() + self.TicketDimDatabaseManager.close_connection() + self.RedditNewsDatabaseManager.close_connection() diff --git a/Database/PostGreSQLInteraction_Alchemy.py b/Database/PostGreSQLInteraction_Alchemy.py new file mode 100644 index 0000000..bfc96d8 --- /dev/null +++ b/Database/PostGreSQLInteraction_Alchemy.py @@ -0,0 +1,102 @@ +import pandas as pd +from sqlalchemy import create_engine, MetaData, Table, Column, String +from sqlalchemy.dialects.postgresql import insert +from dotenv import load_dotenv +import os + + +class StockDatabaseManager: + def __init__(self): + load_dotenv(override=True) + self.dbname = os.getenv("DATABASE_NAME") + self.user = os.getenv("DATABASE_USER") + self.password = os.getenv("DATABASE_PASSWORD") + self.host = os.getenv("DATABASE_SERVER") + self.port = os.getenv("DATABASE_PORT") + self.engine = self.create_engine() + self.metadata = MetaData(schema='tickets') + self.create_schema_query = os.getenv("CREATE_SCHEMA_QUERY") + + def create_engine(self): + db_url = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}" + engine = create_engine(db_url) + return engine + + def create_schema_and_tables(self, tickers): + try: + with self.engine.connect() as connection: + # Create schema if it doesn't exist + connection.execute(self.create_schema_query) + + for ticker in tickers: + table = Table(ticker, self.metadata, + Column('stock_id', String( + 10), primary_key=True), + Column('date', String(10), primary_key=True), + Column('open', String(50)), + Column('high', String(50)), + Column('low', String(50)), + Column('close', String(50)), + Column('volume', String(50)) + ) + table.create(self.engine, checkfirst=True) + return True + except Exception as e: + print(f"Error creating schema and tables: {e}") + return False + + def insert_data(self, ticker, data): + table = Table(ticker, self.metadata, autoload_with=self.engine) + + # Ensure all data is treated as string + data = data.astype(str) + try: + with self.engine.connect() as connection: + for index, row in data.iterrows(): + stmt = insert(table).values( + stock_id=row['stock_id'], + date=row['date'], + open=row['open'], + high=row['high'], + low=row['low'], + close=row['close'], + volume=row['volume'] + ).on_conflict_do_update( + index_elements=['stock_id', 'date'], + set_=dict( + open=row['open'], + high=row['high'], + low=row['low'], + close=row['close'], + volume=row['volume'] + ) + ) + connection.execute(stmt) + return True + except Exception as e: + print(f"Error inserting data: {e}") + return False + + def get_tables(self): + try: + inspector = self.engine.inspect(self.engine) + tables = inspector.get_table_names(schema='tickets') + return tables + except Exception as e: + print(f"Error getting tables: {e}") + return [] + + def fetch_all_data(self): + try: + tables = self.get_tables() + all_data = {} + for table_name in tables: + table = Table(table_name, self.metadata, + autoload_with=self.engine) + query = table.select() + df = pd.read_sql(query, self.engine) + all_data[table_name] = df + return all_data + except Exception as e: + print(f"Error fetching data: {e}") + return {} diff --git a/README.md b/README.md index ffd0c30..c6d3205 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,10 @@ Install requirements.txt: Automatically installed dependencies that needed for t ```bash pip install -r requirements.txt ``` + or + ```bash + pip install streamlit, pyspark, yfinance, psycopg2, python-dotenv + ```

diff --git a/pages/2_Admin.py b/pages/2_Admin.py deleted file mode 100644 index e69de29..0000000 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4e36ba8 Binary files /dev/null and b/requirements.txt differ