# Deep $Q$-learning

In [10]:
import random
from collections import deque
from copy import deepcopy

import gymnasium as gym
import torch
from torch import nn, optim

torch.manual_seed(75)  # set random seeds for reproducibility 
random.seed(75)  

In [None]:
class DeepQNetwork(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        output_dim: int,
        activation_fn: nn.Module,
    ):
        super().__init__()
        assert len(hidden_dims) > 0, "Must have at least one hidden layer."

        # add input layer
        self.layers = nn.ModuleList(
            [nn.Linear(in_features=input_dim, out_features=hidden_dims[0])]
        )
        self.layers.append(activation_fn())

        # add hidden layers
        for i in range(1, len(hidden_dims) - 1):
            self.layers.append(
                nn.Linear(in_features=hidden_dims[i - 1], out_features=hidden_dims[i])
            )
            self.layers.append(activation_fn())

        # add output layer
        self.layers.append(
            nn.Linear(in_features=hidden_dims[-1], out_features=output_dim)
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            input = layer(input)

        pass


class DQNTrainer:

    def __init__(
        self,
        qnet: nn.Module,
        num_actions: int,
        buffer_size: int,
        batch_size: int,
        gamma: float,
        epsilon_decay: float,
        epsilon_min: float,
        network_update_interval: int,
        learning_rate: float
    ):
        # initialize replay buffer/memory
        self.replay_buffer = deque([], maxlen=buffer_size)
        self.batch_size = batch_size
        self.qnet = qnet
        self.target_qnet = deepcopy(qnet)
        self.optimizer = optim.Adam(self.qnet.parameters(), lr=learning_rate)

        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.num_actions = num_actions
        self.network_update_interval = network_update_interval

        self.loss_func = nn.MSELoss()

    def select_action(self, state: torch.Tensor) -> int:
        if self.epsilon > random.uniform(0, 1):
            # select random action from [0,1,...,num_actions]
            action = random.randrange(self.num_actions)
        else:
            Q_values = self.qnet(state)
            action = torch.argmax(Q_values)

        return action


    def compute_loss(self) -> torch.Tensor:
        pass

        
    def get_labels(self) -> torch.Tensor:
        labels = []
        for memory in self.replay_buffer:
            state = memory[0]
            action = memory[1]
            reward = memory[2]
            next_state = memory[3]
            next_state_is_terminal = memory[4]
            if next_state_is_terminal:
                labels.append(reward)
            else:
                labels.append( reward + self.gamma* torch.max(self.target_qnet(next_state)).item() )

        return torch.tensor(labels)

    def train(self, num_episodes: int, max_timesteps: int):
        env = gym.make("LunarLander-v3", render_mode=None)

        for m in range(num_episodes):
            # reset env at start of episode and get starting state
            state, _ = env.reset()

            for t in range(max_timesteps):
                action = self.select_action(state)
                
                new_state, reward, terminated, truncated, info = env.step(action)
                self.replay_buffer.append( (state, action, reward, new_state, terminated) )


                if len(self.replay_buffer) >= self.batch_size:
                    # sample uniformly without replacement
                    minibatch = random.sample(self.replay_buffer, self.batch_size)
                    labels = self.get_labels() #torch.tensor( [  memory[2] if memory[4] else memory[2] + self.gamma*torch.max(self.qnet(state))  ] )

                    self.optimizer.zero_grad()
                    loss = self.compute_loss()
                    loss.backward()
                    self.optimizer.step()

                    if t % self.network_update_interval == 0:
                        qnet_state_dict = self.qnet.state_dict()
                        self.target_qnet.load_state_dict(qnet_state_dict)  # \hat{Q} = Q
                
                state = new_state

In [None]:
qnet = DeepQNetwork(200, [100, 50, 25, 10, 5], 2, nn.ReLU)


DeepQNetwork(
  (layers): ModuleList(
    (0): Linear(in_features=200, out_features=100, bias=True)
    (1): Tanh()
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): Tanh()
    (4): Linear(in_features=50, out_features=25, bias=True)
    (5): Tanh()
    (6): Linear(in_features=25, out_features=10, bias=True)
    (7): Tanh()
    (8): Linear(in_features=5, out_features=2, bias=True)
  )
)


In [7]:
env = gym.make("LunarLander-v3", render_mode=None)

In [9]:
len(env.reset())

2

In [6]:
env.reset()
action = env.action_space.sample()

# step (transition) through the environment with the action
# receiving the next observation, reward and if the episode has terminated or truncated
observation, reward, terminated, truncated, info = env.step(action)

print(observation, reward, terminated, truncated, info)

[-0.01465969  1.3903913  -0.73513913 -0.4691289   0.01479567  0.12495516
  0.          0.        ] 0.02433743160219251 False False {}
