Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time series forcasting #40

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,811 changes: 3,225 additions & 586 deletions notebooks/Chronos.ipynb

Large diffs are not rendered by default.

77 changes: 52 additions & 25 deletions scripts/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,38 @@
from chronos import ChronosPipeline
import matplotlib.pyplot as plt
import numpy as np
from sqlalchemy import create_engine
import os

# RDS connection information
rds_host = os.getenv('PG_HOST')
rds_port = os.getenv('PG_PORT')
rds_db = os.getenv('PG_DATABASE')
rds_user = os.getenv('PG_USER')
rds_password = os.getenv('PG_PASSWORD')

# Create engine for database connection
engine = create_engine(f'postgresql+psycopg2://{rds_user}:{rds_password}@{rds_host}:{rds_port}/{rds_db}')

# Define the crypto coins data source dictionary
crypto_data_dict = {
'BTC': "/content/drive/MyDrive/backtesting/datas/yfinance/BTC-USD.csv",
'BNB': "/content/drive/MyDrive/backtesting/datas/yfinance/BNB-USD.csv",
'ETH': "/content/drive/MyDrive/backtesting/datas/yfinance/ETH-USD.csv"
}

def predict_and_plot_crypto_data(
coin_name,
crypto_data_dict,
df,
model_name="amazon/chronos-t5-small",
prediction_length=12,
num_samples=20):
"""
Predicts and plots cryptocurrency data for a single coin.
Args:
coin_name (str): Name of the cryptocurrency (e.g., 'BTC', 'ETH').
crypto_data_dict (dict): Dictionary containing data for each cryptocurrency.
Keys should be coin names, values should be file paths to CSV data.
df (pd.DataFrame): DataFrame containing OHLCV data with columns ['timestamp', 'open', 'high', 'low', 'close', 'volume'].
model_name (str): Name of the pre-trained Chronos model (default: "amazon/chronos-t5-small").
prediction_length (int): Number of future data points to predict (default: 12).
num_samples (int): Number of prediction samples to generate (default: 20).
Raises:
ValueError: If coin_name is not found in crypto_data_dict.
Returns:
tuple: Tuple containing forecast index and median prediction array.
"""

# Check if coin exists in data
if coin_name not in crypto_data_dict:
raise ValueError(f"Coin '{coin_name}' not found in data dictionary.")

# Load data for the specified coin
csv_file = crypto_data_dict[coin_name]
df = pd.read_csv(csv_file)

# Initialize Chronos pipeline
pipeline = ChronosPipeline.from_pretrained(
model_name,
Expand All @@ -53,7 +44,7 @@ def predict_and_plot_crypto_data(

# Perform prediction
forecast = pipeline.predict(
context=torch.tensor(df["Close"]),
context=torch.tensor(df["close"].values), # Assuming 'close' is the column name in your DataFrame
prediction_length=prediction_length,
num_samples=num_samples,
)
Expand All @@ -64,14 +55,50 @@ def predict_and_plot_crypto_data(

# Plot and visualize predictions
plt.figure(figsize=(10, 6)) # Adjust figure size as needed
plt.plot(df["Close"], label="History")
plt.plot(df["close"], label="History") # Assuming 'close' is the column name in your DataFrame
plt.plot(forecast_index, median, label="Median Prediction")
plt.fill_between(forecast_index, low, high, alpha=0.2, label="Prediction Range")
plt.title(f"Predicted {coin_name} Prices")
plt.title(f"Predicted Prices")
plt.xlabel("Time")
plt.ylabel("Price")
plt.legend()
plt.grid(True) # Add gridlines
plt.show()

return forecast_index, median # Optionally return forecast data for further use


def fetch_data_from_db(symbol):
try:
# Construct table name based on symbol
table_name = f'ohlcv_{symbol.replace("-", "_")}'

# Query to fetch data from database
query = f"SELECT timestamp, open, high, low, close, volume FROM {table_name};"

# Fetch data from database into a DataFrame
df = pd.read_sql(query, con=engine, parse_dates=['timestamp'])

# Set 'timestamp' column as index
df.set_index('timestamp', inplace=True)

return df

except Exception as e:
print(f"Error fetching data from database for {symbol}: {str(e)}")
return None


if __name__ == "__main__":
# Example usage: Fetch data for a specific symbol from database
symbol = 'BTC-USD'
df = fetch_data_from_db(symbol)

if df is not None:
print(f"Fetched data for {symbol}:")
print(df.head())

# Perform prediction and plotting
predict_and_plot_crypto_data(df)
else:
print(f"Failed to fetch data for {symbol}")