Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ DATABASE_URL=postgresql://trading_user:your_strong_password_here@db:5432/trading
# This key is required for the Database Agent to accept requests.
# Generate a secure key, e.g., using: openssl rand -hex 32
DATABASE_AGENT_API_KEY=

# Alpaca API Credentials
# These are required to fetch market data from Alpaca.
# Sign up for a free account at https://alpaca.markets/
ALPACA_API_KEY=
ALPACA_SECRET_KEY=
136 changes: 136 additions & 0 deletions alpaca_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
import logging
from datetime import datetime, timedelta
from alpaca.data.historical import StockHistoricalDataClient
from alpaca.data.requests import StockBarsRequest
from alpaca.data.timeframe import TimeFrame
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from dotenv import load_dotenv

class AlpacaClient:
"""
A client for interacting with the Alpaca API, with built-in retry logic.
This client uses the modern alpaca-py library.
"""
def __init__(self, api_key: str, secret_key: str):
if not api_key or not secret_key:
raise ValueError("API key and secret key cannot be empty.")
# sandbox=True is used for paper trading
self.client = StockHistoricalDataClient(api_key, secret_key, sandbox=True)
logging.info("Alpaca API client (alpaca-py) initialized for paper trading.")

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type(Exception),
before_sleep=lambda retry_state: logging.warning(
f"Retrying Alpaca API call due to: {retry_state.outcome.exception()}. "
f"Attempt #{retry_state.attempt_number}..."
)
)
def fetch_historical_prices(self, symbol: str, timeframe_str: str, start_date: str, end_date: str):
"""
Fetches historical OHLCV data from Alpaca for a given symbol and timeframe.

Args:
symbol (str): The stock symbol (e.g., 'GOOG').
timeframe_str (str): The timeframe for the bars ('4h', '1d').
start_date (str): The start date in 'YYYY-MM-DD' format.
end_date (str): The end date in 'YYYY-MM-DD' format.

Returns:
list[dict]: A list of dictionaries, where each dictionary represents a price bar.
Returns an empty list if there's an error or no data.
"""
logging.info(f"Fetching historical data for {symbol} with timeframe {timeframe_str} from {start_date} to {end_date}.")
try:
# Map our string timeframe to the Alpaca SDK's Enum
timeframe_map = {
'4h': TimeFrame.Hour, # Note: Alpaca API might not support 4H directly.
'1d': TimeFrame.Day,
}
if timeframe_str.lower() == '4h':
# Alpaca's get_bars doesn't directly support '4H'.
# We fetch '1H' data as a workaround.
logging.warning("Alpaca API does not directly support '4H' timeframe. Fetching '1H' data instead.")
alpaca_timeframe = TimeFrame.Hour
elif timeframe_str.lower() == '1d':
alpaca_timeframe = TimeFrame.Day
else:
logging.error(f"Unsupported timeframe: {timeframe_str}")
return []

request_params = StockBarsRequest(
symbol_or_symbols=[symbol],
timeframe=alpaca_timeframe,
start=start_date,
end=end_date
)

bars = self.client.get_stock_bars(request_params).df

if bars.empty:
logging.warning(f"No data returned for {symbol} in the given date range.")
return []

# Data comes in a multi-index DataFrame, reset index to work with it
bars.reset_index(inplace=True)

# Rename columns to match our database schema
bars.rename(columns={
'symbol': 'symbol_col', # Avoid clash with our own 'symbol'
'timestamp': 'timestamp_col',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'volume': 'volume'
}, inplace=True)

# Format for database insertion
bars['symbol'] = symbol
bars['timeframe'] = timeframe_str
bars['timestamp'] = bars['timestamp_col'].apply(lambda ts: ts.isoformat())

# Select and reorder columns
formatted_data = bars[[
'symbol', 'timeframe', 'timestamp', 'open', 'high', 'low', 'close', 'volume'
]].to_dict('records')

logging.info(f"Successfully fetched {len(formatted_data)} data points for {symbol}.")
return formatted_data

except Exception as e:
logging.error(f"Failed to fetch historical data for {symbol}: {e}", exc_info=True)
raise

# Example usage:
if __name__ == '__main__':
load_dotenv()
logging.basicConfig(level=logging.INFO)

API_KEY = os.getenv("ALPACA_API_KEY")
SECRET_KEY = os.getenv("ALPACA_SECRET_KEY")

if not API_KEY or not SECRET_KEY:
print("Please set ALPACA_API_KEY and ALPACA_SECRET_KEY environment variables.")
else:
client = AlpacaClient(API_KEY, SECRET_KEY)

# Calculate dates for the last 2 years
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=2*365)

# Fetch data
historical_data = client.fetch_historical_prices(
'GOOG',
'1d',
start_dt.strftime('%Y-%m-%d'),
end_dt.strftime('%Y-%m-%d')
)

if historical_data:
print(f"Fetched {len(historical_data)} records.")
print("First 5 records:")
for record in historical_data[:5]:
print(record)
67 changes: 60 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import logging
import sys
import uuid
import schedule
import time
import threading
from contextvars import ContextVar
from datetime import datetime, timedelta
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Depends, Security, Request
from fastapi.security import APIKeyHeader
Expand All @@ -11,6 +15,7 @@
from decimal import Decimal

from trading_db import TradingDB
from alpaca_client import AlpacaClient
from models import (
AccountBalance, Position, Order, CreateOrderBody, CreateOrderResponse,
OrderExecutionResponse, Trade, PortfolioMetrics, Price
Expand Down Expand Up @@ -82,26 +87,74 @@ def get_api_key(api_key_header: str = Security(api_key_header)):
# This single instance will be shared across all requests.
db = TradingDB()

# Alpaca API Client
alpaca_client = AlpacaClient(
api_key=os.environ.get("ALPACA_API_KEY"),
secret_key=os.environ.get("ALPACA_SECRET_KEY")
)

# --- Scheduled Jobs ---
def run_ingestion_job():
"""
Defines the scheduled job to fetch and ingest historical data.
"""
logging.info("Scheduler starting historical data ingestion job...")
symbols_to_fetch = ["GOOG"]
timeframes_to_fetch = ["4h", "1d"]

# Calculate date range for the last 2 years
end_date = datetime.now().strftime('%Y-%m-%d')
start_date = (datetime.now() - timedelta(days=2*365)).strftime('%Y-%m-%d')

for symbol in symbols_to_fetch:
for timeframe in timeframes_to_fetch:
try:
price_data = alpaca_client.fetch_historical_prices(
symbol, timeframe, start_date, end_date
)
if price_data:
db.ingest_historical_prices(price_data)
else:
logging.warning(f"No price data to ingest for {symbol} ({timeframe}).")
except Exception as e:
# Log the error but continue to the next symbol/timeframe
logging.error(f"Failed to ingest data for {symbol} ({timeframe}): {e}", exc_info=True)

logging.info("Scheduler finished historical data ingestion job.")

def run_scheduler():
"""
Continuously runs pending scheduled jobs.
"""
while True:
schedule.run_pending()
time.sleep(1)

# --- Events ---
@app.on_event("startup")
async def startup_event():
"""Ensure the database and tables are created on application startup."""
"""Ensure the database and tables are created and start the scheduler."""
logging.info("Database Agent API starting up.")
try:
# The TradingDB __init__ now ensures the DB exists.
# This call ensures the tables exist.
db.setup_database()
logging.info("Database tables verification/creation complete.")

# Schedule the job
schedule.every().day.at("00:00").do(run_ingestion_job)
logging.info("Scheduled data ingestion job to run daily at 00:00.")

# Run the scheduler in a background thread
scheduler_thread = threading.Thread(target=run_scheduler, daemon=True)
scheduler_thread.start()
logging.info("Scheduler started in a background thread.")

except Exception as e:
logging.critical(f"FATAL: Database table setup failed on startup: {e}", exc_info=True)
# In a real-world scenario, you might want the app to fail fast
# if the database is not ready.
logging.critical(f"FATAL: Application startup failed: {e}", exc_info=True)
raise

@app.on_event("shutdown")
async def shutdown_event():
logging.info("Database Agent API shutting down.")
# The db connection is closed automatically by the TradingDB destructor.

# --- API Endpoints ---

Expand Down
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
pytest
httpx
alpaca-py
tenacity
schedule
pytz
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ uvicorn[standard]
python-dotenv
pydantic
psycopg2-binary
alpaca-py
tenacity
schedule
pytz
62 changes: 52 additions & 10 deletions trading_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,28 +230,29 @@ def setup_database(self):
CREATE TABLE IF NOT EXISTS prices (
price_id {pk_type},
symbol TEXT NOT NULL,
timeframe TEXT NOT NULL,
timestamp {timestamp_type} NOT NULL,
open {numeric_type} NOT NULL,
high {numeric_type} NOT NULL,
low {numeric_type} NOT NULL,
close {numeric_type} NOT NULL,
volume BIGINT NOT NULL,
UNIQUE (symbol, timestamp)
UNIQUE (symbol, timeframe, timestamp)
);
""")

# Insert sample data for prices if it doesn't exist
cursor.execute(f"SELECT * FROM prices WHERE symbol = {self.param_style}", ('AAPL',))
if cursor.fetchone() is None:
sample_prices = [
('AAPL', '2025-01-01T10:00:00Z', '150.00', '152.00', '149.50', '151.50', 1000000),
('AAPL', '2025-01-01T11:00:00Z', '151.50', '153.00', '151.00', '152.50', 1200000),
('GOOG', '2025-01-01T10:00:00Z', '2800.00', '2810.00', '2795.00', '2805.00', 500000)
('AAPL', '1h', '2025-01-01T10:00:00Z', '150.00', '152.00', '149.50', '151.50', 1000000),
('AAPL', '1h', '2025-01-01T11:00:00Z', '151.50', '153.00', '151.00', '152.50', 1200000),
('GOOG', '1d', '2025-01-01T10:00:00Z', '2800.00', '2810.00', '2795.00', '2805.00', 500000)
]
for price_data in sample_prices:
cursor.execute(f"""
INSERT INTO prices (symbol, timestamp, open, high, low, close, volume)
VALUES ({self.param_style}, {self.param_style}, {self.param_style}, {self.param_style}, {self.param_style}, {self.param_style}, {self.param_style})
INSERT INTO prices (symbol, timeframe, timestamp, open, high, low, close, volume)
VALUES ({self.param_style}, {self.param_style}, {self.param_style}, {self.param_style}, {self.param_style}, {self.param_style}, {self.param_style}, {self.param_style})
""", price_data)

cursor.execute(f"SELECT * FROM accounts WHERE account_name = {self.param_style}", ('main_account',))
Expand Down Expand Up @@ -489,16 +490,15 @@ def get_trade_history(self, account_id: int, limit: int = 50, offset: int = 0, s
cursor.close()

def get_price_history(self, symbol: str, timeframe: str = '1h', limit: int = 100) -> List[Dict[str, Any]]:
# Note: timeframe is not used in this MVP implementation.
# A real implementation would require time-series aggregation logic.
cursor = self.get_cursor()
try:
query = f"SELECT * FROM prices WHERE symbol = {self.param_style} ORDER BY timestamp DESC LIMIT {self.param_style}"
cursor.execute(query, (symbol.upper(), limit))
query = f"SELECT * FROM prices WHERE symbol = {self.param_style} AND timeframe = {self.param_style} ORDER BY timestamp DESC LIMIT {self.param_style}"
cursor.execute(query, (symbol.upper(), timeframe, limit))

return [
{
'symbol': row['symbol'],
'timeframe': row['timeframe'],
'timestamp': row['timestamp'],
'open': self._to_decimal(row['open']),
'high': self._to_decimal(row['high']),
Expand All @@ -521,6 +521,48 @@ def _get_latest_price(self, symbol: str) -> Optional[Decimal]:
finally:
cursor.close()

def ingest_historical_prices(self, price_data: List[Dict[str, Any]]):
"""
Ingests a list of historical price data points into the database.
It uses an 'ON CONFLICT DO NOTHING' clause to prevent duplicates
based on the (symbol, timeframe, timestamp) unique constraint.
This is a highly efficient bulk operation.
"""
if not price_data:
logging.info("No price data provided to ingest.")
return

if self.db_type == 'sqlite':
# SQLite uses a different syntax for ON CONFLICT
query = """
INSERT INTO prices (symbol, timeframe, timestamp, open, high, low, close, volume)
VALUES (:symbol, :timeframe, :timestamp, :open, :high, :low, :close, :volume)
ON CONFLICT(symbol, timeframe, timestamp) DO NOTHING;
"""
else: # PostgreSQL
query = """
INSERT INTO prices (symbol, timeframe, timestamp, open, high, low, close, volume)
VALUES (%(symbol)s, %(timeframe)s, %(timestamp)s, %(open)s, %(high)s, %(low)s, %(close)s, %(volume)s)
ON CONFLICT (symbol, timeframe, timestamp) DO NOTHING;
"""

cursor = self.get_cursor()
try:
# For SQLite, execute many with a list of dicts
if self.db_type == 'sqlite':
cursor.executemany(query, price_data)
else: # For psycopg2, execute_batch is highly efficient
psycopg2.extras.execute_batch(cursor, query, price_data)

self.conn.commit()
logging.info(f"Successfully ingested or skipped {len(price_data)} price records.")
except Exception as e:
logging.error(f"Database error during price ingestion: {e}", exc_info=True)
self.conn.rollback()
raise
finally:
cursor.close()

def get_portfolio_metrics(self, account_id: int) -> Optional[Dict[str, Any]]:
cash_balance = self.get_account_balance(account_id)
if cash_balance is None:
Expand Down