In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque

# Device Configuration
device = torch.device(
    "mps" if torch.backends.mps.is_available() else 
    "cuda" if torch.cuda.is_available() else 
    "cpu"
)
print(f"Using device: {device}")

################################################################################
# 1) StockTrading environment
################################################################################
class StockTradingEnv:
    def __init__(self, prices, window_size=30, initial_capital=10000, max_steps=1000):
        self.prices = prices
        self.window_size = window_size
        self.initial_capital = initial_capital
        self.max_steps = min(max_steps, len(prices) - window_size - 1)
        self.reset()

    def reset(self):
        self.current_step = 0
        self.done = False
        self.position = 0
        self.capital = self.initial_capital
        self.last_price = self.prices[self.window_size - 1]
        return self._get_observation()

    def _get_observation(self):
        start = self.current_step
        end = self.current_step + self.window_size
        return self.prices[start:end]

    def step(self, action):
        if self.done:
            return self._get_observation(), 0.0, True, {}

        new_position = -1 if action == 0 else (1 if action == 2 else 0)
        current_price = self.prices[self.current_step + self.window_size - 1]
        reward = 0.0

        if self.position != 0:
            reward += (current_price - self.last_price) * self.position

        self.position = new_position
        self.last_price = current_price
        self.current_step += 1

        if self.current_step >= self.max_steps:
            self.done = True
            if self.position != 0:
                final_price = self.prices[self.current_step + self.window_size - 1]
                reward += (final_price - self.last_price) * self.position

        return self._get_observation(), reward, self.done, {}

################################################################################
# 2) LSTM-Based Q-Network
################################################################################
class LSTMQNetwork(nn.Module):
    def __init__(self, window_size=30, hidden_dim=64, num_layers=2, num_actions=3):
        super(LSTMQNetwork, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=0.1)
        self.fc = nn.Linear(hidden_dim, num_actions)

    def forward(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(-1)
        lstm_out, _ = self.lstm(x)
        last_hidden = lstm_out[:, -1, :]
        q_vals = self.fc(last_hidden)
        return q_vals

################################################################################
# 3) Replay Buffer
################################################################################
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

################################################################################
# 4) Utility Functions
################################################################################
def get_epsilon(it, max_it, min_epsilon=0.01, max_epsilon=1.0):
    slope = -(max_epsilon - min_epsilon) / max_it
    return max(min_epsilon, max_epsilon + slope * it)

def process_state(state):
    if isinstance(state, np.ndarray):
        state = torch.FloatTensor(state).to(device)
    return state.unsqueeze(0)

################################################################################
# 5) Training Loop
################################################################################
def train_dqn(env, num_episodes=100, window_size=30, gamma=0.99, lr=1e-3, batch_size=32, max_steps_per_episode=1000):
    q_net = LSTMQNetwork(window_size=window_size, hidden_dim=64, num_layers=2, num_actions=3).to(device)
    optimizer = optim.Adam(q_net.parameters(), lr=lr)
    replay_buffer = ReplayBuffer(capacity=10000)
    episode_rewards = []

    max_iterations = num_episodes * max_steps_per_episode
    iteration = 0

    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0.0

        for step in range(max_steps_per_episode):
            iteration += 1
            epsilon = get_epsilon(iteration, max_iterations)

            if random.random() < epsilon:
                action = random.choice([0, 1, 2])
            else:
                s_t = process_state(state)
                with torch.no_grad():
                    q_values = q_net(s_t)
                    action = q_values.argmax(dim=1).item()

            next_state, reward, done, _ = env.step(action)
            replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward

            if len(replay_buffer) > batch_size:
                states_b, actions_b, rewards_b, next_states_b, dones_b = replay_buffer.sample(batch_size)
                states_b_t = torch.cat([process_state(s) for s in states_b])
                actions_b_t = torch.LongTensor(actions_b).to(device)
                rewards_b_t = torch.FloatTensor(rewards_b).to(device)
                next_states_b_t = torch.cat([process_state(ns) for ns in next_states_b])
                dones_b_t = torch.FloatTensor(dones_b).to(device)

                q_values_b = q_net(states_b_t)
                q_values_chosen = q_values_b.gather(1, actions_b_t.unsqueeze(1)).squeeze(1)

                with torch.no_grad():
                    q_next = q_net(next_states_b_t)
                    q_next_max = q_next.max(dim=1)[0]
                    q_target = rewards_b_t + gamma * q_next_max * (1 - dones_b_t)

                loss = nn.MSELoss()(q_values_chosen, q_target)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if done:
                break

        episode_rewards.append(episode_reward)
        print(f"Episode {episode+1}/{num_episodes}, Reward: {episode_reward:.2f}, Eps: {epsilon:.3f}")

    return q_net, episode_rewards

################################################################################
# 6) Test Trained Agent
################################################################################
def run_trained_agent(env, q_net, max_steps=1000):
    state = env.reset()
    total_reward = 0.0
    done = False

    while not done:
        s_t = process_state(state)
        with torch.no_grad():
            q_values = q_net(s_t)
            action = q_values.argmax(dim=1).item()

        next_state, reward, done, _ = env.step(action)
        total_reward += reward
        state = next_state

    return total_reward

################################################################################
# 7) Main Function
################################################################################
if __name__ == "__main__":
    def generate_synthetic_prices(T=3000, s0=100, mu=0.0005, sigma=0.01):
        prices = [s0]
        for t in range(1, T):
            prices.append(prices[-1] * math.exp((mu - 0.5 * sigma**2) + sigma * random.gauss(0, 1)))
        return np.array(prices, dtype=np.float32)

    prices_array = generate_synthetic_prices(T=3000, s0=100)
    window_size = 30
    env = StockTradingEnv(prices_array, window_size=window_size, initial_capital=10000, max_steps=1000)

    trained_qnet, rewards_history = train_dqn(env, num_episodes=50, window_size=window_size, gamma=0.99, lr=1e-4, batch_size=32, max_steps_per_episode=1000)
    print("\nTraining complete!\n")

    test_reward = run_trained_agent(env, trained_qnet)
    print(f"Test reward with trained LSTM policy: {test_reward:.2f}")


Using device: mps
Episode 1/50, Reward: -14.39, Eps: 0.980
Episode 2/50, Reward: -15.27, Eps: 0.960
Episode 3/50, Reward: -2.62, Eps: 0.941
Episode 4/50, Reward: 28.90, Eps: 0.921
Episode 5/50, Reward: 78.06, Eps: 0.901
Episode 6/50, Reward: 27.74, Eps: 0.881
Episode 7/50, Reward: 31.92, Eps: 0.861
Episode 8/50, Reward: -31.73, Eps: 0.842
Episode 9/50, Reward: -31.07, Eps: 0.822
Episode 10/50, Reward: 33.55, Eps: 0.802
Episode 11/50, Reward: -1.77, Eps: 0.782
Episode 12/50, Reward: 46.01, Eps: 0.762
Episode 13/50, Reward: -19.34, Eps: 0.743
Episode 14/50, Reward: -11.05, Eps: 0.723
Episode 15/50, Reward: -37.24, Eps: 0.703
Episode 16/50, Reward: -3.09, Eps: 0.683
Episode 17/50, Reward: 25.56, Eps: 0.663
Episode 18/50, Reward: -16.61, Eps: 0.644
Episode 19/50, Reward: -7.14, Eps: 0.624
Episode 20/50, Reward: -14.82, Eps: 0.604
Episode 21/50, Reward: -24.89, Eps: 0.584
Episode 22/50, Reward: 22.75, Eps: 0.564
Episode 23/50, Reward: -6.66, Eps: 0.545
Episode 24/50, Reward: 20.56, Eps: 0.5