In [1]:
import pandas as pd
from datetime import date
import yfinance as yf
import torch
from chronos import ChronosPipeline

: 

In [2]:

START = "2020-01-01"
TODAY = date.today().strftime("%Y-%m-%d")

# Define a function to load the dataset
def load_data(ticker):
    data = yf.download(ticker, START, TODAY)
    data.reset_index(inplace=True)
    return data

# Load BTC-USD data
data = load_data('BTC-USD')
df = data[['Date', 'Close']]  # Extract Date and Close price columns

# Initialize Chronos pipeline
pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    device_map="cpu",  # Use "cpu" for CPU inference or "mps" for Apple Silicon
    torch_dtype=torch.float32,  # Adjust dtype as needed based on model requirements
)

# Prepare context data (closing prices as a 1D tensor)
context = torch.tensor(df['Close'].values, dtype=torch.float32)

# Perform forecasting
prediction_length = 10  # Adjust as needed
num_samples = 1000  # Adjust as needed
forecast = pipeline.predict(
    context=context,
    prediction_length=prediction_length,
    num_samples=num_samples,
)

# Print or use forecasted results
print(forecast.shape)  # Example of printing forecast shape


[*********************100%%**********************]  1 of 1 completed


2024-06-27 05:21:31.913347: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
import matplotlib.pyplot as plt
# Plotting the results
plt.figure(figsize=(12, 6))
plt.plot(df['Date'], df['Close'], label="Original Close Price")
# Extract the last `prediction_length` dates for x-axis
x_dates = df['Date'].iloc[-prediction_length:].values 
# Use the mean of forecast across samples for y-axis
y_pred = forecast.mean(axis=1).flatten()  # Flatten to match x_dates shape
plt.plot(x_dates, y_pred, 'r', label="Predicted Close Price")
plt.xlabel('Date')
plt.ylabel('Price')
plt.legend()
plt.grid(True)
plt.show()