In [1]:
from __future__ import annotations

import math
from datetime import datetime
from pathlib import Path
from typing import Tuple, Dict, Any
from tqdm.notebook import tqdm

import gymnasium as gym
import numpy as np
import pandas as pd
from pybit.unified_trading import HTTP

In [2]:
def _datetime_to_ms(dt: str | datetime) -> int:
    ts = pd.Timestamp(dt, tz="UTC") if isinstance(dt, str) else pd.Timestamp(dt).tz_convert("UTC")
    return int(ts.timestamp() * 1000)

In [3]:
def _fetch_ohlcv_minute(session: HTTP, symbol: str, start_ms: int, end_ms: int, limit: int = 1000) -> pd.DataFrame:
    rows: list[list] = []
    cur = start_ms
    step_ms = 60_000 # 1 minute in milliseconds
    total_minutes = (end_ms - start_ms) // step_ms

    with tqdm(total=total_minutes, desc=f"Fetching {symbol} OHLCV") as pbar:
        while cur < end_ms:
            start_time_current_call = cur
            resp = session.get_kline(category="linear", symbol=symbol, interval="1", start=cur, limit=limit)
            data = resp["result"]["list"]
            if not data:
                # If no data is returned, assume we reached the end for the requested period
                # Update progress bar to reflect the remaining time as processed
                remaining_minutes = (end_ms - cur) // step_ms
                pbar.update(remaining_minutes)
                break

            rows.extend(data)
            last_ts = int(data[-1][0])

            # Calculate minutes fetched in this call and update progress bar
            minutes_fetched = (last_ts - start_time_current_call + step_ms) // step_ms
            pbar.update(minutes_fetched)

            # Set cursor for the next iteration
            cur = last_ts + step_ms

            # Ensure progress doesn't exceed total if API returns data beyond end_ms
            if pbar.n > total_minutes:
                 pbar.n = total_minutes
                 pbar.refresh()

        # Ensure the progress bar completes if the loop finishes early
        if pbar.n < total_minutes:
             pbar.update(total_minutes - pbar.n)


    df = pd.DataFrame(rows, columns=["startTime", "open", "high", "low", "close", "volume", "turnover"])
    df["startTime"] = pd.to_datetime(df["startTime"], unit="ms", utc=True)
    df.set_index("startTime", inplace=True)
    # Filter data strictly within the requested range [start_ms, end_ms)
    df = df[(df.index >= pd.to_datetime(start_ms, unit='ms', utc=True)) & (df.index < pd.to_datetime(end_ms, unit='ms', utc=True))]
    df = df.astype(float)[["open", "high", "low", "close", "volume"]]
    return df.sort_index()

In [4]:
def _fetch_long_short_ratio(session: HTTP, symbol: str, start_ms: int, end_ms: int) -> pd.Series:
    rows = []
    limit = 500
    interval_ms = 5 * 60_000 # 5 minutes in milliseconds
    total_intervals = (end_ms - start_ms) // interval_ms
    last_fetched_ts = start_ms

    with tqdm(total=total_intervals, desc=f"Fetching {symbol} Long/Short Ratio") as pbar:
        while True:
            start_time_current_call = last_fetched_ts
            try:
                resp = session.get_long_short_ratio(
                    category="linear",
                    symbol=symbol,
                    period="5min",
                    startTime=start_time_current_call, # Use last fetched timestamp to avoid overlap issues
                    endTime=end_ms,
                    limit=limit
                )
                data = resp["result"]["list"]
                if not data:
                    # No more data in the range for this call
                    remaining_intervals = max(0, (end_ms - last_fetched_ts) // interval_ms)
                    pbar.update(remaining_intervals)
                    break # Exit loop if no data is returned

                rows.extend(data)
                current_last_ts = int(data[-1]["timestamp"])

                # Calculate intervals fetched based on time covered
                intervals_fetched = max(0, (current_last_ts - last_fetched_ts) // interval_ms)
                # Add 1 interval for the last timestamp itself if it wasn't fully covered by the division
                if (current_last_ts - last_fetched_ts) % interval_ms > 0 or intervals_fetched == 0:
                     intervals_fetched += 1


                pbar.update(intervals_fetched)
                last_fetched_ts = current_last_ts + interval_ms # Set start for next potential fetch

                # Check if we have fetched data beyond the requested end_ms
                if last_fetched_ts >= end_ms:
                     # Ensure progress bar completes if we fetched up to or beyond end_ms
                     if pbar.n < total_intervals:
                         pbar.update(total_intervals - pbar.n)
                     break

            except Exception as e:
                print(f"An error occurred: {e}")
                # Update progress bar to reflect the assumed end if an error occurs
                if pbar.n < total_intervals:
                    pbar.update(total_intervals - pbar.n)
                break # Exit loop on error

        # Ensure the progress bar completes fully if the loop finishes early
        if pbar.n < total_intervals:
             pbar.update(total_intervals - pbar.n)


    if not rows:
        # Return an empty series with the correct dtype if no data was fetched
        return pd.Series(dtype=float, name="ls_ratio")

    df = pd.DataFrame(rows)
    df["timestamp"] = pd.to_datetime(df["timestamp"].astype(int), unit="ms", utc=True)
    df.set_index("timestamp", inplace=True)
    df = df.sort_index()
    # Filter data strictly within the requested range [start_ms, end_ms)
    df = df[(df.index >= pd.to_datetime(start_ms, unit='ms', utc=True)) & (df.index < pd.to_datetime(end_ms, unit='ms', utc=True))]
    # Remove potential duplicates from overlapping API calls if cursor wasn't effective
    df = df[~df.index.duplicated(keep='first')]

    if df.empty:
        return pd.Series(dtype=float, name="ls_ratio")

    df[["buyRatio", "sellRatio"]] = df[["buyRatio", "sellRatio"]].astype(float)
    # Avoid division by zero if both buyRatio and sellRatio are 0
    total_ratio = df["buyRatio"] + df["sellRatio"]
    ls_ratio = df["buyRatio"].divide(total_ratio).fillna(0.5) # Fill NaN with 0.5 (neutral) or 0

    # Resample to 1 minute and forward fill
    ls_ratio = ls_ratio.resample("1min").ffill()
    # Ensure the resampled series covers the full requested range, padding with ffill/bfill
    full_range_index = pd.date_range(start=pd.to_datetime(start_ms, unit='ms', utc=True),
                                     end=pd.to_datetime(end_ms - 1, unit='ms', utc=True), # end is exclusive
                                     freq='1min')
    ls_ratio = ls_ratio.reindex(full_range_index).ffill().bfill() # Forward fill then backfill NaNs

    return ls_ratio.rename("ls_ratio")

In [5]:
def _fetch_funding_rate(session: HTTP, symbol: str, start_ms: int, end_ms: int) -> pd.Series:
    rows = []
    cursor = None
    limit = 200 # Max limit for funding rate history
    # Estimate total intervals for progress bar (funding typically every 8 hours)
    interval_ms = 8 * 60 * 60_000
    total_intervals = max(1, (end_ms - start_ms) // interval_ms)
    last_fetched_ts = start_ms # Track the timestamp of the last fetched record for progress update

    with tqdm(total=total_intervals, desc=f"Fetching {symbol} Funding Rate") as pbar:
        while True:
            try:
                resp = session.get_funding_rate_history(
                    category="linear",
                    symbol=symbol,
                    # Rely primarily on cursor for pagination, filter by time later
                    limit=limit,
                    cursor=cursor
                )

                data = resp["result"]["list"]
                if not data:
                    # No more data from API for this cursor
                    if pbar.n < total_intervals:
                        pbar.update(total_intervals - pbar.n) # Complete the bar
                    break

                rows.extend(data)
                current_last_ts = int(data[-1]["fundingRateTimestamp"])

                # Update progress based on time covered since last fetch
                if current_last_ts > last_fetched_ts:
                    intervals_covered = (current_last_ts - last_fetched_ts) // interval_ms
                    # Ensure at least 1 interval is credited if any time passed and data received
                    if intervals_covered == 0 and current_last_ts > last_fetched_ts:
                         intervals_covered = 1
                    # Cap update to not exceed total
                    update_amount = min(intervals_covered, total_intervals - pbar.n)
                    if update_amount > 0:
                        pbar.update(update_amount)
                    last_fetched_ts = current_last_ts # Update last fetched timestamp

                cursor = resp["result"].get("nextPageCursor")
                if not cursor:
                    # No next page cursor means we are done fetching
                    if pbar.n < total_intervals:
                        pbar.update(total_intervals - pbar.n) # Complete the bar
                    break

            except Exception as e:
                print(f"An error occurred during funding rate fetch: {e}")
                # Update progress bar to reflect the assumed end if an error occurs
                if pbar.n < total_intervals:
                    pbar.update(total_intervals - pbar.n)
                break # Exit loop on error

        # Ensure the progress bar completes fully if the loop finished early
        if pbar.n < total_intervals:
             pbar.update(total_intervals - pbar.n)

    if not rows:
        # Return an empty series with the correct dtype and name if no data was fetched
        return pd.Series(dtype=float, name="fundingRate")

    df = pd.DataFrame(rows)
    df["fundingRateTimestamp"] = pd.to_datetime(df["fundingRateTimestamp"].astype(int), unit="ms", utc=True)
    df.set_index("fundingRateTimestamp", inplace=True)
    df = df.sort_index()

    # Filter data strictly within the requested range [start_ms, end_ms) AFTER collecting all data
    df = df[(df.index >= pd.to_datetime(start_ms, unit='ms', utc=True)) & (df.index < pd.to_datetime(end_ms, unit='ms', utc=True))]

    if df.empty:
        return pd.Series(dtype=float, name="fundingRate")

    # Remove potential duplicates just in case (e.g., overlapping calls if cursor logic had issues)
    df = df[~df.index.duplicated(keep='first')]

    df["fundingRate"] = df["fundingRate"].astype(float)

    # Resample to 1 minute and interpolate linearly
    funding_series = df["fundingRate"].resample("1min").interpolate(method='linear')

    # Ensure the resampled series covers the full requested range, padding with ffill/bfill
    full_range_index = pd.date_range(start=pd.to_datetime(start_ms, unit='ms', utc=True),
                                     end=pd.to_datetime(end_ms - 1, unit='ms', utc=True), # end is exclusive
                                     freq='1min')
    # Reindex to the full range, then fill any remaining NaNs at the beginning/end
    # Interpolation handles NaNs between points, ffill/bfill handle edges.
    funding_series = funding_series.reindex(full_range_index).ffill().bfill()

    return funding_series.rename("fundingRate") # Ensure series name is set

In [6]:
def fetch_bybit_data(
    symbol: str = "BTCUSDT",
    start: str | datetime = "2025-03-01 00:00:00",
    end: str | datetime = "2025-04-01 00:00:00",
    save_csv: bool = False,
    out_dir: str | Path = "data",
) -> Tuple[pd.DataFrame, pd.Series, pd.Series]:
    out = Path(out_dir)
    ohlcv_path = out / f"{symbol}_ohlcv_1min.csv"
    lsr_path = out / f"{symbol}_long_short_ratio.csv"
    funding_path = out / f"{symbol}_funding_rate.csv"

    # Check if all files exist and save_csv is True
    if save_csv and ohlcv_path.exists() and lsr_path.exists() and funding_path.exists():
        print(f"Loading data from CSV files in {out_dir}...")
        try:
            ohlcv = pd.read_csv(ohlcv_path, index_col="startTime", parse_dates=True)
            ohlcv = ohlcv[~ohlcv.index.duplicated(keep='first')]
            # Ensure index is UTC DateTimeIndex with 1min frequency if possible
            ohlcv.index = pd.to_datetime(ohlcv.index, utc=True)
            # Attempt to infer frequency, may return None if irregular
            ohlcv = ohlcv.asfreq('1min') # This might fail if there are gaps, handle carefully

            lsr = pd.read_csv(lsr_path, index_col=0, parse_dates=True).squeeze("columns")
            lsr.index = pd.to_datetime(lsr.index, utc=True)
            lsr.name = "ls_ratio" # Ensure Series name is set
            lsr = lsr.asfreq('1min') # Ensure it's 1min frequency

            funding = pd.read_csv(funding_path, index_col=0, parse_dates=True).squeeze("columns")
            funding.index = pd.to_datetime(funding.index, utc=True)
            funding.name = "fundingRate" # Ensure Series name is set
            funding = funding.asfreq('1min') # Ensure it's 1min frequency

            print("Data loaded successfully from CSV.")
            return ohlcv, lsr, funding
        except Exception as e:
            print(f"Error loading data from CSV: {e}. Fetching from API instead.")
            # Fall through to fetch from API if loading fails

    # Fetch from API if files don't exist or save_csv is False or loading failed
    print("Fetching data from Bybit API...")
    session = HTTP(testnet=False)
    start_ms, end_ms = _datetime_to_ms(start), _datetime_to_ms(end)
    ohlcv = _fetch_ohlcv_minute(session, symbol, start_ms, end_ms)
    lsr = _fetch_long_short_ratio(session, symbol, start_ms, end_ms)
    funding = _fetch_funding_rate(session, symbol, start_ms, end_ms)

    if save_csv:
        print(f"Saving data to CSV files in {out_dir}...")
        out.mkdir(parents=True, exist_ok=True)
        ohlcv.to_csv(ohlcv_path)
        lsr.to_csv(lsr_path, header=True) # Save header for Series
        funding.to_csv(funding_path, header=True) # Save header for Series
        print("Data saved successfully.")

    return ohlcv, lsr, funding

In [7]:
ohlcv_df, ls_series, funding_series = fetch_bybit_data(save_csv=True)

Loading data from CSV files in data...
Data loaded successfully from CSV.


In [8]:
from IPython.display import display

print("OHLCV Data:")
display(ohlcv_df.head())
display(ohlcv_df.tail())
display(ohlcv_df.info())

print("\nLong/Short Ratio Data:")
display(ls_series.head())
display(ls_series.tail())
display(ls_series.info())

print("\nFunding Rate Data:")
display(funding_series.head())
display(funding_series.tail())
display(funding_series.info())

OHLCV Data:


Unnamed: 0_level_0,open,high,low,close,volume
startTime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2025-03-01 00:00:00+00:00,84307.6,84344.5,84279.3,84292.4,22.369
2025-03-01 00:01:00+00:00,84292.4,84292.5,84221.4,84247.9,44.954
2025-03-01 00:02:00+00:00,84247.9,84255.0,84230.5,84230.5,9.405
2025-03-01 00:03:00+00:00,84230.5,84272.9,84209.4,84272.9,21.803
2025-03-01 00:04:00+00:00,84272.9,84313.7,84251.6,84251.6,17.315


Unnamed: 0_level_0,open,high,low,close,volume
startTime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2025-03-31 23:55:00+00:00,82500.0,82500.0,82467.4,82481.7,29.694
2025-03-31 23:56:00+00:00,82481.7,82511.6,82481.7,82511.6,22.108
2025-03-31 23:57:00+00:00,82511.6,82525.5,82508.8,82508.8,24.871
2025-03-31 23:58:00+00:00,82508.8,82546.9,82504.8,82521.7,45.531
2025-03-31 23:59:00+00:00,82521.7,82521.7,82494.0,82504.4,19.819


<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 44640 entries, 2025-03-01 00:00:00+00:00 to 2025-03-31 23:59:00+00:00
Freq: min
Data columns (total 5 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   open    44640 non-null  float64
 1   high    44640 non-null  float64
 2   low     44640 non-null  float64
 3   close   44640 non-null  float64
 4   volume  44640 non-null  float64
dtypes: float64(5)
memory usage: 2.0 MB


None


Long/Short Ratio Data:


2025-03-01 00:00:00+00:00    0.5598
2025-03-01 00:01:00+00:00    0.5598
2025-03-01 00:02:00+00:00    0.5598
2025-03-01 00:03:00+00:00    0.5598
2025-03-01 00:04:00+00:00    0.5598
Freq: min, Name: ls_ratio, dtype: float64

2025-03-31 23:55:00+00:00    0.5524
2025-03-31 23:56:00+00:00    0.5524
2025-03-31 23:57:00+00:00    0.5524
2025-03-31 23:58:00+00:00    0.5524
2025-03-31 23:59:00+00:00    0.5524
Freq: min, Name: ls_ratio, dtype: float64

<class 'pandas.core.series.Series'>
DatetimeIndex: 44640 entries, 2025-03-01 00:00:00+00:00 to 2025-03-31 23:59:00+00:00
Freq: min
Series name: ls_ratio
Non-Null Count  Dtype  
--------------  -----  
44640 non-null  float64
dtypes: float64(1)
memory usage: 697.5 KB


None


Funding Rate Data:


2025-03-01 00:00:00+00:00    0.0001
2025-03-01 00:01:00+00:00    0.0001
2025-03-01 00:02:00+00:00    0.0001
2025-03-01 00:03:00+00:00    0.0001
2025-03-01 00:04:00+00:00    0.0001
Freq: min, Name: fundingRate, dtype: float64

2025-03-31 23:55:00+00:00   -0.000003
2025-03-31 23:56:00+00:00   -0.000003
2025-03-31 23:57:00+00:00   -0.000003
2025-03-31 23:58:00+00:00   -0.000003
2025-03-31 23:59:00+00:00   -0.000003
Freq: min, Name: fundingRate, dtype: float64

<class 'pandas.core.series.Series'>
DatetimeIndex: 44640 entries, 2025-03-01 00:00:00+00:00 to 2025-03-31 23:59:00+00:00
Freq: min
Series name: fundingRate
Non-Null Count  Dtype  
--------------  -----  
44640 non-null  float64
dtypes: float64(1)
memory usage: 697.5 KB


None

In [9]:
def sma(series: pd.Series, period: int) -> pd.Series:
    return series.rolling(window=period, min_periods=period).mean()

In [10]:
def ema(series: pd.Series, period: int) -> pd.Series:
    return series.ewm(span=period, adjust=False).mean()

In [11]:
def macd(df: pd.DataFrame, fast: int = 12, slow: int = 26, signal: int = 9) -> pd.DataFrame:
    ema_fast = ema(df["close"], fast)
    ema_slow = ema(df["close"], slow)
    macd_line = ema_fast - ema_slow
    signal_line = ema(macd_line, signal)
    hist = macd_line - signal_line
    return pd.DataFrame({"macd": macd_line, "macd_signal": signal_line, "macd_hist": hist})

In [12]:
def rsi(series: pd.Series, period: int = 14) -> pd.Series:
    delta = series.diff()
    gain = (delta.where(delta > 0, 0.0)).rolling(period).mean()
    loss = (-delta.where(delta < 0, 0.0)).rolling(period).mean()
    rs = gain / (loss + 1e-12)
    return 100 - (100 / (1 + rs))

In [13]:
def connors_rsi(df: pd.DataFrame, rsi_period: int = 3, streak_rsi_period: int = 2, pct_rank_period: int = 100) -> pd.Series:
    close = df["close"]
    # (1) 价格 RSI
    rsi_cl = rsi(close, rsi_period)
    # (2) 连涨/跌天数
    streak = np.sign(close.diff()).fillna(0)
    streak = streak.groupby((streak != streak.shift()).cumsum()).cumsum()
    rsi_streak = rsi(streak, streak_rsi_period)
    # (3) 当日涨跌幅在过去 n 日百分位
    pct_change = close.pct_change().fillna(0)
    pct_rank = pct_change.rolling(pct_rank_period).apply(lambda x: pd.Series(x).rank(pct=True).iloc[-1] * 100, raw=False)
    # CRSI = 上述三者平均
    crsi = (rsi_cl + rsi_streak + pct_rank) / 3.0
    return crsi

In [14]:
def support_resistance(df: pd.DataFrame, lookback: int = 60) -> Tuple[pd.Series, pd.Series]:
    """返回 (support, resistance) 支撑 / 压力位"""
    rolling_low = df["low"].rolling(lookback).min()
    rolling_high = df["high"].rolling(lookback).max()
    return rolling_low, rolling_high

In [15]:
class BitcoinFuturesEnv(gym.Env):
    """BTC 永续合约环境（线性、USDT 计价）"""

    metadata = {"render.modes": ["human"]}

    def __init__(
        self,
        ohlcv: pd.DataFrame,
        long_short_ratio: pd.Series,
        funding_rate: pd.Series,
        window_size: int = 60,
        initial_balance: float = 10_000.0,
        fee_rate: float = 0.00044,
        leverage: float = 10.0,
        maintenance_margin_ratio: float = 0.005,
        random_start: bool = True,
    ):
        super().__init__()

        assert (
            ohlcv.index.freq == "1min"
        ), "OHLCV 必须是 1 分钟频率的 Pandas DataFrame，index 为 DateTimeIndex(freq='1min')，当前为 {}".format(
            ohlcv.index.freq
        )

        self.ohlcv = ohlcv.reset_index(drop=False)
        self.long_short_ratio = long_short_ratio.reset_index(drop=True)
        self.funding_rate = funding_rate.reset_index(drop=True)
        self.window_size = window_size
        self.initial_balance = initial_balance
        self.fee_rate = fee_rate
        self.leverage_setting = leverage
        self.maintenance_margin_ratio = maintenance_margin_ratio
        self.random_start = random_start

        # ===== Gym spaces =====
        self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32)
        obs_dim = (
            window_size * 14  # OHLC + Volume + 9个技术指标
            + 5  # position info & 可用余额 etc.
            + 2  # 资金费率和多空比例
        )
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)

        # 内部状态
        self._reset_account()
        self._ptr: int = self.window_size  # 数据指针

    # ---------------------------------
    # 重置 / 步进
    # ---------------------------------
    def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None):
        super().reset(seed=seed)
        self._reset_account()
        if self.random_start:
            self._ptr = self.np_random.integers(self.window_size, len(self.ohlcv) - 1)
        else:
            self._ptr = self.window_size
        return self._get_observation(), {}

    def step(self, action: np.ndarray):
        """执行一步，action ∈ [-1,1]."""
        action_val = float(action[0])
        reward = 0.0
        info = {}
        price = self._current_price()

        # === 资金费处理 ===
        self._apply_funding(price)

        # === 解析动作 ===
        if abs(action_val) > 1e-3:
            if self.position_size == 0:
                # 开新仓
                self._open_position(action_val, price)
            else:
                same_direction = (self.position_size > 0 and action_val > 0) or (
                    self.position_size < 0 and action_val < 0
                )
                if same_direction:
                    # 加仓
                    self._add_position(action_val, price)
                else:
                    # 减仓或反向 → 先平部分 / 全平
                    self._reduce_or_close(action_val, price)
        else:
            # no‑op
            pass

        # === 强平检查 ===
        self._check_liquidation(price)

        # === 时间向前推进 ===
        self._ptr += 1
        done = self._ptr >= len(self.ohlcv) - 1 or self.balance <= 0

        obs = self._get_observation()
        reward = self.realized_pnl  # 已结 PnL 作为奖励
        self.realized_pnl = 0.0  # 清零，避免下轮重复
        info.update(
            {
                "equity": self.balance + self._unrealized_pnl(price),
                "position_size": self.position_size,
                "entry_price": self.entry_price,
                "unrealized_pnl": self._unrealized_pnl(price),
            }
        )
        return obs, reward, done, False, info

    # ---------------------------------
    # 账户逻辑
    # ---------------------------------
    def _reset_account(self):
        self.balance: float = self.initial_balance  # 可用余额 / Equity
        self.position_size: float = 0.0  # >0 long <0 short (张数 BTC)
        self.entry_price: float = 0.0
        self.realized_pnl: float = 0.0

    def _apply_fee(self, notional: float):
        fee = abs(notional) * self.fee_rate
        self.balance -= fee
        self.realized_pnl -= fee

    def _open_position(self, action_val: float, price: float):
        notional = self.balance * abs(action_val) * self.leverage_setting
        qty = notional / price
        self.position_size = qty if action_val > 0 else -qty
        self.entry_price = price
        margin = notional / self.leverage_setting
        self.balance -= margin
        self._apply_fee(notional)

    def _add_position(self, action_val: float, price: float):
        additional_notional = self.balance * abs(action_val) * self.leverage_setting
        add_qty = additional_notional / price
        new_position_size = self.position_size + (add_qty if action_val > 0 else -add_qty)
        # 加权平均开仓价
        self.entry_price = (
            abs(self.position_size) * self.entry_price + additional_notional
        ) / abs(new_position_size)
        self.position_size = new_position_size
        margin = additional_notional / self.leverage_setting
        self.balance -= margin
        self._apply_fee(additional_notional)

    def _reduce_or_close(self, action_val: float, price: float):
        # 若方向相反，则按比例平仓
        ratio = abs(action_val)
        close_qty = abs(self.position_size) * ratio
        close_notional = close_qty * price
        # 已结 PnL
        pnl = close_qty * (price - self.entry_price) * (1 if self.position_size > 0 else -1)
        self.realized_pnl += pnl
        self.balance += (close_notional / self.leverage_setting) + pnl  # 退保证金 + 盈亏
        self._apply_fee(close_notional)
        # 更新剩余仓位
        remain_qty = abs(self.position_size) - close_qty
        self.position_size = math.copysign(remain_qty, self.position_size) if remain_qty > 0 else 0.0
        if self.position_size == 0:
            self.entry_price = 0.0

    def _apply_funding(self, price: float):
        """按分钟线性插值资金费，收取到/付出 Equity"""
        current_funding = self._current_funding()
        notional = abs(self.position_size) * price
        funding_payment = notional * current_funding / (8 * 60)  # 每分钟份额
        # long 支付正 funding，short 获得
        self.balance -= funding_payment * np.sign(self.position_size)

    def _unrealized_pnl(self, price: float) -> float:
        return abs(self.position_size) * (price - self.entry_price) * (
            1 if self.position_size > 0 else -1
        )

    def _check_liquidation(self, price: float):
        if self.position_size == 0:
            return
        notional = abs(self.position_size) * price
        margin = notional / self.leverage_setting
        equity = self.balance + self._unrealized_pnl(price)
        if equity < margin * self.maintenance_margin_ratio:
            # 强平
            self.realized_pnl += -margin  # 全部保证金亏损
            self.position_size = 0.0
            self.entry_price = 0.0
            self.balance = equity  # 强平后仅剩余的 equity

    # ---------------------------------
    # Observation & Helpers
    # ---------------------------------
    def _current_price(self) -> float:
        return float(self.ohlcv.iloc[self._ptr]["close"])

    def _current_funding(self) -> float:
        return float(self.funding_rate.iloc[self._ptr])

    def _current_long_short_ratio(self) -> float:
        prev_val = self.long_short_ratio.iloc[self._ptr - 1]
        next_val = self.long_short_ratio.iloc[self._ptr]
        return float(self.np_random.uniform(min(prev_val, next_val), max(prev_val, next_val)))

    def _get_observation(self) -> np.ndarray:
        start = self._ptr - self.window_size
        end = self._ptr
        window = self.ohlcv.iloc[start:end]
        # 计算技术指标
        df_ta = window.copy()
        df_ta["sma_fast"] = sma(df_ta["close"], 20)
        df_ta["sma_slow"] = sma(df_ta["close"], 50)
        df_ta["ema"] = ema(df_ta["close"], 20)
        macd_df = macd(df_ta)
        df_ta = pd.concat([df_ta, macd_df], axis=1)
        df_ta["crsi"] = connors_rsi(df_ta)
        support, resistance = support_resistance(df_ta)
        df_ta["support"] = support
        df_ta["resistance"] = resistance

        technical = df_ta[[
            "open",
            "high",
            "low",
            "close",
            "volume",
            "sma_fast",
            "sma_slow",
            "ema",
            "macd",
            "macd_signal",
            "macd_hist",
            "crsi",
            "support",
            "resistance",
        ]].ffill().fillna(0.0)

        # 将 window × features 拉平成一维
        tech_np = technical.to_numpy(dtype=np.float32).flatten()

        # 账户状态
        price = self._current_price()
        pos_dir = 0.0 if self.position_size == 0 else math.copysign(1, self.position_size)
        account_state = np.array([
            self.balance,
            self.position_size,
            self.entry_price,
            pos_dir,
            self._unrealized_pnl(price),
        ], dtype=np.float32)

        obs = np.concatenate([
            tech_np,
            account_state,
            np.array([
                self._current_funding(),
                self._current_long_short_ratio(),
            ], dtype=np.float32),
        ])
        return obs

    # ---------------------------------
    # Render / Close
    # ---------------------------------
    def render(self):
        price = self._current_price()
        print(
            f"t={self._ptr} | price={price:.2f} | bal={self.balance:.2f} | pos={self.position_size:.4f} @ {self.entry_price:.2f} | unreal={self._unrealized_pnl(price):+.2f}"
        )

    def close(self):
        pass

In [16]:
import math
from pathlib import Path
from collections import deque

import torch
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

from torchrl.envs import GymWrapper
from torchrl.collectors import SyncDataCollector
from torchrl.data import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.modules import DTActor
from tensordict.nn import TensorDictModule
from torchrl.objectives import DTLoss

In [17]:
# ================== 2. 训练超参数 ================== #
SEQ_LEN = 80
EMB_DIM = 256
N_LAYER = 8
N_HEAD = 8

TOTAL_FRAMES = 204_800          # 离线数据规模
FRAMES_PER_BATCH = 2_048        # 每次环境收集的帧
REPLAY_SIZE = 120_000

BASE_LR = 3e-4
WARMUP_STEPS = 10_000
GRAD_ACCUM = 4                  # ×4 累积 ≈ 总 batch = 256
BATCH_SIZE = 64                 # step 内真实 batch（×累积后 256）
GAMMA = 0.99                    # 折扣因子

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = torch.amp.GradScaler('cuda')
print(f"Using device: {device}")

Using device: cuda


In [19]:
# ================== 3. 创建环境 & 规格 ================= #
def make_env():
    env = BitcoinFuturesEnv(ohlcv_df, ls_series, funding_series)
    return GymWrapper(env)

tmp_env = make_env()
obs_spec = tmp_env.observation_spec["observation"]
obs_dim = obs_spec.shape[-1] # 847
act_dim = tmp_env.action_spec.shape[-1] # 1
tmp_env.close()

In [20]:
# ================== 4. 构建 Decision Transformer Actor ================= #
dt_actor_core = DTActor(
    state_dim=obs_dim,
    action_dim=act_dim,
    transformer_config={
        "n_embd": EMB_DIM,
        "n_layer": N_LAYER,
        "n_head": N_HEAD,
        "n_inner": 4 * EMB_DIM,
        "activation": "relu",
        "n_positions": 1024,
        "resid_pdrop": 0.1,
        "attn_pdrop": 0.1,
    },
)

actor_module = TensorDictModule(
    dt_actor_core,
    in_keys=["observation", "action", "return_to_go"],
    out_keys=["action"],
)

loss_module = DTLoss(actor_module, device=device)
loss_module.to(device)

DTLoss()

In [21]:
# ================== 5. Collector & Replay ================= #
collector = SyncDataCollector(
    create_env_fn=make_env,
    total_frames=TOTAL_FRAMES,
    frames_per_batch=FRAMES_PER_BATCH,
    device=device,
)

replay = ReplayBuffer(
    storage=LazyTensorStorage(REPLAY_SIZE),
    sampler=SliceSampler(slice_len=SEQ_LEN, strict_length=True),
)

In [22]:
# ================== 6. Optimizer & LR Scheduler ================= #
optimizer = AdamW(loss_module.parameters(), lr=BASE_LR, weight_decay=1e-2)

def cosine_lr(step):
    if step < WARMUP_STEPS:
        return step / WARMUP_STEPS
    progress = (step - WARMUP_STEPS) / max(1, TOTAL_FRAMES - WARMUP_STEPS)
    return 0.5 * (1 + math.cos(math.pi * progress))

In [None]:
# ================== 7. 日志与进度条 ================= #
writer = SummaryWriter("runs/DT_crypto_mixed")
pbar = tqdm(total=TOTAL_FRAMES, desc="Env frames", unit="frame", miniters=1)

reward_window = deque(maxlen=20)
global_step, episode_idx = 0, 0

# ================== 8. 训练循环 ================== #
print("开始训练（混合精度）...")
for batch in collector:
    replay.extend(batch)
    pbar.update(batch.numel())

    # 统计环境回报
    if "next" in batch and batch["done"].any():
        rets = batch["next"]["episode_reward"][batch["done"]]
        for r in rets:
            writer.add_scalar("env/episode_return", r.item(), episode_idx)
            reward_window.append(r.item())
            episode_idx += 1

    # 数据不足先跳过
    if len(replay) < 10_000:
        continue

    # --------- Gradient Accumulation --------- #
    optimizer.zero_grad(set_to_none=True)
    for acc_step in range(GRAD_ACCUM):
        flat = replay.sample(BATCH_SIZE * SEQ_LEN)
        seq  = flat.reshape(BATCH_SIZE, SEQ_LEN).to(device)
        seq  = seq.to(device)

        seq_reward = seq[("next", "reward")].squeeze(-1) # [B, T]
        seq.set("reward", seq_reward)              # now seq["reward"] 可直接访问
        pw  = GAMMA ** torch.arange(SEQ_LEN, device=device)  # [T]
        rtg = torch.flip(torch.cumsum(torch.flip(seq_reward * pw, dims=[1]), dim=1), dims=[1])
        seq["return_to_go"] = rtg.unsqueeze(-1).to(device)  # [B, T, 1]

        with torch.amp.autocast('cuda'):
            td_loss = loss_module(seq)
            loss = td_loss["loss"] / GRAD_ACCUM

        scaler.scale(loss).backward()

    # ------ 更新参数 ------
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 0.5)
    scaler.step(optimizer)
    scaler.update()

    # ------ 手动调度 LR ------
    lr = BASE_LR * cosine_lr(global_step)
    for g in optimizer.param_groups:
        g["lr"] = lr

    # ------ TensorBoard & 进度条 ------
    writer.add_scalar("train/loss", loss.item() * GRAD_ACCUM, global_step)
    if reward_window:
        avg_ret = sum(reward_window) / len(reward_window)
        pbar.set_postfix(loss=f"{loss.item()*GRAD_ACCUM:.4f}",
                         avg_ret=f"{avg_ret:7.2f}",
                         lr=f"{lr:1.2e}")
    global_step += 1

pbar.close()

Env frames:   0%|          | 0/204800 [00:00<?, ?frame/s]

开始训练（混合精度）...


  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.as_tensor(terminated),
  torch.

In [None]:
# ================== 9. 保存 & 评估 ================= #
ckpt_path = Path("decision_transformer_crypto_amp.pt")
torch.save(actor_module.state_dict(), ckpt_path)
print(f"模型已保存：{ckpt_path.resolve()}")

test_env = make_env()
td = test_env.reset()
done, ep_ret = False, 0.0
while not done:
    with torch.no_grad(), torch.amp.autocast('cuda'):
        td = td.to(device)
        td = actor_module(td)
        action = td["action"][-1]
    td = test_env.step(action.cpu())
    ep_ret += td["reward"].item()
    done = td["done"].item()

print(f"评估回报：{ep_ret:.2f}")
writer.add_scalar("eval/episode_return", ep_ret, 0)
writer.close()