In [6]:
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql import SparkSession
from dotenv import load_dotenv
from datetime import datetime
import yfinance as yf
import pandas as pd
import os

load_dotenv(override=True)

postgres_v = os.getenv("POSTGRES_VERSION")
postgres_url = os.getenv("POSTGRES_URL")
postgres_user = os.getenv("POSTGRES_USER")
postgres_pass = os.getenv("POSTGRES_PASSWORD")
postgres_table = os.getenv("POSTGRES_TABLE")
format_file = os.getenv("FORMAT_FILE")
_mode = os.getenv("MODE")

config_ = postgres_v

column_1_name = os.getenv("COLUMN_1")
column_2_name = os.getenv("COLUMN_2")
column_3_name = os.getenv("COLUMN_3")
column_4_name = os.getenv("COLUMN_4")
column_5_name = os.getenv("COLUMN_5")
column_6_name = os.getenv("COLUMN_6")
column_7_name = os.getenv("COLUMN_7")
column_8_name = os.getenv("COLUMN_8")

list_remove = ['GEV','SOLV','VLTO','BF.B','BRK.B']

class ListSAndP500:
    def __init__(self):
        """
        Initialize the ListSAndP500 class

        Attributes:
        tickers_string (list): A list of stock symbols in string format
        tickers_list (list): A list of stock symbols in list format

        """
        _tickers = pd.read_html('https://en.wikipedia.org/wiki/List_of_S%26P_500_companies')[0]
        _tickers = _tickers.Symbol.to_list()
        self.tickers_string = [i.replace('.','-') for i in _tickers]
        _tickers_list_transform_ = [i if i not in list_remove else False for i in _tickers]
        self.tickers_list = [i for i in _tickers_list_transform_ if i]
        
list_of_symbols__ = ListSAndP500().tickers_list

class YahooFinance:
    def __init__(self, list_of_symbols, start, end):
        self.schema = StructType([
            StructField(column_1_name, StringType(), True),
            StructField(column_2_name, StringType(), True),
            StructField(column_3_name, StringType(), True),
            StructField(column_4_name, StringType(), True),
            StructField(column_5_name, StringType(), True),
            StructField(column_6_name, StringType(), True),
            StructField(column_7_name, StringType(), True),
            StructField(column_8_name, StringType(), True)
        ])

        self.symbols = list_of_symbols
        self.interval = '1d'
        self.start = start
        self.end = end
        self.results = self.process_data()

    def process_data(self):
        """
        Process the historical stock data for the stock symbols
        """
        data = self.get_data()
        return self.transform_data(data)

    def get_data(self):
        """
        Get historical stock data from Yahoo Finance API using yfinance library

        Returns:
        DataFrame: A DataFrame containing historical stock data
        """
        try:
            data = yf.download(
                self.symbols,
                start=self.start,
                end=self.end,
                interval=self.interval,
                ignore_tz=True,
                threads=5,
                timeout=60,
                progress=True
            )
            return data
        except Exception as e:
            print(f"Error downloading data: {e}")
            return None

    def transform_data(self, df):
        """
        Transform the historical stock data into a format that can be stored in a database FactPrices table

        Args:
        df (DataFrame): A DataFrame containing historical stock data

        Returns:
        DataFrame: A DataFrame containing transformed historical stock data with the following columns:
        - stock_id (str): The stock symbol
        - date (str): The date of the stock data
        - open (float): The opening price of the stock
        - high (float): The highest price of the stock
        - low (float): The lowest price of the stock
        - close (float): The closing price of the stock
        - volume (int): The volume of the stock
        - adjusted_close (float): The adjusted closing price of the stock

        """
        # Reset the index to turn the MultiIndex into columns
        df = df.reset_index()

        # Create a list to store transformed records
        records = []

        # Iterate over each row and stock symbol
        for index, row in df.iterrows():
            date = row[('Date', '')]
            for stock in self.symbols:
                try:
                    record = {
                        column_1_name: stock,
                        column_2_name: date,
                        column_3_name: row[('Open', stock)],
                        column_4_name: row[('High', stock)],
                        column_5_name: row[('Low', stock)],
                        column_6_name: row[('Close', stock)],
                        column_7_name: row[('Volume', stock)],
                        column_8_name: row[('Adj Close', stock)]
                    }
                    records.append(record)
                except KeyError as e:
                    print(f"KeyError: {e} for stock: {stock} on date: {date}")

        # Convert the list of records into a DataFrame
        return pd.DataFrame(records)

In [7]:
transformed_data = YahooFinance(list_of_symbols__, '2015-01-01', datetime.now().strftime('%Y-%m-%d')).results
print(transformed_data.head())

[*********************100%%**********************]  498 of 498 completed


  stock_id       date        open        high         low       close  \
0      MMM 2015-01-02  137.717392  138.026749  136.061874  137.173920   
1      AOS 2015-01-02   28.309999   28.415001   27.775000   28.004999   
2      ABT 2015-01-02   45.250000   45.450001   44.639999   44.900002   
3     ABBV 2015-01-02   65.440002   66.400002   65.440002   65.889999   
4      ACN 2015-01-02   89.669998   90.089996   88.430000   88.839996   

      volume  adjusted_close  
0  2531214.0       92.933594  
1  1540200.0       24.154337  
2  3216600.0       37.481434  
3  5086100.0       44.314426  
4  2021300.0       75.950943  


In [8]:
len(transformed_data)

1185240

## Transform Data

In [18]:
total_data_dict = dict()
fileroute="./dataset/"
fileroute_ticket="./dataset/tickets/"
transformed_data.to_csv(fileroute + "Summary.csv", index=False, encoding='utf-8')

for i in range(len(list_of_symbols__)):
    filtered_data = transformed_data[transformed_data['stock_id'] == list_of_symbols__[i]]
    filtered_data = filtered_data.reset_index()
    total_data_dict[list_of_symbols__[i]] = filtered_data
    filtered_data.to_csv(fileroute_ticket + list_of_symbols__[i] + ".csv", index=False, encoding='utf-8')

## Checking the missing data

In [13]:
unique_dates = pd.DatetimeIndex(transformed_data['date'].unique())

In [14]:
complete_dates = pd.date_range(start='2015-01-01', end=datetime.now().strftime('%Y-%m-%d'))

missing_dates = complete_dates.difference(unique_dates)

# Convert missing dates to a DataFrame for display
missing_dates_df = pd.DataFrame(missing_dates, columns=['missing_dates'])

print(len(missing_dates_df))
print(missing_dates_df.head(20))
print(len(complete_dates))

1077
   missing_dates
0     2015-01-01
1     2015-01-03
2     2015-01-04
3     2015-01-10
4     2015-01-11
5     2015-01-17
6     2015-01-18
7     2015-01-19
8     2015-01-24
9     2015-01-25
10    2015-01-31
11    2015-02-01
12    2015-02-07
13    2015-02-08
14    2015-02-14
15    2015-02-15
16    2015-02-16
17    2015-02-21
18    2015-02-22
19    2015-02-28
3457
