In [1]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque

device = torch.device(
    "mps" if torch.backends.mps.is_available() else 
    "cuda" if torch.cuda.is_available() else 
    "cpu"
)
print(f"Using device: {device}")

################################################################################
# 1) A simple StockTrading environment (single-stock, discrete action).
################################################################################
class StockTradingEnv:
    """
    A simplified environment for stock trading based on a single price time series.
    
    State:  A window of the most recent price data (e.g., last N prices).
    Action: Discrete {0,1,2} corresponding to [SELL, HOLD, BUY].
    Reward: Change in (unrealized/realized) PnL whenever we switch positions or close out.
    
    This example does not include transaction costs, slippage, or risk controls.
    """
    def __init__(self, prices, window_size=30, initial_capital=10000, max_steps=1000):
        """
        :param prices:        (np.ndarray) 1D array of daily (or t-step) prices.
        :param window_size:   (int) number of past observations for state.
        :param initial_capital: (float) initial capital for PnL calculations (not fully used here).
        :param max_steps:     (int) maximum number of steps in an episode.
        """
        self.prices = prices
        self.window_size = window_size
        self.initial_capital = initial_capital
        self.max_steps = min(max_steps, len(prices) - window_size - 1)
        
        # Internal states
        self.current_step = None
        self.done = None
        self.position = None  # +1 long, 0 flat, -1 short
        self.capital = None
        self.last_price = None

        self.reset()

    def reset(self):
        """
        Resets the environment to a starting state.
        """
        self.current_step = 0
        self.done = False
        self.position = 0
        self.capital = self.initial_capital
        self.last_price = self.prices[self.window_size - 1]
        return self._get_observation()

    def _get_observation(self):
        """
        Returns the current state: the last `window_size` prices.
        Shape: (window_size,)
        """
        start = self.current_step
        end = self.current_step + self.window_size
        window_prices = self.prices[start:end]
        return window_prices

    def step(self, action):
        """
        Executes the chosen action.
        :param action: int in {0, 1, 2}, mapped to [-1, 0, +1].
        :return: (next_state, reward, done, info)
        """
        if self.done:
            return self._get_observation(), 0.0, True, {}

        # Map discrete action into [-1, 0, +1]
        if action == 0:
            new_position = -1
        elif action == 1:
            new_position = 0
        else:
            new_position = 1

        current_price = self.prices[self.current_step + self.window_size - 1]
        reward = 0.0

        # If we had a position, realize PnL from last_price to current_price
        if self.position != 0:
            reward += (current_price - self.last_price) * self.position

        # Update position
        self.position = new_position
        self.last_price = current_price

        # Move forward
        self.current_step += 1
        if self.current_step >= self.max_steps:
            self.done = True

        # If done, close out any position at the final price
        if self.done and self.position != 0:
            final_price = self.prices[self.current_step + self.window_size - 1]
            reward += (final_price - self.last_price) * self.position

        # Prepare for next observation
        next_state = self._get_observation()
        return next_state, reward, self.done, {}


################################################################################
# 2) CNN Q-Network (1D).
################################################################################
class CNNQNetwork(nn.Module):
    """
    A 1D CNN that outputs Q-values for each of 3 actions: SELL, HOLD, BUY.
    Input shape: (batch_size, 1, window_size)
    Output shape: (batch_size, num_actions=3)
    """
    def __init__(self, window_size=30, num_actions=3):
        super(CNNQNetwork, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=5, stride=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, stride=1)
        # After two conv layers (each kernel_size=5, stride=1):
        # output length = window_size - 4 - 4 = window_size - 8
        # So flatten dimension = 32*(window_size - 8)
        self.fc1 = nn.Linear(32*(window_size - 8), 64)
        self.fc2 = nn.Linear(64, num_actions)

    def forward(self, x):
        # x: (batch_size, 1, window_size)
        x = torch.relu(self.conv1(x))    # => (batch_size, 16, window_size-4)
        x = torch.relu(self.conv2(x))    # => (batch_size, 32, window_size-8)
        x = x.view(x.size(0), -1)        # flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)                  # => (batch_size, num_actions)
        return x

################################################################################
# 3) Replay Buffer for DQN.
################################################################################
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)

################################################################################
# 4) Utility functions for training.
################################################################################
def get_epsilon(it, max_it, min_epsilon=0.01, max_epsilon=1.0):
    """
    Linearly decay epsilon from max_epsilon to min_epsilon over max_it iterations.
    """
    slope = -(max_epsilon - min_epsilon) / max_it
    epsilon = max(min_epsilon, max_epsilon + slope * it)
    return epsilon

def process_state_cnn(state):
    """
    Convert the environment's observation (window_size,) into shape (1,1,window_size)
    for the CNN forward pass.
    """
    state_t = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0)
    return state_t

################################################################################
# 5) The DQN training loop.
################################################################################
def train_dqn(env, num_episodes=100, window_size=30, gamma=0.99,
              lr=1e-3, batch_size=32, max_steps_per_episode=1000):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    q_net = CNNQNetwork(window_size=window_size, num_actions=3).to(device)
    optimizer = optim.Adam(q_net.parameters(), lr=lr)
    
    replay_buffer = ReplayBuffer(capacity=10000)
    episode_rewards = []
    
    max_iterations = num_episodes * max_steps_per_episode
    iteration = 0
    
    for episode in range(num_episodes):
        state = env.reset()
        episode_reward = 0.0
        
        for step in range(max_steps_per_episode):
            iteration += 1
            epsilon = get_epsilon(iteration, max_iterations)
            
            # Epsilon-greedy action selection
            if random.random() < epsilon:
                action = random.choice([0,1,2])  # SELL=0, HOLD=1, BUY=2
            else:
                s_t = process_state_cnn(state).to(device)
                with torch.no_grad():
                    q_values = q_net(s_t)  # shape: (1, 3)
                    action = q_values.argmax(dim=1).item()
            
            next_state, reward, done, _ = env.step(action)
            
            replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            
            # Sample from replay and train
            if len(replay_buffer) > batch_size:
                states_b, actions_b, rewards_b, next_states_b, dones_b = replay_buffer.sample(batch_size)

                # Convert to tensors
                states_b_t = torch.cat([process_state_cnn(s) for s in states_b]).to(device) 
                actions_b_t = torch.LongTensor(actions_b).to(device)
                rewards_b_t = torch.FloatTensor(rewards_b).to(device)
                next_states_b_t = torch.cat([process_state_cnn(ns) for ns in next_states_b]).to(device)
                dones_b_t = torch.FloatTensor(dones_b).to(device)
                
                # current Q(s, a)
                q_values_b = q_net(states_b_t)  # (batch_size, 3)
                q_values_chosen = q_values_b.gather(1, actions_b_t.unsqueeze(1)).squeeze(1)
                
                # target Q = r + gamma * max_a' Q(s', a') if not done
                with torch.no_grad():
                    q_next = q_net(next_states_b_t)         # (batch_size, 3)
                    q_next_max = q_next.max(dim=1)[0]       # (batch_size,)
                    q_target = rewards_b_t + gamma * q_next_max * (1 - dones_b_t)
                
                loss = nn.MSELoss()(q_values_chosen, q_target)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            if done:
                break

        episode_rewards.append(episode_reward)
        print(f"Episode {episode+1}/{num_episodes}, Reward: {episode_reward:.2f}, Epsilon: {epsilon:.3f}")
        
    return q_net, episode_rewards

################################################################################
# 6) Running a trained agent (for inference).
################################################################################
def run_trained_agent(env, q_net, max_steps=1000):
    """
    Run a trained agent (q_net) on env (no exploration).
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state = env.reset()
    total_reward = 0.0
    done = False
    steps = 0

    while not done and steps < max_steps:
        s_t = process_state_cnn(state).to(device)
        with torch.no_grad():
            q_values = q_net(s_t)  # (1, 3)
            action = q_values.argmax(dim=1).item()

        next_state, reward, done, _ = env.step(action)
        total_reward += reward
        state = next_state
        steps += 1

    return total_reward

################################################################################
# 7) Main - put it all together.
################################################################################
if __name__ == "__main__":
    # 1) Generate synthetic prices (Geometric Brownian Motion) as an example
    def generate_synthetic_prices(T=2000, s0=100, mu=0.0005, sigma=0.01):
        prices = [s0]
        for t in range(1, T):
            prices.append(
                prices[-1] * math.exp((mu - 0.5*sigma**2) + sigma*random.gauss(0,1))
            )
        return np.array(prices, dtype=np.float32)

    prices_array = generate_synthetic_prices(T=3000, s0=100)
    
    # 2) Create the environment
    window_size = 30
    env = StockTradingEnv(prices_array, window_size=window_size, initial_capital=10000, max_steps=1000)

    # 3) Train DQN
    trained_qnet, rewards_history = train_dqn(env,
                                              num_episodes=50,
                                              window_size=window_size,
                                              gamma=0.99,
                                              lr=1e-3,
                                              batch_size=32,
                                              max_steps_per_episode=1000)
    print("\nTraining complete!\n")

    # 4) Test / run the trained agent
    test_reward = run_trained_agent(env, trained_qnet)
    print(f"Test reward with the trained policy: {test_reward:.2f}")


Episode 1/50, Reward: 32.52, Epsilon: 0.980
Episode 2/50, Reward: 21.32, Epsilon: 0.960
Episode 3/50, Reward: 16.96, Epsilon: 0.941
Episode 4/50, Reward: -4.50, Epsilon: 0.921
Episode 5/50, Reward: 92.08, Epsilon: 0.901
Episode 6/50, Reward: 51.25, Epsilon: 0.881
Episode 7/50, Reward: 2.92, Epsilon: 0.861
Episode 8/50, Reward: 48.86, Epsilon: 0.842
Episode 9/50, Reward: 2.17, Epsilon: 0.822
Episode 10/50, Reward: 17.43, Epsilon: 0.802
Episode 11/50, Reward: -69.11, Epsilon: 0.782
Episode 12/50, Reward: -4.97, Epsilon: 0.762
Episode 13/50, Reward: -50.35, Epsilon: 0.743
Episode 14/50, Reward: -25.98, Epsilon: 0.723
Episode 15/50, Reward: -24.96, Epsilon: 0.703
Episode 16/50, Reward: -55.85, Epsilon: 0.683
Episode 17/50, Reward: 30.28, Epsilon: 0.663
Episode 18/50, Reward: -54.07, Epsilon: 0.644
Episode 19/50, Reward: 8.82, Epsilon: 0.624
Episode 20/50, Reward: -51.01, Epsilon: 0.604
Episode 21/50, Reward: 41.60, Epsilon: 0.584
Episode 22/50, Reward: 90.74, Epsilon: 0.564
Episode 23/50, 