In [None]:
# environment 정의
import numpy as np
import pandas as pd
import torch
class Environment:
    def __init__(self, df_close_us: pd.DataFrame, df_close_kr:pd.DataFrame, predictor, device,
                 seq_window=30, bol_window=20, cost_rate=0.0005,
                 normalize_action=True):
        self.df_us = df_close_us.dropna().copy()
        self.df_kr = df_close_kr.dropna().copy()
        self.predictor = predictor.to(device)
        self.device = device
        self.seq_window = seq_window
        self.bol_window = bol_window
        self.cost_rate = cost_rate
        self.normalize_action = normalize_action
        self.tickers = list(self.df_kr.columns)
        self.N = len(self.tickers)
        self.t = None
        self.prev_w = None

    def reset(self,start_idx=None):
        self.t = (self.seq_window - 1 if start_idx is None else max(start_idx, self.seq_window - 1))
        self.prev_w = np.ones(self.N, dtype=np.float32) / self.N
        return self._get_state()

    def step(self, action):
        if isinstance(action, torch.Tensor):
            w = action.detach().cpu().numpy().astype(np.float32).reshape(-1)
        else:
            w = np.asarray(action, dtype=np.float32).reshape(-1)
        #softmax로 변환
        expw = np.exp(w - np.max(w))
        w = expw / (np.sum(expw) + 1e-12)
        turnover = float(np.sum(np.abs(w - self.prev_w)))
        cost = self.cost_rate * turnover
        t0 = self.t
        t1 = self.t + 1
        if t1 >= len(self.df_kr):
            return None, 0.0, True, {"reason": "end_of_data"}
        p0 = self.df_kr.iloc[t0].values.astype(np.float32)  
        p1 = self.df_kr.iloc[t1].values.astype(np.float32) 
        asset_ret = (p1 / (p0 + 1e-12)) - 1.0            
        port_ret = float(np.dot(w, asset_ret))
        reward = port_ret - cost
        self.prev_w = w
        self.t = t1
        next_state = self._get_state()
        info = {"t": self.t, "port_ret": port_ret, "cost": cost, "turnover": turnover}
        done = False
        return next_state, reward, done, info
    def _get_state(self):
        t = self.t
        us_aligned = self.df_us.reindex(self.df_kr.index, method="ffill").shift(1)
        df_upto_t_us = us_aligned.iloc[:t+1]
        df_upto_t_kr = self.df_kr.iloc[:t+1]
        cur_prices = df_upto_t_kr.iloc[-1].values.astype(np.float32)
        pred_break = self.predictor.get_pred_break(df_upto_t_us, t)
        prev_w = self.prev_w.astype(np.float32)
        state = torch.tensor(
            np.concatenate([cur_prices, pred_break, prev_w], axis=0),
            dtype=torch.float32,
            device=self.device
        )
        return state