In [67]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces


class GridEnv(gym.Env):
    def __init__(self, size=5):
        super().__init__()
        self.size = size

        # corresponding to "right", "up", "left", "down"
        self.action_space = spaces.Discrete(4)
        # Each location is encoded as an element of {0, ..., `size`-1}^2
        self.observation_space = gym.spaces.Dict(
            {
                "agent": gym.spaces.Box(0, size - 1, shape=(2,), dtype=np.int64),
                "target": gym.spaces.Box(0, size - 1, shape=(2,), dtype=np.int64),
            }
        )

        self._agent_location = np.array([-1, -1], dtype=int)
        self._target_location = np.array([-1, -1], dtype=int)
        self._action_to_direction = {
            0: np.array([1, 0]),  # right
            1: np.array([0, 1]),  # up
            2: np.array([-1, 0]),  # left
            3: np.array([0, -1]),  # down
        }

    def reset(self, seed: int | None = None, options: dict | None = None):
        super().reset(seed=seed)

        # Choose the agent's location uniformly at random
        self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)

        # We will sample the target's location randomly until it does not coincide with the agent's location
        self._target_location = self._agent_location
        while np.array_equal(self._target_location, self._agent_location):
            self._target_location = self.np_random.integers(
                0, self.size, size=2, dtype=int
            )

        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def step(self, action: int):
        # Map the action (element of {0,1,2,3}) to the direction we walk in
        direction = self._action_to_direction[action]
        # We use `np.clip` to make sure we don't leave the grid bounds
        self._agent_location = np.clip(
            self._agent_location + direction, 0, self.size - 1
        )

        # An environment is completed if and only if the agent has reached the target
        terminated = np.array_equal(self._agent_location, self._target_location)
        truncated = False
        reward = (
            1 if terminated else 0
        )  # the agent is only reached at the end of the episode
        observation = self._get_obs()
        info = self._get_info()

        return observation, reward, terminated, truncated, info

    def _get_obs(self):
        return {"agent": self._agent_location, "target": self._target_location}

    def _get_info(self):
        return {
            "distance": np.linalg.norm(
                self._agent_location - self._target_location, ord=1
            )
        }

In [68]:
import torch
import numpy as np
from torch import nn
import random
import torch.nn.functional as F
import collections
from torch.optim.lr_scheduler import StepLR

In [69]:
"""
The Q-Network has as input a state s and outputs the state-action values q(s,a_1), ..., q(s,a_n) for all n actions.
"""


class QNetwork(nn.Module):
    def __init__(self, action_dim, state_dim, hidden_dim):
        super(QNetwork, self).__init__()

        self.fc_1 = nn.Linear(state_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, inp):

        x1 = F.leaky_relu(self.fc_1(inp))
        x1 = F.leaky_relu(self.fc_2(x1))
        x1 = self.fc_3(x1)

        return x1

In [70]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Define the structure for storing transitions
from collections import deque, namedtuple


Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

# Define the Replay Memory class
class ReplayMemory(object):
    """ Stores transitions and allows sampling batches. """
    def __init__(self, capacity: int):
        """
        Initialize the Replay Memory.

        Parameters:
        - capacity (int): Maximum number of transitions to store.
        """
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """
        Save a transition.

        Parameters:
        - *args: The transition elements (state, action, next_state, reward, done).
        """
        self.memory.append(Transition(*args))

    def sample(self, batch_size: int) -> list[Transition]:
        """
        Sample a random batch of transitions from memory.

        Parameters:
        - batch_size (int): The number of transitions to sample.

        Returns:
        - List[Transition]: A list containing the sampled transitions.
        """
        return random.sample(self.memory, batch_size)

    def __len__(self) -> int:
        """ Return the current size of the memory. """
        return len(self.memory)

In [None]:
from typing import Literal


def optimize_model(memory: ReplayMemory,
                   policy_net: nn.Module,
                   target_net: nn.Module,
                   optimizer,
                   batch_size: int,
                   gamma: float,
                   criterion: nn.Module = nn.SmoothL1Loss(), 
                   method:Literal["dqn", "ddqn"]="ddqn") -> float|None:
    """
    Performs one step of optimization on the policy network.

    Parameters:
    - memory (ReplayMemory): The replay memory containing past transitions.
    - policy_net (nn.Module): The main Q-network being optimized.
    - target_net (nn.Module): The target Q-network used for stable target computation.
    - optimizer (optim.Optimizer): The optimizer for updating the policy network.
    - batch_size (int): The number of transitions to sample for each optimization step.
    - gamma (float): The discount factor for future rewards.
    - criterion (nn.Module): The loss function to use (default: SmoothL1Loss).

    Returns:
    - Optional[float]: The loss value for the optimization step, or None if not enough samples.
    """
    # Ensure there are enough samples in memory to perform optimization
    if len(memory) < batch_size:
        return None

    # Sample a batch of transitions from replay memory
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))  # Unpack transitions into separate components

    # Identify non-final states (states that are not terminal)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)),
                                  device=device, dtype=torch.bool)

    # Stack non-final next states into a tensor

    # Stack current states, actions, rewards, and dones into tensors
    state_batch = torch.stack(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    done_batch = torch.cat(batch.done)

    # Compute Q(s_t, a) for the actions taken
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for the next states using the target network
    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        if any(non_final_mask):  # Only compute for non-final states
            non_final_next_states = torch.stack([s for s in batch.next_state if s is not None])
            if method == "dqn":
                next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
            elif method == "ddqn":
                max_actions = policy_net(non_final_next_states).max(1)[1].unsqueeze(1)
                next_state_values[non_final_mask] = target_net(non_final_next_states).gather(1, max_actions).squeeze(1)


    # Compute the expected Q values using the Bellman equation
    expected_state_action_values = (next_state_values * gamma) + reward_batch

    # Compute the loss between predicted and expected Q values
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Perform backpropagation and optimization
    optimizer.zero_grad()  # Clear previous gradients
    loss.backward()  # Compute gradients
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)  # Clip gradients to prevent explosion
    optimizer.step()  # Update the policy network

    return loss.item()  # Return the loss value for logging

In [None]:
from tqdm import tqdm


def select_action(model, num_action, state, eps):
    state_tensor = torch.Tensor(state["agent"]).to(device)
    with torch.no_grad():
        values = model(state_tensor)

    # select a random action wih probability eps
    if random.random() <= eps:
        action = np.random.randint(0, num_action)
    else:
        action = np.argmax(values.cpu().numpy())

    return action


def train(batch_size, current, target, optim, memory: ReplayMemory, gamma):

    states, actions, next_states, rewards, is_done = memory.sample(batch_size)

    q_values = current(states)

    next_q_values = current(next_states)
    next_q_state_values = target(next_states)

    q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
    next_q_value = next_q_state_values.gather(1, torch.max(next_q_values, 1)[1].unsqueeze(1)).squeeze(1)
    expected_q_value = rewards + gamma * next_q_value * (1 - is_done)

    loss = (q_value - expected_q_value.detach()).pow(2).mean()

    optim.zero_grad()
    loss.backward()
    optim.step()


def evaluate(Qmodel, env, repeats, max_step):
    """
    Runs a greedy policy with respect to the current Q-Network for "repeats" many episodes. Returns the average
    episode reward.
    """
    Qmodel.eval()
    perform = 0
    for _ in tqdm(range(repeats), "Evaluate"):
        state, info = env.reset()
        for _ in range(max_step):
            agent_state = torch.Tensor(state["agent"]).to(device)
            with torch.no_grad():
                values = Qmodel(agent_state)
            action = np.argmax(values.cpu().numpy())
            state, reward, term, trun, info = env.step(action)
            perform += reward
            
            if term or trun:
                break

    Qmodel.train()
    return perform/repeats


def update_parameters(current_model, target_model):
    target_model.load_state_dict(current_model.state_dict())



In [73]:
def main(
    gamma=0.99,
    lr=1e-3,
    min_episodes=20,
    eps=1,
    eps_decay=0.998,
    eps_min=0.01,
    update_step=10,
    batch_size=64,
    update_repeats=50,
    num_episodes=3000,
    seed=42,
    max_memory_size=5000,
    lr_gamma=1,
    lr_step=100,
    measure_step=100,
    measure_repeats=100,
    hidden_dim=64,
    cnn=False,
    horizon=100,
    render=True,
    render_step=50,
):
    """
    Remark: Convergence is slow. Wait until around episode 2500 to see good performance.

    :param gamma: reward discount factor
    :param lr: learning rate for the Q-Network
    :param min_episodes: we wait "min_episodes" many episodes in order to aggregate enough data before starting to train
    :param eps: probability to take a random action during training
    :param eps_decay: after every episode "eps" is multiplied by "eps_decay" to reduces exploration over time
    :param eps_min: minimal value of "eps"
    :param update_step: after "update_step" many episodes the Q-Network is trained "update_repeats" many times with a
    batch of size "batch_size" from the memory.
    :param batch_size: see above
    :param update_repeats: see above
    :param num_episodes: the number of episodes played in total
    :param seed: random seed for reproducibility
    :param max_memory_size: size of the replay memory
    :param lr_gamma: learning rate decay for the Q-Network
    :param lr_step: every "lr_step" episodes we decay the learning rate
    :param measure_step: every "measure_step" episode the performance is measured
    :param measure_repeats: the amount of episodes played in to asses performance
    :param hidden_dim: hidden dimensions for the Q_network
    :param env_name: name of the gym environment
    :param cnn: set to "True" when using environments with image observations like "Pong-v0"
    :param horizon: number of steps taken in the environment before terminating the episode (prevents very long episodes)
    :param render: if "True" renders the environment every "render_step" episodes
    :param render_step: see above
    :return: the trained Q-Network and the measured performances
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    env = GridEnv()

    Q_1 = QNetwork(
        action_dim=env.action_space.n,
        state_dim=env.observation_space["agent"].shape[0],
        hidden_dim=hidden_dim,
    ).to(device)
    Q_2 = QNetwork(
        action_dim=env.action_space.n,
        state_dim=env.observation_space["agent"].shape[0],
        hidden_dim=hidden_dim,
    ).to(device)
    # transfer parameters from Q_1 to Q_2
    update_parameters(Q_1, Q_2)

    # we only train Q_1
    for param in Q_2.parameters():
        param.requires_grad = False

    optimizer = torch.optim.Adam(Q_1.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_gamma)
    print("Initialized network components")

    memory = Memory(max_memory_size)
    performance = []

    for episode in range(num_episodes):
        # display the performance
        if (episode % measure_step == 0) and episode >= min_episodes:
            performance.append([episode, evaluate(Q_1, env, measure_repeats, 100)])
            # performance.append([episode, 0])
            print("Episode: ", episode)
            print("rewards: ", performance[-1][1])
            print("lr: ", scheduler.get_last_lr()[0])
            print("eps: ", eps)

        state, info = env.reset()
        memory.state.append(state)

        done = False
        i = 0
        while not done:
            i += 1
            action = select_action(Q_2, env, state, eps)
            state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

            if i > horizon:
                done = True

            # render the environment if render == True
            if render and episode % render_step == 0:
                try:
                    env.render()
                except NotImplementedError:
                    pass

            # save state, action, reward sequence
            memory.update(state, action, reward, done)

        has_trained = False
        if episode >= min_episodes and episode % update_step == 0:
            for _ in range(update_repeats):
                train(batch_size, Q_1, Q_2, optimizer, memory, gamma)
                has_trained = True

            # transfer new parameter from Q_1 to Q_2
            update_parameters(Q_1, Q_2)

        # update learning rate and eps
        if has_trained:
            scheduler.step()
            eps = max(eps * eps_decay, eps_min)

    return Q_1, performance


main()

Initialized network components


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 213.75it/s]


Episode:  100
rewards:  0.05
lr:  0.001
eps:  0.9841115531182099


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 206.67it/s]


Episode:  200
rewards:  0.03
lr:  0.001
eps:  0.9646055206870082


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 256.43it/s]


Episode:  300
rewards:  0.2
lr:  0.001
eps:  0.9454861164790007


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 218.78it/s]


Episode:  400
rewards:  0.08
lr:  0.001
eps:  0.9267456771529368


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 248.20it/s]


Episode:  500
rewards:  0.19
lr:  0.001
eps:  0.9083766912623201


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 200.90it/s]


Episode:  600
rewards:  0.06
lr:  0.001
eps:  0.8903717962447101


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 215.63it/s]


Episode:  700
rewards:  0.11
lr:  0.001
eps:  0.8727237754706968


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 254.34it/s]


Episode:  800
rewards:  0.2
lr:  0.001
eps:  0.8554255553513692


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 210.74it/s]


Episode:  900
rewards:  0.09
lr:  0.001
eps:  0.838470202503115


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 207.27it/s]


Episode:  1000
rewards:  0.07
lr:  0.001
eps:  0.8218509209686186


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 220.48it/s]


Episode:  1100
rewards:  0.1
lr:  0.001
eps:  0.805561049492939


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 220.83it/s]


Episode:  1200
rewards:  0.11
lr:  0.001
eps:  0.7895940588535815


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 219.57it/s]


Episode:  1300
rewards:  0.08
lr:  0.001
eps:  0.7739435492434863


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 211.15it/s]


Episode:  1400
rewards:  0.05
lr:  0.001
eps:  0.7586032477058929


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 237.72it/s]


Episode:  1500
rewards:  0.19
lr:  0.001
eps:  0.743567005620044


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 249.90it/s]


Episode:  1600
rewards:  0.23
lr:  0.001
eps:  0.7288287962367284


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 249.44it/s]


Episode:  1700
rewards:  0.19
lr:  0.001
eps:  0.7143827122626694


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 231.04it/s]


Episode:  1800
rewards:  0.1
lr:  0.001
eps:  0.7002229634927942


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 231.11it/s]


Episode:  1900
rewards:  0.12
lr:  0.001
eps:  0.6863438744894339


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 222.18it/s]


Episode:  2000
rewards:  0.09
lr:  0.001
eps:  0.6727398823075239


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 229.81it/s]


Episode:  2100
rewards:  0.11
lr:  0.001
eps:  0.6594055342648917


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 214.29it/s]


Episode:  2200
rewards:  0.05
lr:  0.001
eps:  0.6463354857567426


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 206.91it/s]


Episode:  2300
rewards:  0.07
lr:  0.001
eps:  0.6335244981134615


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 230.32it/s]


Episode:  2400
rewards:  0.13
lr:  0.001
eps:  0.6209674365008767


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 229.44it/s]


Episode:  2500
rewards:  0.18
lr:  0.001
eps:  0.6086592678621416


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 225.59it/s]


Episode:  2600
rewards:  0.1
lr:  0.001
eps:  0.5965950589004119


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 221.23it/s]


Episode:  2700
rewards:  0.09
lr:  0.001
eps:  0.5847699741015057


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 221.41it/s]


Episode:  2800
rewards:  0.12
lr:  0.001
eps:  0.5731792737957582


Evaluate: 100%|██████████| 100/100 [00:00<00:00, 206.64it/s]


Episode:  2900
rewards:  0.05
lr:  0.001
eps:  0.5618183122582913


(QNetwork(
   (fc_1): Linear(in_features=2, out_features=64, bias=True)
   (fc_2): Linear(in_features=64, out_features=64, bias=True)
   (fc_3): Linear(in_features=64, out_features=4, bias=True)
 ),
 [[100, 0.05],
  [200, 0.03],
  [300, 0.2],
  [400, 0.08],
  [500, 0.19],
  [600, 0.06],
  [700, 0.11],
  [800, 0.2],
  [900, 0.09],
  [1000, 0.07],
  [1100, 0.1],
  [1200, 0.11],
  [1300, 0.08],
  [1400, 0.05],
  [1500, 0.19],
  [1600, 0.23],
  [1700, 0.19],
  [1800, 0.1],
  [1900, 0.12],
  [2000, 0.09],
  [2100, 0.11],
  [2200, 0.05],
  [2300, 0.07],
  [2400, 0.13],
  [2500, 0.18],
  [2600, 0.1],
  [2700, 0.09],
  [2800, 0.12],
  [2900, 0.05]])