<a href="https://colab.research.google.com/github/adidror005/HomeAssignment/blob/main/GetBatchDollarBarsNumba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pandas as pd
import numpy as np
from numba import njit
from typing import List
from dataclasses import dataclass
import glob
BARS_PER_DAY = 50

# ================== Data Class for Dollar Bars ==================
@dataclass
class DollarBar:
    ts_start: pd.Timestamp
    ts_end: pd.Timestamp
    open: float
    high: float
    low: float
    close: float
    volume: int
    dollar_amount: float

# ================== Dollar Bar Creation Logic ==================
@njit
def _create_dollar_bars(price, size, day_int, daily_dollar_amounts, period_len=28):
    dollar_threshold = np.median(np.array(daily_dollar_amounts[-period_len:])) / BARS_PER_DAY
    n = len(price)
    max_bars = n
    bar_indices = np.empty((max_bars, 2), dtype=np.int64)
    opens = np.empty(max_bars)
    highs = np.empty(max_bars)
    lows = np.empty(max_bars)
    closes = np.empty(max_bars)
    volumes = np.empty(max_bars, dtype=np.int64)
    dollar_amts = np.empty(max_bars)

    count = 0
    dollar_sum = 0.0
    volume = 0
    start_idx = 0
    open_price = high = low = 0.0
    prev_day = day_int[0]
    new_daily_dollar_amount = 0.0

    for i in range(n):
        if day_int[i] != prev_day:
            if i > start_idx:
                bar_indices[count] = start_idx, i - 1
                opens[count] = open_price
                highs[count] = high
                lows[count] = low
                closes[count] = price[i - 1]
                volumes[count] = volume
                dollar_amts[count] = dollar_sum
                count += 1

                daily_dollar_amounts.append(new_daily_dollar_amount)
                dollar_threshold = np.median(np.array(daily_dollar_amounts[-period_len:])) / BARS_PER_DAY
                new_daily_dollar_amount = 0.0

            dollar_sum = 0.0
            volume = 0
            start_idx = i
            open_price = price[i]
            high = price[i]
            low = price[i]

        high = max(high, price[i])
        low = min(low, price[i])
        volume += size[i]
        trade_value = price[i] * size[i]
        dollar_sum += trade_value
        new_daily_dollar_amount += trade_value

        if dollar_sum >= dollar_threshold:
            bar_indices[count] = start_idx, i
            opens[count] = open_price
            highs[count] = high
            lows[count] = low
            closes[count] = price[i]
            volumes[count] = volume
            dollar_amts[count] = dollar_sum
            count += 1
            dollar_sum = 0.0
            volume = 0


        prev_day = day_int[i]

    # Handle the last bar if not closed
    if dollar_sum > 0 and start_idx < n:
        bar_indices[count] = start_idx, n - 1
        opens[count] = open_price
        highs[count] = high
        lows[count] = low
        closes[count] = price[-1]
        volumes[count] = volume
        dollar_amts[count] = dollar_sum
        daily_dollar_amounts.append(new_daily_dollar_amount)
        count += 1

    return bar_indices[:count], opens[:count], highs[:count], lows[:count], closes[:count], volumes[:count], dollar_amts[:count]

# ================== Create Dollar Bars Fast ==================
def create_dollar_bars_fast(df: pd.DataFrame, daily_dollar_amounts: List[float], period_len=28) -> List[DollarBar]:
    price = df['price'].values
    size = df['size'].values
    day_int = df['day_int'].values
    ts = df.index.values

    bar_indices, opens, highs, lows, closes, volumes, dollar_amts = _create_dollar_bars(price, size, day_int, daily_dollar_amounts, period_len)

    dollar_bars = [
        DollarBar(pd.Timestamp(ts[start]), pd.Timestamp(ts[end]), opens[i], highs[i], lows[i], closes[i], volumes[i], dollar_amts[i])
        for i, (start, end) in enumerate(bar_indices)
    ]
    return dollar_bars


# ================== Main Processing Loop ==================
symbol = 'AMD'
files = sorted(glob.glob(f'/content/drive/MyDrive/stock_trades_daily/{symbol}/*.parquet'))
span = 200
alpha = 2 / (span + 1)
ewms = np.empty(0, dtype=np.float64)
dollar_bars = []
daily_dollar_amounts = []
for i, file in enumerate(files):
    print(f"Processing File: {file}")
    df = pd.read_parquet(file)

    # === Reintroduced Your Filtering Logic ===
    df = df.droplevel(0)
    df['conditions'] = df.conditions.apply(lambda l: "".join(l))
    df = df[df.conditions.isin(['@', '@F', '@I', '@FI'])]
    df.index = df.index.tz_convert("US/Eastern")
    df = df.between_time('9:30', '16:00')

    # Add necessary columns for processing
    df['day'] = df.index.date
    df['day_int'] = df['day'].astype('datetime64[s]').astype('int64')
    df['dollar_amount'] = df['price'] * df['size']

    # Initialize the first 28-day median with the first file's daily sum
    if i == 0:
        median_dollar_amount = np.median(df.groupby('day')['dollar_amount'].sum())
        daily_dollar_amounts = [median_dollar_amount] * 28

    # Create dollar bars
    new_bars = create_dollar_bars_fast(df, daily_dollar_amounts)
    dollar_bars.extend(new_bars)


Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-01-02_2017-01-30.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-01-30_2017-02-27.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-02-27_2017-03-27.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-03-27_2017-04-24.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-04-24_2017-05-22.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-05-22_2017-06-19.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-06-19_2017-07-17.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-07-17_2017-08-14.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2017-08-14_2017-09-11.parquet
Processing File: /content/drive/MyDrive/stock_

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['day'] = df.index.date


Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-06-14_2021-07-12.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-07-12_2021-08-09.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-08-09_2021-09-06.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-09-06_2021-10-04.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-10-04_2021-11-01.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-11-01_2021-11-29.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-11-29_2021-12-27.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2021-12-27_2022-01-24.parquet


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['day'] = df.index.date


Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-01-24_2022-02-21.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-02-21_2022-03-21.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-03-21_2022-04-18.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-04-18_2022-05-16.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-05-16_2022-06-13.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-06-13_2022-07-11.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-07-11_2022-08-08.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-08-08_2022-09-05.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2022-09-05_2022-10-03.parquet
Processing File: /content/drive/MyDrive/stock_

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['day'] = df.index.date


Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2024-12-23_2025-01-20.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2025-01-20_2025-02-17.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2025-02-17_2025-03-17.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2025-03-17_2025-04-14.parquet
Processing File: /content/drive/MyDrive/stock_trades_daily/AMD/stock_trades_2025-04-14_2025-05-09_partial.parquet


### Save Dollar Bars and Daily Dollar Bar Amounts to Parquet

In [3]:
df_dollar_bars=pd.DataFrame(dollar_bars)
df_daily_dollar_amounts=pd.DataFrame(daily_dollar_amounts)

### Save Dollar Bars to Parquet

In [4]:
symbol = "AMD"
df_dollar_bars.to_parquet(f"/content/drive/MyDrive/dollar_bars_{symbol}.parquet")
df_daily_dollar_amounts.to_parquet(f"/content/drive/MyDrive/daily_dollar_amounts_{symbol}.parquet")

### Save to MYSQL PythonAnywhere!

In [14]:
!pip install sshtunnel pandas pymysql sqlalchemy

import pandas as pd
import pymysql
from sqlalchemy import create_engine
from sshtunnel import SSHTunnelForwarder
from getpass import getpass
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def create_ssh_tunnel(ssh_host, ssh_username, ssh_password, mysql_host, mysql_port):
    """Create and return an SSH tunnel."""
    try:
        tunnel = SSHTunnelForwarder(
            (ssh_host, 22),
            ssh_username=ssh_username,
            ssh_password=ssh_password,
            remote_bind_address=(mysql_host, mysql_port),
            local_bind_address=('127.0.0.1', 3307)
        )
        tunnel.start()
        logger.info(f"SSH tunnel established to {ssh_host}. MySQL available at 127.0.0.1:{tunnel.local_bind_port}")
        return tunnel
    except Exception as e:
        logger.error(f"Failed to create SSH tunnel: {e}")
        raise

def create_sql_engine(mysql_username, mysql_password, db_name, tunnel):
    """Create and return SQLAlchemy engine."""
    try:
        connection_string = f"mysql+pymysql://{mysql_username}:{mysql_password}@127.0.0.1:{tunnel.local_bind_port}/{db_name}"
        engine = create_engine(connection_string, pool_pre_ping=True)
        logger.info("SQLAlchemy engine created successfully")
        return engine
    except Exception as e:
        logger.error(f"Failed to create SQL engine: {e}")
        raise

def main():
    # Configuration
    SSH_HOST = "ssh.pythonanywhere.com"
    SSH_USERNAME = "trademamba"
    MYSQL_HOST = "trademamba.mysql.pythonanywhere-services.com"  # Replace with your MySQL hostname
    MYSQL_USERNAME = "trademamba"
    MYSQL_PORT = 3306
    DB_NAME = "trademamba$default"

    # Get passwords securely
    try:
        ssh_password = getpass("Enter SSH password: ")
        mysql_password = getpass("Enter MySQL password: ")
    except Exception as e:
        logger.error(f"Error getting passwords: {e}")
        return


    tunnel = None
    try:
        # Create SSH tunnel
        tunnel = create_ssh_tunnel(SSH_HOST, SSH_USERNAME, ssh_password, MYSQL_HOST, MYSQL_PORT)

        # Create SQLAlchemy engine
        engine = create_sql_engine(MYSQL_USERNAME, mysql_password, DB_NAME, tunnel)

        # Test connection and write data
        with engine.connect() as conn:
            logger.info("Database connection successful")
            df_daily_dollar_amounts.reset_index().to_sql("daily_dollar_amounts", engine, if_exists="replace", index=False)
            logger.info("Data successfully written to 'test' table")
            df_dollar_bars.to_sql("dollar_bars", engine, if_exists="replace",index=False)
            logger.info("Data successfully written to 'test' table")


    except Exception as e:
        logger.error(f"Error: {e}")
        logger.info("\nTroubleshooting tips:")
        logger.info("1. Verify MySQL hostname (e.g., trademba.mysql.pythonanywhere-services.com)")
        logger.info("2. Ensure SSH access is enabled in your PythonAnywhere account")
        logger.info("3. Check SSH and MySQL credentials")
        logger.info("4. Confirm sshtunnel is installed: pip install sshtunnel")
        logger.info("5. Verify port 22 (SSH) is not blocked by your network")
        logger.info("6. Contact PythonAnywhere support if MySQL server issues persist")

    finally:
        if tunnel is not None and tunnel.is_active:
            tunnel.stop()
            logger.info("SSH tunnel closed")

if __name__ == "__main__":
    main()

Enter SSH password: ··········
Enter MySQL password: ··········
