In [4]:
# stock_trading_env.py
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd

import gymnasium as gym
from gymnasium import spaces

import yfinance as yF

# 실험에 쓰는 설정을 한 눈에 보기 위한 데이터클래스
@dataclass
class EnvConfig:
    tickers: List[str]
    start_date: str
    end_date: str

    # Observation config
    window: int = 20  # how many past days to include
    use_volume: bool = True

    # Trading config
    max_shares_per_day: int = 10   # Q: max shares to buy/sell per ticker per day
    initial_balance_min: float = 10_000.0
    initial_balance_max: float = 100_000.0

    # Reward config
    reward_scale: float = 1.0      # can scale daily returns

    # Fees (optional)
    commission_rate: float = 0.0   # e.g., 0.001 for 0.1% per trade notional

    # Numerical stability
    eps: float = 1e-12

# cfg = EnvConfig(tickers = ["AAPL"], start_date="2020-01-01", end_date="2024-01-01")
# print(cfg)


class StockTradingEnv(gym.Env):
    """
    Daily stock trading environment (Gymnasium).

    - One step = one trading day.
    - The agent observes info at day t (end-of-day), chooses action a_t.
    - The action is executed at day t+1 close price (Close[t+1]).
    - Supports N=1 (single ticker) and N>1 (multiple tickers).

    Action (default):
      MultiDiscrete([3]*N) where each element is:
        0 = SELL, 1 = HOLD, 2 = BUY

    Observation (Dict):
      {
        "market": Box(shape=(N, market_dim)),
        "agent":  Box(shape=(1 + N,), )  # cash_ratio + position_ratio for each ticker
      }
    """

    metadata = {"render_modes": ["human"]}

    def __init__(
        self,
        config: EnvConfig,
        data: Optional[Dict[str, pd.DataFrame]] = None,
        render_mode: Optional[str] = None,
        seed: Optional[int] = None,
    ):
        super().__init__()
        self.cfg = config
        self.render_mode = render_mode

        self._rng = np.random.RandomState(seed)

        # Load data: dict[ticker] -> DataFrame with at least ["Close", "Volume"]
        self.data = data if data is not None else self._download_data()

        self.tickers = list(self.cfg.tickers)
        self.n = len(self.tickers)

        # Align and build arrays
        self._build_arrays()

        # Time index: we need t+1, and also window history.
        self.t0 = self.cfg.window  # first valid decision day index
        self.t = self.t0

        # Portfolio state
        self.cash: float = 0.0
        self.shares: np.ndarray = np.zeros(self.n, dtype=np.int64)

        # --- Define action space ---
        # One action per ticker: {SELL, HOLD, BUY}
        self.action_space = spaces.MultiDiscrete([3] * self.n)

        # --- Define observation space ---
        # market features per ticker: returns(window) + (optional) vol_change(window)
        self.market_dim = self.cfg.window * (2 if self.cfg.use_volume else 1)

        self.observation_space = spaces.Dict(
            {
                "market": spaces.Box(
                    low=-np.inf,
                    high=np.inf,
                    shape=(self.n, self.market_dim),
                    dtype=np.float32,
                ),
                "agent": spaces.Box(
                    low=0.0,
                    high=1.0,
                    shape=(1 + self.n,),  # cash_ratio + position_ratio per ticker
                    dtype=np.float32,
                ),
            }
        )

    # -----------------------
    # Data utilities
    # -----------------------
    def _download_data(self) -> Dict[str, pd.DataFrame]:
        if yF is None:
            raise ImportError("yfinance is not installed. Install it or pass `data=` directly.")

        out: Dict[str, pd.DataFrame] = {}
        for tk in self.cfg.tickers:
            df = yF.download(
                tk,
                start=self.cfg.start_date,
                end=self.cfg.end_date,
                auto_adjust=False,
                progress=False,
            )
            # Expect columns include: Open, High, Low, Close, Volume
            if "Close" not in df.columns:
                raise ValueError(f"{tk}: downloaded data missing 'Close' column.")
            if self.cfg.use_volume and "Volume" not in df.columns:
                raise ValueError(f"{tk}: downloaded data missing 'Volume' column (use_volume=True).")

            df = df.dropna().copy()
            out[tk] = df
        return out

    def _build_arrays(self) -> None:
        """Align all tickers on common dates and create numpy arrays."""
        # Find common index intersection
        idx = None
        for tk in self.tickers:
            df = self.data[tk]
            idx = df.index if idx is None else idx.intersection(df.index)

        if idx is None or len(idx) < (self.cfg.window + 2):
            raise ValueError("Not enough aligned data for the chosen window and period.")

        # Reindex all to common idx
        closes = []
        volumes = []
        for tk in self.tickers:
            df = self.data[tk].reindex(idx).dropna()
            closes.append(df["Close"].to_numpy(dtype=np.float64))
            if self.cfg.use_volume:
                volumes.append(df["Volume"].to_numpy(dtype=np.float64))

        self.dates = idx
        self.close = np.stack(closes, axis=1)  # shape (T, N)
        self.volume = np.stack(volumes, axis=1) if self.cfg.use_volume else None
        self.T = self.close.shape[0]

        # Precompute market features (returns, volume change)
        self.ret = self._pct_change(self.close)  # shape (T, N), ret[0]=0
        if self.cfg.use_volume:
            self.vchg = self._pct_change(self.volume)  # shape (T, N)
        else:
            self.vchg = None

    @staticmethod
    def _pct_change(x: np.ndarray) -> np.ndarray:
        """Percent change along time axis. First row is 0."""
        out = np.zeros_like(x, dtype=np.float64)
        out[1:] = (x[1:] - x[:-1]) / np.maximum(x[:-1], 1e-12)
        return out

    # -----------------------
    # Gym API
    # -----------------------
    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]:
        if seed is not None:
            self._rng = np.random.RandomState(seed)

        self.t = self.t0
        self.shares = np.zeros(self.n, dtype=np.int64)

        # Random initial balance (as required)
        self.cash = float(self._rng.uniform(self.cfg.initial_balance_min, self.cfg.initial_balance_max))

        obs = self._get_obs()
        info = self._get_info()
        return obs, info

    def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict]:
        """
        Executes action decided at day t, using next day's close (t+1) as execution price.
        Returns obs at (t+1), reward for transition t->t+1, done flags, and info.
        """
        action = np.asarray(action, dtype=np.int64)
        if action.shape != (self.n,):
            raise ValueError(f"Expected action shape {(self.n,)}, got {action.shape}")

        # If at the last day, cannot move to t+1
        if self.t >= self.T - 1:
            obs = self._get_obs()
            return obs, 0.0, True, False, self._get_info()

        # Portfolio value at current day t (using close[t])
        price_t = self.close[self.t]              # shape (N,)
        V_before = self._portfolio_value(price_t)

        # Execution at next day close
        exec_price = self.close[self.t + 1]       # Close[t+1]
        self._execute_trades(action, exec_price)

        # Advance time
        self.t += 1

        # Portfolio value after execution at new day t (which is old t+1)
        price_after = self.close[self.t]
        V_after = self._portfolio_value(price_after)

        # Reward: daily portfolio return
        reward = ((V_after - V_before) / max(V_before, self.cfg.eps)) * self.cfg.reward_scale

        # Termination when we reach end of data
        terminated = (self.t >= self.T - 1)
        truncated = False

        obs = self._get_obs()
        info = self._get_info()
        return obs, float(reward), bool(terminated), bool(truncated), info

    # -----------------------
    # Core mechanics
    # -----------------------
    def _portfolio_value(self, price: np.ndarray) -> float:
        stock_value = float(np.dot(self.shares, price))
        return float(self.cash + stock_value)

    def _execute_trades(self, action: np.ndarray, price: np.ndarray) -> None:
        """
        Apply per-ticker trades with constraints:
          - sell qty <= shares
          - buy qty constrained by available cash
        Uses max_shares_per_day as fixed per-ticker trade size.
        """
        Q = int(self.cfg.max_shares_per_day)
        fee_rate = float(self.cfg.commission_rate)

        # Process sells first (common practical choice)
        for i in range(self.n):
            if action[i] == 0:  # SELL
                qty = min(Q, int(self.shares[i]))
                if qty > 0:
                    notional = qty * float(price[i])
                    fee = fee_rate * notional
                    self.cash += (notional - fee)
                    self.shares[i] -= qty

        # Then process buys
        for i in range(self.n):
            if action[i] == 2:  # BUY
                # max affordable qty given cash
                unit_price = float(price[i])
                if unit_price <= 0:
                    continue
                # include fee in affordability
                # cash >= qty*price*(1+fee_rate)
                max_afford = int(self.cash / (unit_price * (1.0 + fee_rate) + self.cfg.eps))
                qty = min(Q, max_afford)
                if qty > 0:
                    notional = qty * unit_price
                    fee = fee_rate * notional
                    self.cash -= (notional + fee)
                    self.shares[i] += qty

        # Safety clamp (numerical)
        self.cash = max(0.0, float(self.cash))

    def _get_obs(self) -> Dict[str, np.ndarray]:
        """
        Observation at current time self.t uses history up to self.t (inclusive),
        with window length cfg.window.
        """
        w = self.cfg.window
        t = self.t

        # Market features: last w returns (and optionally vol changes)
        rets = self.ret[t - w + 1 : t + 1]  # shape (w, N)
        rets = rets.T  # (N, w)

        if self.cfg.use_volume:
            vch = self.vchg[t - w + 1 : t + 1].T  # (N, w)
            market = np.concatenate([rets, vch], axis=1)  # (N, 2w)
        else:
            market = rets  # (N, w)

        market = market.astype(np.float32)

        # Agent features
        price_t = self.close[t]
        V = self._portfolio_value(price_t)
        cash_ratio = float(self.cash / max(V, self.cfg.eps))
        stock_values = self.shares.astype(np.float64) * price_t
        position_ratio = (stock_values / max(V, self.cfg.eps)).astype(np.float32)  # (N,)

        agent = np.concatenate([[cash_ratio], position_ratio], axis=0).astype(np.float32)

        return {"market": market, "agent": agent}

    def _get_info(self) -> Dict[str, Any]:
        price_t = self.close[self.t]
        info = {
            "t": int(self.t),
            "date": str(self.dates[self.t].date()) if hasattr(self.dates[self.t], "date") else str(self.dates[self.t]),
            "cash": float(self.cash),
            "shares": self.shares.copy(),
            "price": price_t.copy(),
            "portfolio_value": float(self._portfolio_value(price_t)),
        }
        return info

    def render(self):
        if self.render_mode != "human":
            return
        info = self._get_info()
        print(
            f"[{info['date']}] t={info['t']}  value={info['portfolio_value']:.2f}  "
            f"cash={info['cash']:.2f}  shares={info['shares']}"
        )


# -----------------------
# Quick sanity test (run this file)
# -----------------------
if __name__ == "__main__":
    cfg = EnvConfig(
        tickers=["AAPL"],               # start with single ticker
        start_date="2020-01-01",
        end_date="2024-01-01",
        window=20,
        max_shares_per_day=10,
        initial_balance_min=10_000,
        initial_balance_max=50_000,
        commission_rate=0.0,
    )

    env = StockTradingEnv(cfg, seed=0, render_mode="human")
    obs, info = env.reset(seed=0)
    done = False

    while not done:
        action = env.action_space.sample()  # random actions
        obs, reward, terminated, truncated, info = env.step(action)
        env.render()
        done = terminated or truncated

    print("Episode finished.")


  stock_value = float(np.dot(self.shares, price))


ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 1 dimension(s) and the array at index 1 has 2 dimension(s)