In [10]:
import pandas as pd
import sys, os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), "src"))

# Load the dataset
df = pd.read_csv(os.path.join("..", "data", "Bitcoin History 2010-2024.csv"))

# Rename the 'Price' column to 'Close' and remove commas from the 'Close' values, then convert to float
df = df.rename(columns={"Price": "Close"})
df['Close'] = df['Close'].str.replace(',', '').astype(float)

# Convert the 'Date' column to datetime format
df['Date'] = pd.to_datetime(df['Date'])

# Filter the DataFrame to only include rows within the specified start and end dates
start_date = '2016-07-09'
end_date = '2018-01-09'
mask = (df['Date'] >= start_date) & (df['Date'] <= end_date)
df_filtered2 = df.loc[mask]

# Reset the index of the filtered DataFrame
df_filtered2 = df_filtered2.reset_index(drop=True)

df_filtered2 = df_filtered2.sort_values(by='Date')

# Reset the index of the filtered and sorted DataFrame
df_filtered2 = df_filtered2.reset_index(drop=True)
df_filtered2

Unnamed: 0,Date,Close,Open,High,Low,Vol.,Change %
0,2016-07-09,651.8,662.8,663.5,620.7,96.01K,-1.65%
1,2016-07-10,647.1,651.8,652.0,636.8,29.26K,-0.72%
2,2016-07-11,646.7,647.1,657.5,640.0,45.94K,-0.06%
3,2016-07-12,670.6,646.7,673.2,644.6,66.59K,3.69%
4,2016-07-13,661.2,670.6,672.0,656.0,62.96K,-1.40%
...,...,...,...,...,...,...,...
545,2018-01-05,16954.8,15180.1,17126.9,14832.4,141.96K,11.69%
546,2018-01-06,17172.3,16954.8,17252.8,16286.6,83.93K,1.28%
547,2018-01-07,16228.2,17174.5,17184.8,15791.1,79.01K,-5.50%
548,2018-01-08,14976.2,16228.3,16302.9,13902.3,142.45K,-7.71%


In [11]:
from prettytable import PrettyTable
from utils import print_stats, plot_multiple_conf_interval
import random
import warnings
from Environment import Environment
from Agent import Agent

def main():
    # ----------------------------- LOAD DATA ---------------------------------------------------------------------------
    path = ''
    df = df_filtered2  # 你需要确保 df_filtered2 已经定义或加载了数据

    # ----------------------------- PARAMETERS --------------------------------
    REPLAY_MEM_SIZE = 10000
    BATCH_SIZE = 40
    GAMMA = 0.99
    EPS_START = 1
    EPS_END = 0.02
    EPS_STEPS = 200
    LEARNING_RATE = 0.0005
    INPUT_DIM = 24
    HIDDEN_DIM = 120
    ACTION_NUMBER = 3
    TARGET_UPDATE = 5
    N_TEST = 10
    TRADING_PERIOD = 500

    # ----------------------------- TRAINING SETUP --------------------------------
    index = random.randrange(len(df) - TRADING_PERIOD - 1)
    train_size = int(TRADING_PERIOD * 0.8)

    profit_dueling_ddqn_return = []
    sharpe_dueling_ddqn_return = []

    profit_train_env = Environment(df[index:index + train_size], "profit")
    sharpe_train_env = Environment(df[index:index + train_size], "sr")

    # ----------------------------- CREATE AGENT (D3QN) --------------------------------
    d3qn_agent = Agent(REPLAY_MEM_SIZE,
                       BATCH_SIZE,
                       GAMMA,
                       EPS_START,
                       EPS_END,
                       EPS_STEPS,
                       LEARNING_RATE,
                       INPUT_DIM,
                       HIDDEN_DIM,
                       ACTION_NUMBER,
                       TARGET_UPDATE,
                       MODEL='ddqn',
                       DOUBLE=True)

    # ----------------------------- TRAIN PROFIT MODEL --------------------------------
    cr_profit_dueling_ddqn = d3qn_agent.train(profit_train_env, path)
    profit_train_env.reset()

    # ----------------------------- TEST PROFIT MODEL --------------------------------
    for i in range(N_TEST):
        print(f"Profit Test {i + 1}")
        index = random.randrange(len(df) - TRADING_PERIOD - 1)
        profit_test_env = Environment(df[index + train_size:index + TRADING_PERIOD], "profit")

        cr_profit_dueling_ddqn_test, _ = d3qn_agent.test(profit_test_env)
        profit_dueling_ddqn_return.append(profit_test_env.cumulative_return)
        profit_test_env.reset()

    # ----------------------------- TRAIN SHARPE MODEL --------------------------------
    cr_sharpe_dueling_ddqn = d3qn_agent.train(sharpe_train_env, path)
    sharpe_train_env.reset()

    # ----------------------------- TEST SHARPE MODEL --------------------------------
    for i in range(N_TEST):
        print(f"Sharpe Test {i + 1}")
        index = random.randrange(len(df) - TRADING_PERIOD - 1)
        sharpe_test_env = Environment(df[index + train_size:index + TRADING_PERIOD], "sr")

        cr_sharpe_dueling_ddqn_test, _ = d3qn_agent.test(sharpe_test_env)
        sharpe_dueling_ddqn_return.append(sharpe_test_env.cumulative_return)
        sharpe_test_env.reset()

    # ----------------------------- STATS + PLOTS --------------------------------
    t = PrettyTable(["Trading System", "Avg. Return (%)", "Max Return (%)", "Min Return (%)", "Std. Dev."])
    print_stats("Profit D3QN", profit_dueling_ddqn_return, t)
    print_stats("Sharpe D3QN", sharpe_dueling_ddqn_return, t)
    print(t)

    plot_multiple_conf_interval(["Profit D3QN", "Sharpe D3QN"],
                                [profit_dueling_ddqn_return, sharpe_dueling_ddqn_return])

if __name__ == "__main__":
    main()


Agent is using device:	cpu
Training:


  0%|          | 0/100 [00:04<?, ?it/s]


KeyboardInterrupt: 