In [None]:
"""
LSTM-SAC: LSTM-based Soft Actor-Critic for Cancer Chemotherapy (POMDP)
=======================================================================

This is a single-file, self-contained implementation of Recurrent SAC using
LSTM-based recurrent networks, following the patterns from the repository.

Uses the same RNNBase/ContextualModel patterns as the repository but with
layer_type='lstm' instead of 'gru'.

Features:
- LSTM-based recurrent actor and critic networks (using torch.nn.LSTM)
- Repository-style RNNHidden container with LSTM (h,c) tuple support
- Full-trajectory replay buffer with burn-in support
- Automatic entropy tuning
- Cancer chemotherapy (AhnChemoEnv) environment included

Requirements: torch, numpy, gymnasium, scipy, matplotlib, pandas

Author: LSTM-SAC Implementation following repository patterns
"""

import copy
import math
import os
import random
import time
from collections import namedtuple, OrderedDict
from typing import Dict, List, Optional, Tuple, Union, Any, TypeVar
from abc import abstractmethod

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
import pandas as pd

try:
    import seaborn as sns
    HAS_SEABORN = True
except ImportError:
    HAS_SEABORN = False


# ==============================================================================
# SECTION 1: AhnChemoEnv Environment (Cancer Chemotherapy)
# ==============================================================================

ObsType = TypeVar("ObsType")


def to_scalar(x):
    """Convert numpy array to scalar."""
    if isinstance(x, np.ndarray):
        return float(x.item())
    return float(x)


class BaseSimulator:
    def __init__(self): pass
    @abstractmethod
    def activate(self) -> OrderedDict[str, np.ndarray]: raise NotImplementedError
    @abstractmethod
    def update(self, action: Union[dict, float], state: dict) -> OrderedDict[str, np.ndarray]: raise NotImplementedError


class BaseReward:
    def __init__(self): pass
    @abstractmethod
    def count_reward(self, *args, **kwargs) -> float: raise NotImplementedError


def uniform_random(mean, width, absolute=False):
    """Generate uniform random value around mean."""
    def single_random(mean, width):
        if absolute:
            return float(np.random.uniform(mean - width, mean + width))
        return float(np.random.uniform(mean - mean * width, mean + mean * width))
    if isinstance(mean, (list, np.ndarray)):
        return [single_random(m, width) for m in mean]
    return single_random(mean, width)


class AhnReward(BaseReward):
    """Reward function for cancer chemotherapy optimization."""
    def __init__(self):
        super().__init__()

    def count_reward(self, state, init_state, action, terminated) -> float:
        N, T, I, B = state["N"], state["T"], state["I"], state["B"]
        N0, T0 = init_state["N"], init_state["T"]
        reward = (N / N0) - (T / T0) - action
        if terminated:
            return -100.0
        return to_scalar(reward)


class AhnODE(BaseSimulator):
    """ODE simulator for cancer cell dynamics with chemotherapy."""
    def __init__(self, state_noise, pkpd_noise):
        super().__init__()
        self.state_noise = state_noise
        self.pkpd_noise = pkpd_noise
        self.cur_time = 0
        self.time_interv = 0.25

    def activate(self) -> OrderedDict[str, np.ndarray]:
        width = self.pkpd_noise * 0.5
        self.r2, self.b2, self.c4, self.a3 = (
            uniform_random(1., width), uniform_random(1., width),
            uniform_random(1., width), uniform_random(0.1, width)
        )
        self.r1, self.b1, self.c2, self.c3, self.a2 = (
            uniform_random(1.5, width), uniform_random(1., width),
            uniform_random(0.5, width), uniform_random(1., width), uniform_random(0.3, width)
        )
        self.s, self.rho, self.alpha, self.c1, self.d1, self.a1 = (
            uniform_random(0.33, width), uniform_random(0.01, width),
            uniform_random(0.3, width), uniform_random(1., width),
            uniform_random(0.2, width), uniform_random(0.2, width)
        )
        self.d2 = uniform_random(1., width)
        init_state = OrderedDict({
            "N": np.array([uniform_random(0.9, self.state_noise * 0.5)], dtype=np.float32),
            "T": np.array([uniform_random(0.2, self.state_noise * 0.5)], dtype=np.float32),
            "I": np.array([uniform_random(0.005, self.state_noise * 0.5)], dtype=np.float32),
            "B": np.array([0.0], dtype=np.float32)
        })
        self.cur_time = 0
        return init_state

    def update(self, action, state):
        def odes_fn(t, variables, u):
            N, T, I, B = variables
            dNdt = self.r2 * N * (1 - self.b2 * N) - self.c4 * T * N - self.a3 * (1 - np.exp(-B)) * N
            dTdt = self.r1 * T * (1 - self.b1 * T) - self.c2 * I * T - self.c3 * T * N - self.a2 * (1 - np.exp(-B)) * T
            dIdt = self.s + self.rho * I * T / (self.alpha + T) - self.c1 * I * T - self.d1 * I - self.a1 * (1 - np.exp(-B)) * I
            dBdt = -self.d2 * B + u
            if self.state_noise > 0:
                noise = np.random.normal(0, self.state_noise, 4)
                dNdt += dNdt * noise[0]
                dTdt += dTdt * noise[1]
                dIdt += dIdt * noise[2]
                dBdt += dBdt * noise[3]
            return [dNdt, dTdt, dIdt, dBdt]

        variables = np.array((state["N"], state["T"], state["I"], state["B"])).flatten()
        scalar_action = to_scalar(action)
        solution = solve_ivp(odes_fn, (self.cur_time, self.cur_time + self.time_interv),
                            variables, args=(scalar_action,))
        N, T, I, B = [max(0, solution.y[i, -1]) for i in range(4)]
        self.cur_time += self.time_interv
        return OrderedDict({
            "N": np.array([N], dtype=np.float32),
            "T": np.array([T], dtype=np.float32),
            "I": np.array([I], dtype=np.float32),
            "B": np.array([B], dtype=np.float32)
        })


class AhnChemoEnv(gym.Env):
    """Cancer Chemotherapy Environment (POMDP)."""
    def __init__(self, max_t=600, obs_noise=0.2, state_noise=0.5, pkpd_noise=0.1,
                 missing_rate=0.0, **kwargs):
        super().__init__()
        self.Simulator = AhnODE(state_noise=state_noise, pkpd_noise=pkpd_noise)
        self.Reward = AhnReward()
        self.obs_noise = obs_noise
        self.missing_rate = missing_rate
        self.max_t = max_t
        self.t = 0
        self.observation_space = spaces.Box(low=0.0, high=2.0, shape=(3,), dtype=np.float32)
        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)
        self.state_map = {"N": "Normal Cells", "T": "Tumor Cells", "I": "Immune Cells", "B": "Drug Concentration"}

    def _state2obs(self, state, enable_missing):
        obs = np.array([state['T'], state['I'], state["B"]]).flatten()
        obs += self.obs_noise * obs * np.random.uniform(-0.5, 0.5, size=obs.shape)
        obs = np.clip(obs, self.observation_space.low, self.observation_space.high).astype(np.float32)
        if enable_missing and np.random.uniform(0, 1) < self.missing_rate:
            return self.last_obs
        self.last_obs = obs
        return obs

    def reset(self, seed=None, **kwargs):
        super().reset(seed=seed)
        if seed is not None:
            np.random.seed(seed)
        self.t = 0
        self.init_state = self.Simulator.activate()
        self.cur_state = self.init_state
        obs = self._state2obs(self.init_state, False)
        info = {"state": {self.state_map[k]: to_scalar(v) for k, v in self.init_state.items()}}
        return obs, info

    def step(self, action):
        action_scalar = np.clip(to_scalar(action), 0.0, 1.0)
        state_next = self.Simulator.update(action=np.array([action_scalar]), state=self.cur_state)
        obs_next = self._state2obs(state_next, True)
        terminated = to_scalar(state_next["N"]) < to_scalar(self.init_state["N"]) * 0.7
        truncated = self.t + 1 >= self.max_t
        reward = self.Reward.count_reward(self.cur_state, self.init_state, action_scalar, terminated)
        self.t += 1
        self.cur_state = state_next
        info = {"state": {self.state_map[k]: to_scalar(v) for k, v in state_next.items()}}
        return obs_next, reward, terminated, truncated, info


def create_AhnChemoEnv_setting1():
    """Create AhnChemoEnv with Setting 1 (lower noise)."""
    return AhnChemoEnv(obs_noise=0.5, state_noise=0.0, pkpd_noise=0.0)


# ==============================================================================
# SECTION 2: RNNHidden Container (From Repository - with LSTM support)
# ==============================================================================

class RNNHidden:
    """
    Container for managing RNN hidden states across multiple recurrent layers.

    Supports both GRU (single tensor) and LSTM (tuple of h,c tensors).
    This follows the repository's RNNHidden pattern exactly.
    """

    def __init__(self, rnn_num: int, rnn_types: List[str],
                 device: torch.device = torch.device('cpu'), batch_first: bool = False):
        self._rnn_types = copy.deepcopy(rnn_types)
        self._rnn_num = rnn_num
        self._data: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = []
        self._device = device
        self._batch_first = batch_first

    def append(self, hidden_state: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> None:
        """Append hidden state. For LSTM, expects (h, c) tuple."""
        assert len(self._data) < self._rnn_num
        self._data.append(hidden_state)

    def __getitem__(self, key) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], "RNNHidden"]:
        if isinstance(key, slice):
            rnn_types = self._rnn_types[key]
            data = self._data[key]
            result = RNNHidden(len(data), rnn_types, self._device, self._batch_first)
            result._data = data
            return result
        return self._data[key]

    def __setitem__(self, key, value):
        self._data[key] = value

    def __len__(self) -> int:
        return len(self._data)

    def __add__(self, other: "RNNHidden") -> "RNNHidden":
        if other is None:
            return self
        res = RNNHidden(self._rnn_num + other._rnn_num,
                       self._rnn_types + other._rnn_types, self._device, self._batch_first)
        res._data = self._data + other._data
        return res

    @property
    def device(self) -> torch.device:
        return self._device

    def detach(self) -> "RNNHidden":
        """Detach hidden states from computation graph."""
        result = RNNHidden(self._rnn_num, self._rnn_types, self._device, self._batch_first)
        for data, rnn_type in zip(self._data, self._rnn_types):
            if isinstance(data, tuple):  # LSTM
                result._data.append((data[0].detach(), data[1].detach()))
            else:  # GRU and others
                result._data.append(data.detach())
        return result

    @staticmethod
    def init_hidden_by_type(rnn_type: str, batch_size: int, unit_num: int,
                            device: torch.device) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        """Initialize hidden state by RNN type. LSTM returns (h, c) tuple."""
        if rnn_type == 'lstm':
            # LSTM hidden state: (h_0, c_0), each of shape (num_layers, batch, hidden_size)
            return (torch.zeros((1, batch_size, unit_num), device=device),
                    torch.zeros((1, batch_size, unit_num), device=device))
        else:
            # GRU and others: single tensor
            return torch.zeros((1, batch_size, unit_num), device=device)


# ==============================================================================
# SECTION 3: RNNBase (From Repository - with LSTM layer support)
# ==============================================================================

class RNNBase(nn.Module):
    """
    Base class for building networks with FC and RNN layers.

    Supports layer_type='lstm' or 'gru' for recurrent layers.
    This follows the repository's RNNBase pattern.
    """

    ACTIVATION_DICT = {
        'tanh': nn.Tanh, 'relu': nn.ReLU, 'sigmoid': nn.Sigmoid,
        'leaky_relu': nn.LeakyReLU, 'linear': nn.Identity, 'elu': nn.ELU, 'gelu': nn.GELU,
    }

    LAYER_DICT = {
        'fc': nn.Linear,
        'lstm': nn.LSTM,  # LSTM outputs (output, (h_n, c_n))
        'gru': nn.GRU,    # GRU outputs (output, h_n)
    }

    def __init__(self, input_size: int, output_size: int, hidden_sizes: List[int],
                 activations: List[str], layer_types: List[str]):
        super().__init__()
        assert len(activations) == len(hidden_sizes) + 1
        assert len(activations) == len(layer_types)

        self.layer_types = layer_types
        self.layers = nn.ModuleList()
        self.activations = nn.ModuleList()
        self.rnn_hidden_sizes = []
        self.rnn_layer_types = []
        self.rnn_num = 0
        self.input_size = input_size

        sizes = [input_size] + hidden_sizes + [output_size]
        for i, (in_s, out_s) in enumerate(zip(sizes[:-1], sizes[1:])):
            ltype = layer_types[i]
            if ltype == 'fc':
                self.layers.append(nn.Linear(in_s, out_s))
            elif ltype in ['gru', 'lstm']:
                self.layers.append(self.LAYER_DICT[ltype](in_s, out_s, batch_first=True))
                self.rnn_hidden_sizes.append(out_s)
                self.rnn_layer_types.append(ltype)
                self.rnn_num += 1
            self.activations.append(self.ACTIVATION_DICT[activations[i]]())

        self._init_weights()

    def _init_weights(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0)
            elif isinstance(layer, (nn.LSTM, nn.GRU)):
                for name, param in layer.named_parameters():
                    if 'weight_ih' in name:
                        nn.init.xavier_uniform_(param)
                    elif 'weight_hh' in name:
                        nn.init.orthogonal_(param)
                    elif 'bias' in name:
                        nn.init.zeros_(param)

    def make_init_state(self, batch_size: int, device: torch.device) -> RNNHidden:
        """Create initial hidden state for all RNN layers."""
        hidden = RNNHidden(self.rnn_num, self.rnn_layer_types, device)
        for size, rnn_type in zip(self.rnn_hidden_sizes, self.rnn_layer_types):
            hidden.append(RNNHidden.init_hidden_by_type(rnn_type, batch_size, size, device))
        return hidden

    def forward(self, x: torch.Tensor, hidden: Optional[RNNHidden] = None,
                return_full: bool = False) -> Tuple[torch.Tensor, RNNHidden, Optional[RNNHidden]]:
        if hidden is None:
            hidden = self.make_init_state(x.shape[0], x.device)

        x_dim = len(x.shape)
        if x_dim == 2 and self.rnn_num > 0:
            x = x.unsqueeze(1)  # Add sequence dimension

        out_hidden = RNNHidden(self.rnn_num, self.rnn_layer_types, x.device)
        full_hidden = RNNHidden(self.rnn_num, self.rnn_layer_types, x.device, batch_first=True) if return_full else None

        rnn_idx = 0
        for layer, activation, ltype in zip(self.layers, self.activations, self.layer_types):
            if ltype in ['gru', 'lstm']:
                h = hidden[rnn_idx]
                if ltype == 'lstm':
                    # LSTM: hidden is (h, c) tuple
                    x, (h_n, c_n) = layer(x, h)
                    out_hidden.append((h_n, c_n))
                else:
                    # GRU: hidden is single tensor
                    x, h_n = layer(x, h)
                    out_hidden.append(h_n)
                if return_full:
                    full_hidden.append(x)
                rnn_idx += 1
            else:
                x = layer(x)
            x = activation(x)

        if x_dim == 2 and self.rnn_num > 0:
            x = x.squeeze(1)

        return x, out_hidden, full_hidden


# ==============================================================================
# SECTION 4: Contextual Model (Embedding + Universal Network) - From Repository
# ==============================================================================

class ContextualModel:
    """Two-stage model: embedding network processes history, universal network outputs action/value."""

    def __init__(self, embed_in: int, embed_out: int, embed_hidden: List[int],
                 embed_acts: List[str], embed_types: List[str],
                 uni_in: int, uni_out: int, uni_hidden: List[int],
                 uni_acts: List[str], uni_types: List[str], name: str = 'ContextualModel'):
        self.name = name
        self.embedding_net = RNNBase(embed_in, embed_out, embed_hidden, embed_acts, embed_types)
        self.uni_net = RNNBase(embed_out + uni_in, uni_out, uni_hidden, uni_acts, uni_types)
        self.rnn_num = self.embedding_net.rnn_num + self.uni_net.rnn_num
        self.device = torch.device('cpu')
        self._modules = {'embedding': self.embedding_net, 'universal': self.uni_net}

    def parameters(self, recurse: bool = True) -> List[torch.Tensor]:
        return list(self.embedding_net.parameters(recurse)) + list(self.uni_net.parameters(recurse))

    def to(self, device: torch.device):
        if device != self.device:
            self.device = device
            self.embedding_net.to(device)
            self.uni_net.to(device)
        return self

    def train(self, mode: bool = True):
        self.embedding_net.train(mode)
        self.uni_net.train(mode)

    def eval(self):
        self.train(False)

    def make_init_state(self, batch_size: int, device: torch.device) -> RNNHidden:
        return self.embedding_net.make_init_state(batch_size, device) + \
               self.uni_net.make_init_state(batch_size, device)

    def meta_forward(self, embed_input: torch.Tensor, uni_input: torch.Tensor,
                     hidden: Optional[RNNHidden] = None, detach_embed: bool = False
                     ) -> Tuple[torch.Tensor, RNNHidden, torch.Tensor]:
        if hidden is None:
            hidden = self.make_init_state(embed_input.shape[0], embed_input.device)

        embed_hidden = hidden[:self.embedding_net.rnn_num] if len(hidden) > 0 else None
        uni_hidden = hidden[self.embedding_net.rnn_num:] if len(hidden) > 0 else None

        embed_out, embed_h, _ = self.embedding_net.forward(embed_input, embed_hidden, False)

        if detach_embed:
            embed_out = embed_out.detach()

        if len(embed_out.shape) - len(uni_input.shape) == 1:
            uni_input = uni_input.unsqueeze(1).expand(embed_out.shape[0], embed_out.shape[1], -1)

        combined = torch.cat([uni_input, embed_out], dim=-1)
        uni_out, uni_h, _ = self.uni_net.forward(combined, uni_hidden, False)

        return uni_out, embed_h + uni_h, embed_out

    def state_dict(self):
        return {k: v.state_dict() for k, v in self._modules.items()}

    def load_state_dict(self, state_dict):
        for k, v in self._modules.items():
            v.load_state_dict(state_dict[k])

    def save(self, path: str):
        os.makedirs(path, exist_ok=True)
        for k, v in self._modules.items():
            torch.save(v.state_dict(), os.path.join(path, f'{self.name}-{k}.pt'))

    def load(self, path: str, map_location=None):
        for k, v in self._modules.items():
            v.load_state_dict(torch.load(os.path.join(path, f'{self.name}-{k}.pt'),
                                         map_location=map_location))

    def __call__(self, *args, **kwargs):
        """Make ContextualModel callable like nn.Module."""
        return self.forward(*args, **kwargs)


# ==============================================================================
# SECTION 5: LSTM-SAC Policy Network (Actor) - Using Repository Pattern with LSTM
# ==============================================================================

class RecurrentSACPolicy(ContextualModel):
    """
    Recurrent SAC policy (actor) with LSTM-based context encoding.

    Uses layer_type='lstm' instead of 'gru' for LSTM recurrence.
    """

    MAX_LOG_STD = 2.0
    MIN_LOG_STD = -20.0

    def __init__(self, state_dim: int, action_dim: int, embed_dim: int = 64,
                 embed_hidden: List[int] = [64], uni_hidden: List[int] = [256, 256],
                 use_last_action: bool = True):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.use_last_action = use_last_action

        embed_in = state_dim
        if use_last_action:
            embed_in += action_dim

        # Use LSTM instead of GRU
        embed_acts = ['relu'] * len(embed_hidden) + ['relu']
        embed_types = ['lstm'] * (len(embed_hidden) + 1)  # LSTM layers
        uni_acts = ['relu'] * len(uni_hidden) + ['linear']
        uni_types = ['fc'] * (len(uni_hidden) + 1)

        super().__init__(embed_in, embed_dim, embed_hidden, embed_acts, embed_types,
                        state_dim, action_dim * 2, uni_hidden, uni_acts, uni_types,
                        name='RecurrentSACPolicy')

    def get_embed_input(self, state: torch.Tensor, last_action: torch.Tensor) -> torch.Tensor:
        inputs = [state]
        if self.use_last_action:
            inputs.append(last_action)
        return torch.cat(inputs, dim=-1)

    def forward(self, state: torch.Tensor, last_action: torch.Tensor,
                hidden: Optional[RNNHidden] = None
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, RNNHidden]:
        embed_input = self.get_embed_input(state, last_action)
        output, hidden, embed_out = self.meta_forward(embed_input, state, hidden)

        mean, log_std = output.chunk(2, dim=-1)
        log_std = torch.clamp(log_std, self.MIN_LOG_STD, self.MAX_LOG_STD)
        std = log_std.exp()

        noise = torch.randn_like(mean)
        sample = mean + noise * std

        log_prob = (-0.5 * noise.pow(2) - log_std - 0.5 * np.log(2 * np.pi)).sum(-1, keepdim=True)
        log_prob -= (2 * (np.log(2) - sample - F.softplus(-2 * sample))).sum(-1, keepdim=True)

        action_mean = torch.tanh(mean)
        action_sample = torch.tanh(sample)

        return action_mean, action_sample, log_prob, hidden


# ==============================================================================
# SECTION 6: LSTM-SAC Value Network (Critic) - Using Repository Pattern with LSTM
# ==============================================================================

class RecurrentSACValue(ContextualModel):
    """
    Recurrent SAC Q-value network (critic) with LSTM-based context encoding.

    Uses layer_type='lstm' instead of 'gru' for LSTM recurrence.
    """

    def __init__(self, state_dim: int, action_dim: int, embed_dim: int = 64,
                 embed_hidden: List[int] = [64], uni_hidden: List[int] = [256, 256],
                 use_last_action: bool = True):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.use_last_action = use_last_action

        embed_in = state_dim
        if use_last_action:
            embed_in += action_dim

        # Use LSTM instead of GRU
        embed_acts = ['relu'] * len(embed_hidden) + ['relu']
        embed_types = ['lstm'] * (len(embed_hidden) + 1)  # LSTM layers
        uni_acts = ['relu'] * len(uni_hidden) + ['linear']
        uni_types = ['fc'] * (len(uni_hidden) + 1)

        super().__init__(embed_in, embed_dim, embed_hidden, embed_acts, embed_types,
                        state_dim + action_dim, 1, uni_hidden, uni_acts, uni_types,
                        name='RecurrentSACValue')

    def get_embed_input(self, state: torch.Tensor, last_action: torch.Tensor) -> torch.Tensor:
        inputs = [state]
        if self.use_last_action:
            inputs.append(last_action)
        return torch.cat(inputs, dim=-1)

    def forward(self, state: torch.Tensor, last_action: torch.Tensor, action: torch.Tensor,
                hidden: Optional[RNNHidden] = None, detach_embed: bool = False
                ) -> Tuple[torch.Tensor, RNNHidden]:
        embed_input = self.get_embed_input(state, last_action)
        uni_input = torch.cat([state, action], dim=-1)
        q_value, hidden, embed_out = self.meta_forward(embed_input, uni_input, hidden, detach_embed)
        return q_value, hidden


# ==============================================================================
# SECTION 7: Trajectory Replay Buffer
# ==============================================================================

Transition = namedtuple('Transition', [
    'state', 'last_action', 'action', 'next_state', 'reward', 'done', 'timeout'
])


class SequenceReplayBuffer:
    """Replay buffer storing full trajectories for sequence training."""

    def __init__(self, max_trajectories: int = 1000, max_length: int = 1000):
        self.max_trajectories = max_trajectories
        self.max_length = max_length
        self.buffer: Optional[np.ndarray] = None
        self.lengths = [0] * max_trajectories
        self.ptr = 0
        self.size = 0
        self.total_steps = 0
        self.current_trajectory: List[Transition] = []
        self.dim_info: Optional[List[Tuple[int, int]]] = None

    def _init_buffer(self, transition: Transition):
        self.dim_info = []
        total_dim = 0
        for item in transition:
            dim = item.shape[-1] if isinstance(item, np.ndarray) else 1
            self.dim_info.append((total_dim, total_dim + dim))
            total_dim += dim
        self.buffer = np.zeros((self.max_trajectories, self.max_length, total_dim))

    def _transition_to_array(self, t: Transition) -> np.ndarray:
        arrays = []
        for item in t:
            if isinstance(item, np.ndarray):
                arrays.append(item.flatten())
            else:
                arrays.append(np.array([item]))
        return np.concatenate(arrays)

    def _array_to_dict(self, data: np.ndarray) -> Dict[str, np.ndarray]:
        fields = Transition._fields
        result = {}
        for i, field in enumerate(fields):
            start, end = self.dim_info[i]
            result[field] = data[..., start:end]
        return result

    def push(self, transition: Transition):
        self.current_trajectory.append(transition)
        if transition.done:
            self._complete_trajectory()

    def _complete_trajectory(self):
        if len(self.current_trajectory) == 0:
            return
        if self.buffer is None:
            self._init_buffer(self.current_trajectory[0])

        traj_len = min(len(self.current_trajectory), self.max_length)
        self.buffer[self.ptr] = 0

        for i in range(traj_len):
            self.buffer[self.ptr, i] = self._transition_to_array(self.current_trajectory[i])

        self.total_steps -= self.lengths[self.ptr]
        self.lengths[self.ptr] = traj_len
        self.total_steps += traj_len

        self.ptr = (self.ptr + 1) % self.max_trajectories
        self.size = min(self.size + 1, self.max_trajectories)
        self.current_trajectory = []

    def sample_chunks(self, batch_size: int, chunk_len: int = 64) -> Tuple[Dict[str, np.ndarray], np.ndarray, int]:
        if self.size == 0:
            raise ValueError("Buffer is empty")

        chunks = []
        masks = []

        for _ in range(batch_size):
            traj_idx = np.random.randint(0, self.size)
            traj_len = self.lengths[traj_idx]

            if traj_len <= chunk_len:
                chunk = self.buffer[traj_idx, :traj_len]
                pad_len = chunk_len - traj_len
                if pad_len > 0:
                    pad = np.zeros((pad_len, chunk.shape[-1]))
                    chunk = np.vstack([chunk, pad])
                mask = np.zeros((chunk_len, 1))
                mask[:traj_len] = 1
            else:
                start = np.random.randint(0, traj_len - chunk_len + 1)
                chunk = self.buffer[traj_idx, start:start + chunk_len]
                mask = np.ones((chunk_len, 1))

            chunks.append(chunk)
            masks.append(mask)

        chunks = np.stack(chunks, axis=0)
        masks = np.stack(masks, axis=0)
        data = self._array_to_dict(chunks)
        return data, masks, int(masks.sum())

    def __len__(self) -> int:
        return self.size


# ==============================================================================
# SECTION 8: Utility Functions
# ==============================================================================

def normalize_action(action: np.ndarray, action_space) -> np.ndarray:
    return (action - action_space.low) / (action_space.high - action_space.low) * 2 - 1

def denormalize_action(action: np.ndarray, action_space) -> np.ndarray:
    return (action + 1) / 2 * (action_space.high - action_space.low) + action_space.low

def to_torch(data: np.ndarray, device: torch.device) -> torch.Tensor:
    return torch.from_numpy(data).float().to(device)

def to_numpy(tensor: torch.Tensor) -> np.ndarray:
    return tensor.detach().cpu().numpy()


# ==============================================================================
# SECTION 9: LSTM-SAC Agent
# ==============================================================================

class LSTMSAC:
    """
    LSTM-based Soft Actor-Critic Agent.

    Uses the repository's ContextualModel pattern with LSTM layers.
    """

    def __init__(self, env, seed: int = 0, gamma: float = 0.99, tau: float = 0.002,
                 alpha: float = 0.2, lr: float = 1e-4, buffer_size: int = 1000,
                 batch_size: int = 128, embed_dim: int = 64,
                 embed_hidden: List[int] = [64], uni_hidden: List[int] = [256, 256],
                 auto_alpha: bool = True, warmup_steps: int = 1000,
                 update_interval: int = 4, max_traj_len: int = 600,
                 chunk_len: int = 96, burn_in: int = 32, save_best_path: str = './models/best'):

        self._set_seed(seed)
        self.seed = seed

        self.env = env
        self.eval_env = create_AhnChemoEnv_setting1()

        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]
        self.max_traj_len = min(max_traj_len, getattr(env, 'max_t', max_traj_len))

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        print(f"LSTM config: embed_dim={embed_dim}, embed_hidden={embed_hidden}")

        # Hyperparameters
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.warmup_steps = warmup_steps
        self.update_interval = update_interval
        self.auto_alpha = auto_alpha
        self.chunk_len = chunk_len
        self.burn_in = burn_in

        # Networks using repository pattern with LSTM
        self.policy = RecurrentSACPolicy(
            self.state_dim, self.action_dim, embed_dim, embed_hidden, uni_hidden
        )
        self.policy.to(self.device)

        self.q1 = RecurrentSACValue(
            self.state_dim, self.action_dim, embed_dim, embed_hidden, uni_hidden
        )
        self.q1.to(self.device)

        self.q2 = RecurrentSACValue(
            self.state_dim, self.action_dim, embed_dim, embed_hidden, uni_hidden
        )
        self.q2.to(self.device)

        self.q1_target = RecurrentSACValue(
            self.state_dim, self.action_dim, embed_dim, embed_hidden, uni_hidden
        )
        self.q1_target.to(self.device)
        self.q1_target.load_state_dict(self.q1.state_dict())

        self.q2_target = RecurrentSACValue(
            self.state_dim, self.action_dim, embed_dim, embed_hidden, uni_hidden
        )
        self.q2_target.to(self.device)
        self.q2_target.load_state_dict(self.q2.state_dict())

        # Entropy
        self.target_entropy = -self.action_dim
        self.log_alpha = torch.tensor(np.log(alpha), device=self.device, requires_grad=True)

        # Optimizers
        self.policy_optim = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.q1_optim = torch.optim.Adam(self.q1.parameters(), lr=lr)
        self.q2_optim = torch.optim.Adam(self.q2.parameters(), lr=lr)
        self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=lr)

        # Replay buffer
        self.buffer = SequenceReplayBuffer(buffer_size, self.max_traj_len)

        # Tracking
        self.total_steps = 0
        self.episodes = 0
        self.best_eval_reward = -float('inf')
        self.save_best_path = save_best_path
        self.reward_history: List[Dict] = []

    def _set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    def _reset_env(self, env, seed=None):
        try:
            if seed is not None:
                obs, _ = env.reset(seed=seed)
            else:
                obs, _ = env.reset()
        except TypeError:
            result = env.reset()
            obs = result[0] if isinstance(result, tuple) else result
        return obs

    def _step_env(self, env, action):
        result = env.step(action)
        if len(result) == 5:
            obs, reward, terminated, truncated, info = result
            done = terminated or truncated
        else:
            obs, reward, done, info = result
        return obs, reward, done, info

    @property
    def alpha(self) -> torch.Tensor:
        return self.log_alpha.exp()

    def select_action(self, state: np.ndarray, last_action: np.ndarray,
                      hidden: RNNHidden, deterministic: bool = False
                      ) -> Tuple[np.ndarray, RNNHidden]:
        with torch.no_grad():
            state_t = to_torch(state.reshape(1, -1), self.device)
            last_action_t = to_torch(last_action.reshape(1, -1), self.device)

            action_mean, action_sample, _, new_hidden = self.policy(
                state_t, last_action_t, hidden
            )

            action = action_mean if deterministic else action_sample
            return to_numpy(action).squeeze(0), new_hidden

    def _masked_mean(self, data: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        return (data * mask).sum() / mask.sum().clamp(min=1)

    def _soft_update(self, target: ContextualModel, source: ContextualModel):
        with torch.no_grad():
            for tp, sp in zip(target.parameters(), source.parameters()):
                tp.data.copy_(self.tau * sp.data + (1 - self.tau) * tp.data)

    def update(self) -> Dict[str, float]:
        train_len = self.chunk_len - self.burn_in
        if train_len <= 0:
            raise ValueError(f"chunk_len ({self.chunk_len}) must be > burn_in ({self.burn_in})")

        data, mask, _ = self.buffer.sample_chunks(self.batch_size, chunk_len=self.chunk_len)

        state = to_torch(data['state'], self.device)
        last_action = to_torch(data['last_action'], self.device)
        action = to_torch(data['action'], self.device)
        next_state = to_torch(data['next_state'], self.device)
        reward = to_torch(data['reward'], self.device)
        done = to_torch(data['done'], self.device)
        timeout = to_torch(data['timeout'], self.device)
        mask_t = to_torch(mask.copy(), self.device)

        # Burn-in phase
        with torch.no_grad():
            state_burn = state[:, :self.burn_in]
            last_action_burn = last_action[:, :self.burn_in]
            n_trajs = state.shape[0]

            policy_h = self.policy.make_init_state(n_trajs, self.device)
            q1_h = self.q1.make_init_state(n_trajs, self.device)
            q2_h = self.q2.make_init_state(n_trajs, self.device)
            q1t_h = self.q1_target.make_init_state(n_trajs, self.device)
            q2t_h = self.q2_target.make_init_state(n_trajs, self.device)

            if self.burn_in > 0:
                _, _, _, policy_h = self.policy(state_burn, last_action_burn, policy_h)
                _, q1_h = self.q1(state_burn, last_action_burn, action[:, :self.burn_in], q1_h)
                _, q2_h = self.q2(state_burn, last_action_burn, action[:, :self.burn_in], q2_h)
                _, q1t_h = self.q1_target(state_burn, last_action_burn, action[:, :self.burn_in], q1t_h)
                _, q2t_h = self.q2_target(state_burn, last_action_burn, action[:, :self.burn_in], q2t_h)

        # Training phase
        state_train = state[:, self.burn_in:]
        last_action_train = last_action[:, self.burn_in:]
        action_train = action[:, self.burn_in:]
        next_state_train = next_state[:, self.burn_in:]
        reward_train = reward[:, self.burn_in:]
        done_train = done[:, self.burn_in:]
        timeout_train = timeout[:, self.burn_in:]
        mask_train = mask_t[:, self.burn_in:]

        done_train = done_train * (1 - timeout_train)
        alpha = self.alpha.detach()

        # Target Q
        with torch.no_grad():
            _, next_action, next_log_prob, _ = self.policy(next_state_train, action_train, policy_h.detach())
            q1_next, _ = self.q1_target(next_state_train, action_train, next_action, q1t_h)
            q2_next, _ = self.q2_target(next_state_train, action_train, next_action, q2t_h)
            q_next = torch.min(q1_next, q2_next) - alpha * next_log_prob
            target_q = reward_train + (1 - done_train) * self.gamma * q_next

        # Critic update
        q1_pred, _ = self.q1(state_train, last_action_train, action_train, q1_h)
        q2_pred, _ = self.q2(state_train, last_action_train, action_train, q2_h)

        q1_loss = self._masked_mean((q1_pred - target_q).pow(2), mask_train)
        q2_loss = self._masked_mean((q2_pred - target_q).pow(2), mask_train)

        self.q1_optim.zero_grad()
        q1_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q1.parameters(), max_norm=10.0)
        self.q1_optim.step()

        self.q2_optim.zero_grad()
        q2_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q2.parameters(), max_norm=10.0)
        self.q2_optim.step()

        # Actor update
        _, new_action, log_prob, _ = self.policy(state_train, last_action_train, policy_h.detach())
        q1_pi, _ = self.q1(state_train, last_action_train, new_action, q1_h.detach(), detach_embed=True)
        q2_pi, _ = self.q2(state_train, last_action_train, new_action, q2_h.detach(), detach_embed=True)
        q_pi = torch.min(q1_pi, q2_pi)

        policy_loss = self._masked_mean(alpha * log_prob - q_pi, mask_train)

        self.policy_optim.zero_grad()
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=10.0)
        self.policy_optim.step()

        # Alpha update
        alpha_loss = torch.tensor(0.0)
        if self.auto_alpha:
            alpha_loss = -self._masked_mean(
                self.log_alpha * (log_prob + self.target_entropy).detach(), mask_train
            )
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=10.0)
            self.alpha_optim.step()

            # Clamp alpha to prevent collapse
            with torch.no_grad():
                self.log_alpha.clamp_(min=np.log(0.12))

        # Soft update targets
        self._soft_update(self.q1_target, self.q1)
        self._soft_update(self.q2_target, self.q2)

        return {
            'q1_loss': q1_loss.item(),
            'q2_loss': q2_loss.item(),
            'policy_loss': policy_loss.item(),
            'alpha': self.alpha.item(),
        }

    def train(self, total_steps: int = 50000, eval_interval: int = 5000,
              log_interval: int = 1000) -> Tuple[List[float], List[int]]:
        print(f"Starting LSTM-SAC training for {total_steps} steps...")

        ep_rewards = []
        ep_lengths = []
        start_time = time.time()

        obs = self._reset_env(self.env, self.seed)
        last_action = np.zeros(self.action_dim)
        hidden = self.policy.make_init_state(1, self.device)
        ep_reward = 0.0
        ep_length = 0
        losses = {'alpha': self.alpha.item()}

        for step in range(1, total_steps + 1):
            self.total_steps = step

            if step < self.warmup_steps:
                action = self.env.action_space.sample()
            else:
                action, hidden = self.select_action(obs, last_action, hidden)

            env_action = denormalize_action(action, self.env.action_space)
            next_obs, reward, done, info = self._step_env(self.env, env_action)
            timeout = info.get('TimeLimit.truncated', False) if isinstance(info, dict) else False

            transition = Transition(
                state=obs.copy(),
                last_action=last_action.copy(),
                action=action.copy() if isinstance(action, np.ndarray) else np.array([action]),
                next_state=next_obs.copy(),
                reward=np.array([reward]),
                done=np.array([float(done)]),
                timeout=np.array([float(timeout)])
            )
            self.buffer.push(transition)

            ep_reward += reward
            ep_length += 1
            last_action = action.copy() if isinstance(action, np.ndarray) else np.array([action])
            obs = next_obs

            if done:
                ep_rewards.append(ep_reward)
                ep_lengths.append(ep_length)
                self.episodes += 1

                cumulative_steps = sum(ep_lengths)
                self.reward_history.append({
                    'episode': self.episodes,
                    'timestep': cumulative_steps,
                    'reward': ep_reward,
                    'length': ep_length
                })

                obs = self._reset_env(self.env)
                last_action = np.zeros(self.action_dim)
                hidden = self.policy.make_init_state(1, self.device)
                ep_reward = 0.0
                ep_length = 0

            if step >= self.warmup_steps and len(self.buffer) > 0:
                if step % self.update_interval == 0:
                    losses = self.update()

                if step % log_interval == 0:
                    avg_reward = np.mean(ep_rewards[-10:]) if ep_rewards else 0
                    elapsed = time.time() - start_time
                    print(f"Step {step}/{total_steps} | Eps: {self.episodes} | "
                          f"Reward: {avg_reward:.1f} | Alpha: {losses['alpha']:.4f} | "
                          f"Time: {elapsed:.0f}s")

                    csv_path = os.path.join("./results", "reward.csv")
                    self.save_rewards_csv(csv_path)

            if step % eval_interval == 0:
                eval_reward = self.evaluate(n_episodes=5)
                print(f"[EVAL] Step {step} | Mean Reward: {eval_reward:.2f}")

                # Save best model checkpoint
                if eval_reward > self.best_eval_reward:
                    self.best_eval_reward = eval_reward
                    self.save(self.save_best_path)
                    print(f"[BEST] New best model saved! Reward: {eval_reward:.2f}")

        print(f"\nTraining completed in {time.time() - start_time:.1f}s")
        return ep_rewards, ep_lengths

    def evaluate(self, n_episodes: int = 10) -> float:
        self.policy.eval()
        total_reward = 0.0

        for _ in range(n_episodes):
            obs = self._reset_env(self.eval_env)
            last_action = np.zeros(self.action_dim)
            hidden = self.policy.make_init_state(1, self.device)
            done = False

            while not done:
                action, hidden = self.select_action(obs, last_action, hidden, deterministic=True)
                env_action = denormalize_action(action, self.eval_env.action_space)
                obs, reward, done, _ = self._step_env(self.eval_env, env_action)
                total_reward += reward
                last_action = action.copy() if isinstance(action, np.ndarray) else np.array([action])

        self.policy.train()
        return total_reward / n_episodes

    def save(self, path: str):
        os.makedirs(path, exist_ok=True)
        self.policy.save(path)
        self.q1.save(path)
        self.q2.save(path)
        torch.save({'log_alpha': self.log_alpha}, os.path.join(path, 'alpha.pt'))
        print(f"Model saved to {path}")

    def load(self, path: str):
        self.policy.load(path, map_location=self.device)
        self.q1.load(path, map_location=self.device)
        self.q2.load(path, map_location=self.device)
        checkpoint = torch.load(os.path.join(path, 'alpha.pt'), map_location=self.device)
        self.log_alpha = checkpoint['log_alpha']
        print(f"Model loaded from {path}")

    def save_rewards_csv(self, path: str):
        if len(self.reward_history) == 0:
            return
        os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
        df = pd.DataFrame(self.reward_history)
        df.to_csv(path, index=False)
        print(f"Rewards saved to: {path}")


# ==============================================================================
# SECTION 10: Plotting Function
# ==============================================================================

def plot_training_reward_curve(ep_rewards: List[float], experiment_name: str,
                               ep_lengths: Optional[List[int]] = None,
                               window_size: int = 10, save_path: Optional[str] = None):
    """Plot training reward curve with rolling statistics."""
    print(f"\n--- Plotting Training Reward Curve ---")

    results_df = pd.DataFrame({
        'episode': list(range(1, len(ep_rewards) + 1)),
        'r': ep_rewards,
        'l': ep_lengths if ep_lengths else [600] * len(ep_rewards)
    })
    results_df['timesteps'] = results_df['l'].cumsum()
    results_df['reward_mean'] = results_df['r'].rolling(window=window_size, min_periods=1).mean()
    results_df['reward_std'] = results_df['r'].rolling(window=window_size, min_periods=1).std().fillna(0)
    results_df['lower_bound'] = results_df['reward_mean'] - results_df['reward_std']
    results_df['upper_bound'] = results_df['reward_mean'] + results_df['reward_std']

    if HAS_SEABORN:
        sns.set(style="darkgrid", font_scale=1.2)

    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    color = sns.color_palette()[0] if HAS_SEABORN else 'blue'

    ax1 = axes[0]
    ax1.fill_between(results_df['episode'], results_df['lower_bound'],
                     results_df['upper_bound'], alpha=0.2, color=color)
    ax1.plot(results_df['episode'], results_df['r'], alpha=0.15, color='gray')
    ax1.plot(results_df['episode'], results_df['reward_mean'], linewidth=2, color=color)
    ax1.set_title(f'{experiment_name} - Reward Over Episodes')
    ax1.set_xlabel('Episode')
    ax1.set_ylabel('Reward')
    ax1.grid(True, linestyle='--', alpha=0.6)

    ax2 = axes[1]
    ax2.fill_between(results_df['timesteps'], results_df['lower_bound'],
                     results_df['upper_bound'], alpha=0.2, color=color)
    ax2.plot(results_df['timesteps'], results_df['r'], alpha=0.15, color='gray')
    ax2.plot(results_df['timesteps'], results_df['reward_mean'], linewidth=2, color=color)
    ax2.set_title(f'{experiment_name} - Reward Over Timesteps')
    ax2.set_xlabel('Timesteps')
    ax2.set_ylabel('Reward')
    ax2.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Plot saved to: {save_path}")

    plt.show()

    print(f"Total Episodes: {len(results_df)}")
    print(f"Total Timesteps: {results_df['timesteps'].iloc[-1]:,}")
    print(f"Best Reward: {results_df['r'].max():.2f}")
    print(f"Mean Reward: {results_df['r'].mean():.2f} Â± {results_df['r'].std():.2f}")

    return results_df


# ==============================================================================
# SECTION 11: Main Entry Point
# ==============================================================================

if __name__ == '__main__':
    print("=" * 60)
    print("LSTM-SAC for Cancer Chemotherapy Environment")
    print("(Using Repository RNNBase/ContextualModel Pattern with LSTM)")
    print("=" * 60)

    TOTAL_TIMESTEPS = 400000
    EXPERIMENT_NAME = "LSTM_SAC_ENV1"
    RESULTS_DIR = "./results"
    MODELS_DIR = "./models"

    os.makedirs(RESULTS_DIR, exist_ok=True)
    os.makedirs(MODELS_DIR, exist_ok=True)

    env = create_AhnChemoEnv_setting1()

    # agent = LSTMSAC(
    #     env=env,
    #     seed=42,
    #     gamma=0.99,
    #     tau=0.005,
    #     batch_size=128,
    #     embed_dim=64,
    #     embed_hidden=[64],
    #     uni_hidden=[256, 256],
    #     warmup_steps=1000,
    #     update_interval=4,
    #     max_traj_len=600,
    #     chunk_len=96,
    #     burn_in=32
    # )

    agent = LSTMSAC(
        env=env,
        seed=42,
        gamma=0.99,

        # 1. Stability: Lower tau significantly for long runs
        tau=0.002,           # (was 0.005) - Smooths target updates

        # 2. Stability: Lower LR slightly
        lr=1e-4,             # (Standard stable baseline, current was 1e-4 which is okay, but 3e-4 is often better for SAC if batch size is larger)
                             # actually, keep 1e-4 if 3e-4 is unstable, but 1e-4 is safe.
                             # RECOMMENDATION: Keep 1e-4 or try 3e-4 carefully. Let's stick to 1e-4 but increase batch size.

        # 3. Batch Size: Increase for better gradient estimation
        batch_size=256,      # (was 128) - Reduces variance

        # 4. Network Capacity: Increase universal network size
        embed_dim=64,        # (Keep as is)
        embed_hidden=[64],   # (Keep as is)
        uni_hidden=[512, 512], # (was [256, 256]) - Larger capacity for 400k steps

        # 5. Temporal Memory: Longer chunks for better history learning
        chunk_len=256,       # (was 96) - See more history
        burn_in=40,          # (was 32) - Better hidden state initialization

        warmup_steps=5000,   # (was 1000) - Gather more data before training starts
        update_interval=2,   # (was 4) - Update more frequently now that we have more steps
        max_traj_len=600,
    )

    rewards, lengths = agent.train(
        total_steps=TOTAL_TIMESTEPS,
        eval_interval=5000,
        log_interval=1000
    )

    model_path = os.path.join(MODELS_DIR, EXPERIMENT_NAME)
    agent.save(model_path)

    csv_path = os.path.join(RESULTS_DIR, "reward.csv")
    agent.save_rewards_csv(csv_path)

    plot_path = os.path.join(RESULTS_DIR, f"{EXPERIMENT_NAME}_training_curve.png")
    plot_training_reward_curve(
        ep_rewards=rewards,
        experiment_name=EXPERIMENT_NAME,
        ep_lengths=lengths,
        window_size=10,
        save_path=plot_path
    )

    print("\nFinal Evaluation:")
    final_reward = agent.evaluate(n_episodes=10)
    print(f"Mean Reward over 10 episodes: {final_reward:.2f}")


LSTM-SAC for Cancer Chemotherapy Environment
(Using Repository RNNBase/ContextualModel Pattern with LSTM)
Using device: cuda
LSTM config: embed_dim=64, embed_hidden=[64]
Starting LSTM-SAC training for 400000 steps...
Step 5000/400000 | Eps: 8 | Reward: -13.7 | Alpha: 0.2000 | Time: 5s
Rewards saved to: ./results/reward.csv
[EVAL] Step 5000 | Mean Reward: 65.88
Model saved to ./models/best
[BEST] New best model saved! Reward: 65.88
Step 6000/400000 | Eps: 10 | Reward: -47.1 | Alpha: 0.1915 | Time: 141s
Rewards saved to: ./results/reward.csv
Step 7000/400000 | Eps: 13 | Reward: -175.3 | Alpha: 0.1831 | Time: 276s
Rewards saved to: ./results/reward.csv
Step 8000/400000 | Eps: 15 | Reward: -237.0 | Alpha: 0.1738 | Time: 412s
Rewards saved to: ./results/reward.csv
Step 9000/400000 | Eps: 18 | Reward: -387.9 | Alpha: 0.1648 | Time: 547s
Rewards saved to: ./results/reward.csv
Step 10000/400000 | Eps: 23 | Reward: -379.9 | Alpha: 0.1573 | Time: 682s
Rewards saved to: ./results/reward.csv
[EVAL