In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import Categorical
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import random
import os

from google.colab import drive
drive.mount('/content/drive')

In [None]:
def angle_encoding(dt_index,weights=None):
    weights = weights or { 'hour':0.1,'day':0.4,'week':0.3,'year':0.2 }
    timestamps = dt_index.view('int64') // 10**9
    hour_angle = (timestamps%3600) / 3600 * 2 * np.pi
    day_angle = (timestamps%86400) / 86400 * 2 * np.pi
    week_angle = (timestamps%604800) / 603800 * 2 * np.pi
    year_seconds = 365.25 * 86400
    year_angle = (timestamps % year_seconds)/year_seconds*2*np.pi
    sin_sum = (
        weights['hour']*np.sin(hour_angle)+
        weights['day']*np.sin(day_angle)+
        weights['week']*np.sin(week_angle)+
        weights['year']*np.sin(year_angle)
    )
    cos_sum = (
        weights['hour']*np.cos(hour_angle)+
        weights['day']*np.cos(day_angle)+
        weights['week']*np.cos(week_angle)+
        weights['year']*np.cos(year_angle)
    )
    final_angle = (np.arctan2(sin_sum,cos_sum)+np.pi)/(2*np.pi)
    return pd.Series(final_angle,index=dt_index,name='time')

class MultiTimeDataset(Dataset):
    def __init__(self,path,tick,input_dims,batch_size=64,device='cpu'):
        super(MultiTimeDataset,self).__init__()
        self.timeframes = input_dims.keys()
        file_name = f'{path}{tick}.csv'
        df=pd.read_csv(file_name,parse_dates=[0],index_col=[0])
        self.x = {k:torch.tensor(
            df[[col for col in df.columns if col.startswith(k)]].values,
            device=device, dtype=torch.float32
        ) for k in self.timeframes}
        non_tf_cols = [col for col in df.columns
                        if not any(col.startswith(prefix) for prefix in self.timeframes)
                        and col != 'close']
        self.x['1'] = torch.tensor(
            df[non_tf_cols].values,
            device=device, dtype=torch.float32
        )
        self.times = torch.tensor(
            angle_encoding(df.index).values,
            device=device, dtype=torch.float32
        )
        time_indices = {
            k:np.arange(-int(k)*(v-1),int(k),int(k)) 
            for k,v in input_dims.items() 
        }
        om = min(arr.min() for arr in time_indices.values())
        non_nan_index = df.index.get_loc(df.index[~df.isna().any(axis=1)][0])
        self.indices = {
            k:arr + abs(om) + non_nan_index
            for k,arr in time_indices.items()
        }
        self.len = len(df)-(abs(om)+non_nan_index)
        self.y = df.close.iloc[non_nan_index+abs(om):]
        self._precompute_first_batch_indices(batch_size)
        
    def __len__(self):
        return self.len
    
    def __getitem__(self,idx):
        return {
            tf: (
                self.x[tf][self.indices[tf]+idx],
                self.times[self.indices[tf]+idx]
            )
            for tf in self.timeframes
        },self.y.iloc[idx]
    
    def _precompute_first_batch_indices(self,batch_size):
        self.batch_size = batch_size
        total_samples = len(self)
        self.num_batches = (total_samples+ self.batch_size - 1) // self.batch_size
        self.first_batch_indices = {}
        first_batch_size = min(self.batch_size,total_samples)
        for tf in self.timeframes:
            tf_indices = self.indices[tf]
            indices_matrix = np.empty((first_batch_size,len(tf_indices)),dtype=np.int32)
            for i in range(first_batch_size):
                indices_matrix[i] = tf_indices + i
            self.first_batch_indices[tf] = indices_matrix

    def _get_batch_indices(self,batch_idx):
        start_idx = batch_idx * self.batch_size
        end_idx = min(start_idx + self.batch_size, len(self))
        return start_idx, end_idx
    
    def _prepare_batch(self, batch_idx):
        if batch_idx >= self.num_batches:
            raise IndexError("Batch index out of range")
        start_idx,end_idx = self._get_batch_indices(batch_idx)
        current_batch_size = end_idx - start_idx
        batch_data = {}
        for tf in self.timeframes:
            base_indices = self.first_batch_indices[tf][:current_batch_size]
            adjusted_indices = base_indices + start_idx
            x_batch = self.x[tf][adjusted_indices]
            times_batch = self.times[adjusted_indices]
            batch_data[tf] = (x_batch,times_batch)
        batch_labels = self.y.iloc[start_idx:end_idx]
        return batch_data, batch_labels
    
    def iter_batch(self):
        for batch_idx in range(self.num_batches):
            yield self._prepare_batch_fast(batch_idx)

    def get_batch(self,batch_idx):
        return self._prepare_batch(batch_idx)

In [None]:
class FourierTimeEmbedding(nn.Module):
    """ 시간 정보를 푸리에 변환을 사용하여 임베딩하는 모듈
    시간 정보를 주기적인 특성을 가진 고차원 벡터로 변환합니다.
    
    Args:
        embed_dim (int): 출력 임베딩 차원
        num_bands (int): 푸리에 변환에 사용할 주파수 밴드 수
    """
    def __init__(self,embed_dim=128,num_bands=32):
        super().__init__()
        self.num_bands = num_bands
        self.embed_dim = embed_dim
        self.fc1 = nn.Linear(2*num_bands,embed_dim)
        self.fc2 = nn.Linear(embed_dim,embed_dim)
        self.activation = nn.GELU()
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self,t):
        # t: [batch,seq,1]
        coeffs = torch.linspace(0,1,self.num_bands,device=t.device)
        embed = torch.einsum('bs,n->bsn',t,coeffs) * 2 * torch.pi
        embed = torch.cat([torch.sin(embed),torch.cos(embed)],dim=-1)
        embed = self.fc1(embed)
        embed = self.activation(embed)
        embed = self.fc2(embed)
        return self.norm(embed)

class ExecutionHyridModule(nn.Module):
    """ 1분봉 데이터를 처리하는 CNN-GRU 하이브리드 모듈
    CNN으로 지역적 특징을 추출하고 GRU로 시계열 정보를 처리합니다.
    
    Args:
        input_dim (int): 입력 특징 차원
        time_dim (int): 시간 임베딩 차원
        hidden_size (int): 은닉층 크기
    """
    def __init__(self,input_dim,time_dim=128,hidden_size=64):
        super().__init__()
        self.time_embed = FourierTimeEmbedding(time_dim)
        self.conv = nn.Sequential(
            nn.Conv1d(input_dim+time_dim,hidden_size,5,padding=2),
            nn.GELU(),
            nn.BatchNorm1d(hidden_size),
            nn.Conv1d(hidden_size,hidden_size,3,padding=1),
            nn.GELU(),
            nn.BatchNorm1d(hidden_size)
        )
        self.gru = nn.GRU(hidden_size,hidden_size,batch_first=True,
                          bidirectional=True,num_layers=2,dropout=0.3)
        self.proj = nn.Linear(hidden_size*2, 128)
        self.dropout = nn.Dropout(0.2)
        self.norm = nn.LayerNorm(128)

    def forward(self,x,t):
        # x: [B,T,F], t: [B,T,1]
        t_emb = self.time_embed(t)
        x = torch.cat([x,t_emb],dim=-1).permute(0,2,1)
        conv_out = self.conv(x).permute(0,2,1)
        gru_out,_ = self.gru(conv_out)
        return self.norm(self.dropout(self.proj(gru_out)))

class MultiScaleLSTM(nn.Module):
    """ 15분/4시간 봉 데이터를 위한 멀티스케일 LSTM 모듈
    여러 시간 스케일에서 LSTM을 적용하여 다양한 시간대의 패턴을 포착합니다.
    
    Args:
        input_dim (int): 입력 특징 차원
        time_dim (int): 시간 임베딩 차원
        scales (list): 각 LSTM이 처리할 시간 스케일 목록
    """
    def __init__(self,input_dim,time_dim=128,scales=[5,10,20]):
        super().__init__()
        self.time_embed = FourierTimeEmbedding(time_dim)
        self.lstms = nn.ModuleList([
            nn.LSTM(
                input_size=input_dim+time_dim,
                hidden_size=64,
                num_layers=2,
                dropout=0.3,
                batch_first=True)
            for _ in scales
        ])
        self.dropouts = nn.ModuleList([nn.Dropout(0.2) for _ in scales])
        self.attn = nn.MultiheadAttention(64*len(scales),8,batch_first=True)
        self.proj = nn.Linear(64*len(scales),128)
        self.norm = nn.LayerNorm(128)

    def forward(self,x,t):
        t_emb = self.time_embed(t)
        x_in = torch.cat([x,t_emb],dim=-1)
        outputs = []
        for i,lstm in enumerate(self.lstms):
            out,_ = lstm(x_in)
            out = self.dropouts[i](out)
            outputs.append(out)
        concat = torch.cat(outputs,dim=-1)
        attn_out,_ = self.attn(concat,concat,concat)
        return self.norm(self.proj(attn_out))

class HierarchicalTransformer(nn.Module):
    """ 1시간/일 봉 데이터를 처리하는 트랜스포머 기반 모듈
    자기 주의 메커니즘을 통해 장기 의존성을 포착합니다.
    
    Args:
        input_dim (int): 입력 특징 차원
        time_dim (int): 시간 임베딩 차원
        nhead (int): 멀티헤드 어텐션의 헤드 수
        num_layers (int): 트랜스포머 인코더 층 수
    """
    def __init__(self,input_dim,time_dim=128,nhead=8,num_layers=4):
        super().__init__()
        self.time_embed = FourierTimeEmbedding(time_dim)
        self.input_proj = nn.Linear(input_dim+time_dim,128)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=128,nhead=nhead,dim_feedforward=512,
            dropout=0.2,activation='gelu',batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer,num_layers)
        self.norm = nn.LayerNorm(128)

    def forward(self,x,t):
        t_emb = self.time_embed(t)
        x_in = torch.cat([x,t_emb],dim=-1)
        projected = self.input_proj(x_in)
        return self.norm(self.encoder(projected))

class CrossModalAttention(nn.Module):
    """ 다중 시간대 특징을 통합하는 교차 모달 어텐션 모듈
    서로 다른 시간대의 특징들 간의 관계를 학습합니다.
    
    Args:
        num_timeframes (int): 처리할 시간대 수
        embed_dim (int): 특징 임베딩 차원
        heads (int): 어텐션 헤드 수
    """
    def __init__(self, input_dims, embed_dim=128, heads=8):
        super().__init__()
        self.timeframes = input_dims.keys()
        num_timeframes = len(self.timeframes)
        self.projections = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim,embed_dim),
                nn.GELU(),
                nn.LayerNorm(embed_dim)
            )
            for tf, dim in input_dims.items()
        ])
        self.attentions = nn.ModuleList([
            nn.MultiheadAttention(embed_dim, heads, dropout=0.2, batch_first=True)
            for _ in range(num_timeframes)
        ])
        self.norms = nn.ModuleList([
            nn.LayerNorm(embed_dim)
            for _ in range(num_timeframes)
        ])
        self.timeframe_weights = nn.Parameter(torch.ones(num_timeframes)/num_timeframes)
        self.final_norm = nn.LayerNorm(embed_dim * num_timeframes)
        self.final_dropout = nn.Dropout(0.2)

    def forward(self, features):
        features = [proj(feat.transpose(1,2)).transpose(1,2)
            for proj,feat in zip(self.projections,features)]
        # [B, K, T, D] 형태로 변환
        context = torch.stack(features, dim=1)
        B, K, T, D = context.shape
        # 배치와 시퀀스 차원 결합
        context = context.view(B*T, K, D)
        # 각 시간대별 어텐션 적용
        attn_outs = []
        for i, (attn, norm) in enumerate(zip(self.attentions, self.norms)):
            # 현재 시간대를 쿼리로 사용
            query = context[:, i:i+1]
            # 어텐션 및 잔차 연결
            out, _ = attn(query, context, context)
            attn_outs.append(norm(out + query))
        # 모든 시간대의 특징 결합
        weights = F.softmax(self.timeframe_weights,dim=0)
        combined = torch.stack(attn_outs,dim=1)*weights.view(1,K,1,1)
        fused = combined.view(B,T,K*D)
        return self.final_norm(self.final_dropout(fused))

class EnhancedMultiTimeframeModel(nn.Module):
    """ 다중 시간대 데이터를 처리하는 강화학습 모델
    각 시간대별로 특화된 모듈을 사용하여 특징을 추출하고,
    이를 통합하여 행동(actor)과 가치(critic) 예측을 수행합니다.
    
    Args:
        feature_dims (dict): 각 시간대별 입력 특징 차원을 담은 딕셔너리
    """
    def __init__(self, feature_dims, input_dims, action_dim):
        super().__init__()
        # 입력된 시간대 저장
        self.timeframes = list(feature_dims.keys())
        
        # 각 시간대별 특화된 모듈 초기화
        self.modules_dict = nn.ModuleDict()
        for tf, dim in feature_dims.items():
            # 시간대에 따라 적절한 모듈 선택
            if int(tf) <= 20:  # 1분봉
                self.modules_dict[tf] = ExecutionHyridModule(dim)
            elif int(tf) <= 100:  # 1시간 이하
                self.modules_dict[tf] = MultiScaleLSTM(dim)
            else:  # 1시간 초과
                self.modules_dict[tf] = HierarchicalTransformer(dim)
        
        # 특징 융합을 위한 어텐션 모듈
        self.fusion = CrossModalAttention(
            input_dims,
            embed_dim=128  # 각 모듈의 출력 차원
        )
        
        # 행동과 가치 예측을 위한 헤드
        fusion_dim = 128 * len(self.timeframes)  # 융합된 특징의 차원
    
        # 행동 (롱/숏/중립)
        self.actor = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512,256),
            nn.GELU(),
            nn.LayerNorm(256),
            nn.Linear(256, action_dim)  # 3개의 행동: [롱, 중립, 숏]
        )
        
        # Critic 네트워크 (가치 예측)
        self.critic = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512,256),
            nn.GELU(),
            nn.LayerNorm(256),
            nn.Linear(256, 1)
        )
    
    def forward(self, inputs):
        features = []
        # 입력된 모든 시간대에 대해 처리
        for tf in self.timeframes:
            data, time = inputs[tf]
            out = self.modules_dict[tf](data, time)
            features.append(out)
        # 특징 융합
        fused = self.fusion(features)
        logits = self.actor(fused)
        dist = Categorical(logits=logits)
        value = self.critic(fused)
        return dist, value
    def get_action(self,inputs,deterministic=False,mode='last'):
        dist,value = self.forward(inputs)
        if deterministic:
            action = torch.argmax(dist.probs,dim=-1)
        else:
            action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        if mode == 'last':
            return action[:,-1],log_prob[:,-1],entropy[:,-1],value[:,-1]
        elif mode == 'mean':
            return action[:,-1],log_prob.mean(dim=1),entropy.mean(dim=1),value.mean(dim=1,keepdim=True)
        else:
            return action,log_prob,entropy,value
    
    def get_logprob(self,inputs,action,mode='last'):
        dist,value = self.forward(inputs)
        B,T,_ = dist.probs.shape
        action = action.unsqueeze(1).expand(B, T)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        if mode == 'last':
            return log_prob[:, -1], entropy[:, -1], value[:, -1]
        elif mode == 'mean':
            return log_prob.mean(dim=1), entropy.mean(dim=1), value.mean(dim=1, keepdim=True)
        else:
            return log_prob,entropy,value


In [None]:
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPSILON = 0.2
CRITIC_DISCOUNT = 0.5
ENTROPY_BETA = 0.01
LEARNING_RATE = 3e-4
PPO_EPOCHS = 30
MINI_BATCH_SIZE = 64
MAX_GRAD_NORM = 0.5

class PPOMemory:
    def __init__(self):
        self.states_index = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []
    def push(self,state_index,action,reward,value,log_prob,done):
        self.states_index.append(state_index)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.log_probs.append(log_prob)
        self.dones.append(done)
    def get(self):
        return (
            self.states_index,
            torch.stack(self.actions),
            torch.stack(self.rewards),
            torch.stack(self.values),
            torch.stack(self.log_probs),
            torch.tensor(self.dones,device=self.actions[0].device)
        )
    def clear(self):
        self.states_index.clear()
        self.actions.clear()
        self.rewards.clear()
        self.values.clear()
        self.log_probs.clear()
        self.dones.clear()

class PPOAgent:
    def __init__(self, feature_dims, input_dims, n_actions, device):
        self.policy = EnhancedMultiTimeframeModel(feature_dims, input_dims, n_actions).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LEARNING_RATE)
        self.memory = PPOMemory()
        self.mode = "last"
        self.device = device
        self.scaler = torch.amp.GradScaler()
    def compute_gae(self, next_value, rewards, values, dones):
        values = torch.cat([values, next_value.T]).to(torch.float32)
        gae = 0
        returns = torch.zeros_like(rewards,device=self.device,dtype=torch.float32)
        advantages = torch.zeros_like(rewards,device=self.device,dtype=torch.float32)
        for steps in reversed(range(len(rewards))):
            delta = rewards[steps] + GAMMA * values[steps + 1] * (1 - dones[steps]) - values[steps]
            gae = delta + GAMMA * GAE_LAMBDA * (1 - dones[steps]) * gae
            advantages[steps] = gae
            rewards[steps] = gae + values[steps]
        return returns, advantages
    def update(self,next_value,dataset):
        device_type = 'cuda' if self.device.type == 'cuda' else 'cpu'
        states_indexs, actions, rewards, values, old_log_probs, dones = self.memory.get()
        returns, advantages = self.compute_gae(next_value,rewards,values,dones)
        returns = returns.detach()
        advantages = advantages.detach()
        old_log_probs = [log_prob.detach() for log_prob in old_log_probs]
        returns = returns.view(-1,MINI_BATCH_SIZE)
        advantages = advantages.view(-1,MINI_BATCH_SIZE)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        for _ in range(PPO_EPOCHS):
            for batch_idx in states_indexs:
                with torch.amp.autocast(device_type=device_type,dtype=torch.float32):
                    mb_states,_ = dataset.get_batch(batch_idx)
                    mb_actions = actions[batch_idx]
                    mb_returns = returns[batch_idx].view(-1)
                    mb_advantages = advantages[batch_idx]
                    mb_old_log_prob = old_log_probs[batch_idx]
                    
                    new_log_prob,entropy,values = self.policy.get_logprob(mb_states, mb_actions, self.mode)
                    ratio = torch.exp(new_log_prob - mb_old_log_prob)
                    surr1 = ratio * mb_advantages
                    surr2 = torch.clamp(ratio, 1.0 - CLIP_EPSILON, 1.0 + CLIP_EPSILON) * mb_advantages
                    actor_loss = -torch.min(surr1,surr2).mean()
                    critic_loss = F.mse_loss(values.squeeze(-1), mb_returns, reduction='mean')
                    loss = actor_loss + CRITIC_DISCOUNT * critic_loss - ENTROPY_BETA * entropy.mean()
                
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.policy.parameters(),MAX_GRAD_NORM)

                self.scaler.step(self.optimizer)
                self.scaler.update()

        self.memory.clear()
    def save(self,path):
        torch.save(self.policy.state_dict(),path)


In [None]:
class TradingEnv:
    def __init__(self,path,ticks,input_dims,batch_size,max_step,device='cpu'):
        self.max_step = max_step
        self.datasets = [MultiTimeDataset(path,tick,input_dims,batch_size,device) for tick in ticks]
        self.reset(0)
    def reset(self,batch_index):
        self.dataset = self.datasets[random.randint(0,len(self.datasets)-1)]
        self.position = False
        self.current_price = None
        self.transaction_cost = 0.0005
        self.n_step = 0
        self.total_reward = 1
        self.last_reward = 0
        return self.dataset.get_batch(batch_index)
    def step(self,actions,labels,batch_index):
        if not self.current_price:
            self.current_price = labels[0]
        actions[0] = 1 if actions[0] == position + 1 else actions[0]

        mask = actions != 1

        if mask.any():
            non_zero_idxs = self.indices[mask]
            if len(non_zero_idxs) > 1:
                diff = torch.diff(actions[mask])
                change_mask = diff != 0
                if change_mask.any():
                    changes = torch.nonzero(change_mask).squeeze(-1)
                    final_idxs = torch.cat([non_zero_idxs[[0]], non_zero_idxs[changes + 1]])
                else:
                    final_idxs = non_zero_idxs[[0]]
            else:
                final_idxs = non_zero_idxs
        else:
            final_idxs = torch.tensor([], device=self.device, dtype=torch.long)

        current_prices = torch.zeros_like(labels,device=self.device)
        current_prices[final_idxs] = labels[final_idxs]

        cum_max_indices = torch.cummax(
            torch.where(current_prices != 0,
                        self.indices, torch.tensor(-1, device=self.device)
            ), dim=0)[0]
        current_prices = torch.where(cum_max_indices >= 0, current_prices[cum_max_indices], self.current_price)
        current_price = current_prices[-1]
        current_prices = torch.roll(current_prices, shifts=1)
        current_prices[0] = current_prices[1]

        positions = self.position * (-1) ** (self.indices.unsqueeze(1) > final_idxs.unsqueeze(0)).sum(dim=1)
        self.position = positions[-1] if final_idxs == n else positions[-1]*-1

        profit = (labels - current_prices) / labels * positions
        profit[final_idxs] = torch.where(positions[final_idxs] == -1, profit[final_idxs] - self.transaction_cost, profit[final_idxs])

        rewards = torch.where(torch.isin(self.indices,final_idxs),profit,profit*0.01)

        self.n_step += 1
        self.total_reward *= torch.prod(rewards + 1).item()
        if (self.total_reward < 0.8) or (self.n_step >= self.max_step):
            done = 1
        rewards[0] += self.last_reward
        rewards = torch.cumsum(rewards,dim=0)
        self.last_reward = rewards[-1].item()
        states, labels = self.dataset.get_batch(batch_index)
        return states, labels, rewards, done


In [None]:
path = "/content/drive/My Drive/ttrade/data/250325/"
ticks = ["KRW-BTC","KRW-ETC","KRW-DOGE"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

feature_dims = {
    '1': 4, '15':6, '60':3, '240':6, '1440':8
}

input_dims = {
    '1':140,'15':15,'60':15,'240':15,'1440':10
}
param_path = '/content/drive/My Drive/ttrade/pth/ppo-MT1.pth'

max_steps = 2048
env = TradingEnv(path,ticks,input_dims,MINI_BATCH_SIZE,max_steps,device)
agent = PPOAgent(feature_dims, input_dims, 3, device)
if os.path.exists(param_path) == True:
    agent.policy.load_state_dict(torch.load(param_path,weights_only=True,map_location=device))

num_episodes = 30
update_interval = 256
for episode in range(num_episodes):
    batch_index = 0
    state, label = env.reset(batch_index)
    episode_reward = 0
    
    for step in range(batch_index, batch_index+max_steps):
        with torch.no_grad():
            action,log_prob,_,value = agent.policy.get_action(state,deterministic=(random.random() < 0.4))
        next_batch_index = batch_index + 1
        next_state,next_label,reward,done = env.step(action,label,next_batch_index)
        agent.memory.push(batch_index, action, reward, log_prob, value, done)
        state = next_state
        label = next_label
        episode_reward += reward
        
        if (step + 1) % update_interval == 0 or done:
            with torch.amp.autocast(device_type=device.type):
                _,_,_,next_value = agent.policy.get_action(
                    state, deterministic=True) if not done else (
                    None,None,None,torch.zeros_like(value)
                )
            agent.update(next_value, env.dataset)
        
        if done:
            break
    agent.save(param_path)
    torch.cuda.empty_cache()
    print(f"Episode: {episode+1}, Reward: {episode_reward.float().mean()}, batch: {step}")