Skip to content

Commit

Permalink
save changes
Browse files Browse the repository at this point in the history
  • Loading branch information
derejehinsermu committed Jun 26, 2024
2 parents d462f5d + 0cb1a9c commit 6eb658f
Show file tree
Hide file tree
Showing 2 changed files with 3,277 additions and 611 deletions.
3,811 changes: 3,225 additions & 586 deletions notebooks/Chronos.ipynb

Large diffs are not rendered by default.

77 changes: 52 additions & 25 deletions scripts/forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,38 @@
from chronos import ChronosPipeline
import matplotlib.pyplot as plt
import numpy as np
from sqlalchemy import create_engine
import os

# RDS connection information
rds_host = os.getenv('PG_HOST')
rds_port = os.getenv('PG_PORT')
rds_db = os.getenv('PG_DATABASE')
rds_user = os.getenv('PG_USER')
rds_password = os.getenv('PG_PASSWORD')

# Create engine for database connection
engine = create_engine(f'postgresql+psycopg2://{rds_user}:{rds_password}@{rds_host}:{rds_port}/{rds_db}')

# Define the crypto coins data source dictionary
crypto_data_dict = {
'BTC': "/content/drive/MyDrive/backtesting/datas/yfinance/BTC-USD.csv",
'BNB': "/content/drive/MyDrive/backtesting/datas/yfinance/BNB-USD.csv",
'ETH': "/content/drive/MyDrive/backtesting/datas/yfinance/ETH-USD.csv"
}

def predict_and_plot_crypto_data(
coin_name,
crypto_data_dict,
df,
model_name="amazon/chronos-t5-small",
prediction_length=12,
num_samples=20):
"""
Predicts and plots cryptocurrency data for a single coin.
Args:
coin_name (str): Name of the cryptocurrency (e.g., 'BTC', 'ETH').
crypto_data_dict (dict): Dictionary containing data for each cryptocurrency.
Keys should be coin names, values should be file paths to CSV data.
df (pd.DataFrame): DataFrame containing OHLCV data with columns ['timestamp', 'open', 'high', 'low', 'close', 'volume'].
model_name (str): Name of the pre-trained Chronos model (default: "amazon/chronos-t5-small").
prediction_length (int): Number of future data points to predict (default: 12).
num_samples (int): Number of prediction samples to generate (default: 20).
Raises:
ValueError: If coin_name is not found in crypto_data_dict.
Returns:
tuple: Tuple containing forecast index and median prediction array.
"""

# Check if coin exists in data
if coin_name not in crypto_data_dict:
raise ValueError(f"Coin '{coin_name}' not found in data dictionary.")

# Load data for the specified coin
csv_file = crypto_data_dict[coin_name]
df = pd.read_csv(csv_file)

# Initialize Chronos pipeline
pipeline = ChronosPipeline.from_pretrained(
model_name,
Expand All @@ -53,7 +44,7 @@ def predict_and_plot_crypto_data(

# Perform prediction
forecast = pipeline.predict(
context=torch.tensor(df["Close"]),
context=torch.tensor(df["close"].values), # Assuming 'close' is the column name in your DataFrame
prediction_length=prediction_length,
num_samples=num_samples,
)
Expand All @@ -64,14 +55,50 @@ def predict_and_plot_crypto_data(

# Plot and visualize predictions
plt.figure(figsize=(10, 6)) # Adjust figure size as needed
plt.plot(df["Close"], label="History")
plt.plot(df["close"], label="History") # Assuming 'close' is the column name in your DataFrame
plt.plot(forecast_index, median, label="Median Prediction")
plt.fill_between(forecast_index, low, high, alpha=0.2, label="Prediction Range")
plt.title(f"Predicted {coin_name} Prices")
plt.title(f"Predicted Prices")
plt.xlabel("Time")
plt.ylabel("Price")
plt.legend()
plt.grid(True) # Add gridlines
plt.show()

return forecast_index, median # Optionally return forecast data for further use


def fetch_data_from_db(symbol):
try:
# Construct table name based on symbol
table_name = f'ohlcv_{symbol.replace("-", "_")}'

# Query to fetch data from database
query = f"SELECT timestamp, open, high, low, close, volume FROM {table_name};"

# Fetch data from database into a DataFrame
df = pd.read_sql(query, con=engine, parse_dates=['timestamp'])

# Set 'timestamp' column as index
df.set_index('timestamp', inplace=True)

return df

except Exception as e:
print(f"Error fetching data from database for {symbol}: {str(e)}")
return None


if __name__ == "__main__":
# Example usage: Fetch data for a specific symbol from database
symbol = 'BTC-USD'
df = fetch_data_from_db(symbol)

if df is not None:
print(f"Fetched data for {symbol}:")
print(df.head())

# Perform prediction and plotting
predict_and_plot_crypto_data(df)
else:
print(f"Failed to fetch data for {symbol}")

0 comments on commit 6eb658f

Please sign in to comment.