## Configurations for Colab

In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules

if IN_COLAB:
    !apt install python-opengl
    !apt install ffmpeg
    !apt install xvfb
    !pip install PyVirtualDisplay==3.0
    !pip install gym
    from pyvirtualdisplay import Display

    # Start virtual display
    dis = Display(visible=0, size=(400, 400))
    dis.start()

# 03. Prioritized Experience Replay (PER)

[T. Schaul et al., "Prioritized Experience Replay." arXiv preprint arXiv:1511.05952, 2015.](https://arxiv.org/pdf/1511.05952.pdf)

Using a replay memory leads to design choices at two levels: which experiences to store, and which experiences to replay (and how to do so). This paper addresses only the latter: making the most effective use of the replay memory for learning, assuming that its contents are outside of our control.

The central component of prioritized replay is the criterion by which the importance of each transition is measured. A reasonable approach is to use the magnitude of a transition’s TD error $\delta$, which indicates how ‘surprising’
or unexpected the transition is. This algorithm stores the last encountered TD error along with each transition in the replay memory. The transition with the largest absolute TD error is replayed from the memory. A Q-learning update
is applied to this transition, which updates the weights in proportion to the TD error. One thing to note that new transitions arrive without a known TD-error, so it puts them at maximal priority in order to guarantee that all experience is seen at least once. (see *store* method)

We might use 2 ideas to deal with TD-error: 1. greedy TD-error prioritization, 2. stochastic prioritization. However, greedy TD-error prioritization has a severe drawback. Greedy prioritization focuses on a small subset of the experience: errors shrink slowly, especially when using function approximation, meaning that the initially high error transitions get replayed frequently. This lack of diversity that makes the system prone to over-fitting. To overcome this issue, we will use a stochastic sampling method that interpolates between pure greedy prioritization and uniform random sampling.

$$
P(i) = \frac{p_i^{\alpha}}{\sum_k p_k^{\alpha}}
$$

where $p_i > 0$ is the priority of transition $i$. The exponent $\alpha$ determines how much prioritization is used, with $\alpha = 0$ corresponding to the uniform case. In practice, we use additional term $\epsilon$ in order to guarantee all transactions can be possibly sampled: $p_i = |\delta_i| + \epsilon$, where $\epsilon$ is a small positive constant.

One more. Let's recall one of the main ideas of DQN. To remove correlation of observations, it uses uniformly random sampling from the replay buffer. Prioritized replay introduces bias because it doesn't sample experiences uniformly at random due to the sampling proportion correspoding to TD-error. We can correct this bias by using importance-sampling (IS) weights

$$
w_i = \big( \frac{1}{N} \cdot \frac{1}{P(i)} \big)^\beta
$$

that fully compensates for the non-uniform probabilities $P(i)$ if $\beta = 1$. These weights can be folded into the Q-learning update by using $w_i\delta_i$ instead of $\delta_i$. In typical reinforcement learning scenarios, the unbiased nature of the updates is most important near convergence at the end of training, We therefore exploit the flexibility of annealing the amount of importance-sampling correction over time, by defining a schedule on the exponent $\beta$ that reaches 1 only at the end of learning.

In [None]:
import os
import random
from typing import Dict, List, Tuple

import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.display import clear_output

if IN_COLAB and not os.path.exists("segment_tree.py"):
    # download segment tree module
    !wget https://raw.githubusercontent.com/curt-park/rainbow-is-all-you-need/master/segment_tree.py

from segment_tree import MinSegmentTree, SumSegmentTree

## Replay buffer

Please see *01.dqn.ipynb* for detailed description.

In [None]:
class ReplayBuffer:
    """A simple numpy replay buffer."""

    def __init__(self, obs_dim: int, size: int, batch_size: int = 32):
        self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.next_obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size], dtype=np.float32)
        self.rews_buf = np.zeros([size], dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.max_size, self.batch_size = size, batch_size
        self.ptr, self.size, = 0, 0

    def store(
        self,
        obs: np.ndarray,
        act: np.ndarray,
        rew: float,
        next_obs: np.ndarray,
        done: bool,
    ):
        self.obs_buf[self.ptr] = obs
        self.next_obs_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self) -> Dict[str, np.ndarray]:
        idxs = np.random.choice(self.size, size=self.batch_size, replace=False)
        return dict(obs=self.obs_buf[idxs],
                    next_obs=self.next_obs_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

    def __len__(self) -> int:
        return self.size

## Prioritized replay Buffer

The key concept of PER's implementation is *Segment Tree*. It efficiently stores and samples transitions while managing the priorities of them. We recommend you understand how it works before you move on. Here are references for you:

- In Korean: https://mrsyee.github.io/rl/2019/01/25/PER-sumtree/
- In English: https://www.geeksforgeeks.org/segment-tree-set-1-sum-of-given-range/

In [None]:
class PrioritizedReplayBuffer(ReplayBuffer):
    """Prioritized Replay buffer.

    Attributes:
        max_priority (float): max priority
        tree_ptr (int): next index of tree
        alpha (float): alpha parameter for prioritized replay buffer
        sum_tree (SumSegmentTree): sum tree for prior
        min_tree (MinSegmentTree): min tree for min prior to get max weight

    """

    def __init__(
        self,
        obs_dim: int,
        size: int,
        batch_size: int = 32,
        alpha: float = 0.6
    ):
        """Initialization."""
        assert alpha >= 0

        super(PrioritizedReplayBuffer, self).__init__(obs_dim, size, batch_size)
        self.max_priority, self.tree_ptr = 1.0, 0
        self.alpha = alpha

        # capacity must be positive and a power of 2.
        tree_capacity = 1
        while tree_capacity < self.max_size:
            tree_capacity *= 2

        self.sum_tree = SumSegmentTree(tree_capacity)
        self.min_tree = MinSegmentTree(tree_capacity)

    def store(
        self,
        obs: np.ndarray,
        act: int,
        rew: float,
        next_obs: np.ndarray,
        done: bool
    ):
        """Store experience and priority."""
        super().store(obs, act, rew, next_obs, done)

        self.sum_tree[self.tree_ptr] = self.max_priority ** self.alpha
        self.min_tree[self.tree_ptr] = self.max_priority ** self.alpha
        self.tree_ptr = (self.tree_ptr + 1) % self.max_size

    def sample_batch(self, beta: float = 0.4) -> Dict[str, np.ndarray]:
        """Sample a batch of experiences."""
        assert len(self) >= self.batch_size
        assert beta > 0

        indices = self._sample_proportional()

        obs = self.obs_buf[indices]
        next_obs = self.next_obs_buf[indices]
        acts = self.acts_buf[indices]
        rews = self.rews_buf[indices]
        done = self.done_buf[indices]
        weights = np.array([self._calculate_weight(i, beta) for i in indices])

        return dict(
            obs=obs,
            next_obs=next_obs,
            acts=acts,
            rews=rews,
            done=done,
            weights=weights,
            indices=indices,
        )

    def update_priorities(self, indices: List[int], priorities: np.ndarray):
        """Update priorities of sampled transitions."""
        assert len(indices) == len(priorities)

        for idx, priority in zip(indices, priorities):
            assert priority > 0
            assert 0 <= idx < len(self)

            self.sum_tree[idx] = priority ** self.alpha
            self.min_tree[idx] = priority ** self.alpha

            self.max_priority = max(self.max_priority, priority)

    def _sample_proportional(self) -> List[int]:
        """Sample indices based on proportions."""
        indices = []
        p_total = self.sum_tree.sum(0, len(self) - 1)
        segment = p_total / self.batch_size

        for i in range(self.batch_size):
            a = segment * i
            b = segment * (i + 1)
            upperbound = random.uniform(a, b)
            idx = self.sum_tree.retrieve(upperbound)
            indices.append(idx)

        return indices

    def _calculate_weight(self, idx: int, beta: float):
        """Calculate the weight of the experience at idx."""
        # get max weight
        p_min = self.min_tree.min() / self.sum_tree.sum()
        max_weight = (p_min * len(self)) ** (-beta)

        # calculate weights
        p_sample = self.sum_tree[idx] / self.sum_tree.sum()
        weight = (p_sample * len(self)) ** (-beta)
        weight = weight / max_weight

        return weight

## Network

We are going to use a simple network architecture with three fully connected layers and two non-linearity functions (ReLU).

In [None]:
class Network(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        """Initialization."""
        super(Network, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        return self.layers(x)

## DQN + PER Agent

Here is a summary of DQNAgent class.

| Method           | Note                                                 |
| ---              | ---                                                  |
|select_action     | select an action from the input state.               |
|step              | take an action and return the response of the env.   |
|compute_dqn_loss  | return dqn loss.                                     |
|update_model      | update the model by gradient descent.                |
|target_hard_update| hard update from the local model to the target model.|
|train             | train the agent during num_frames.                   |
|test              | test the agent (1 episode).                          |
|plot              | plot the training progresses.                        |


All differences from pure DQN are noted with comments - PER.

#### __init__

Here, we use PrioritizedReplayBuffer, instead of ReplayBuffer, and use hold 2 more parameters beta and priority epsilon which are used to calculate weights and new priorities respectively.

#### compute_dqn_loss & update_model

It returns every loss per each sample for importance sampling before average. After updating the nework, it is necessary to update priorities of all sampled experiences.

#### train

beta linearly increases to 1 at every training step.

In [None]:
class DQNAgent:
    """DQN Agent interacting with environment.

    Attribute:
        env (gym.Env): openAI Gym environment
        memory (ReplayBuffer): replay memory to store transitions
        batch_size (int): batch size for sampling
        epsilon (float): parameter for epsilon greedy policy
        epsilon_decay (float): step size to decrease epsilon
        max_epsilon (float): max value of epsilon
        min_epsilon (float): min value of epsilon
        target_update (int): period for target model's hard update
        gamma (float): discount factor
        dqn (Network): model to train and select actions
        dqn_target (Network): target model to update
        optimizer (torch.optim): optimizer for training dqn
        transition (list): transition information including
                           state, action, reward, next_state, done
        beta (float): determines how much importance sampling is used
        prior_eps (float): guarantees every transition can be sampled
    """

    def __init__(
        self,
        env: gym.Env,
        memory_size: int,
        batch_size: int,
        target_update: int,
        epsilon_decay: float,
        seed: int,
        max_epsilon: float = 1.0,
        min_epsilon: float = 0.1,
        gamma: float = 0.99,
        # PER parameters
        alpha: float = 0.2,
        beta: float = 0.6,
        prior_eps: float = 1e-6,
    ):
        """Initialization.

        Args:
            env (gym.Env): openAI Gym environment
            memory_size (int): length of memory
            batch_size (int): batch size for sampling
            target_update (int): period for target model's hard update
            epsilon_decay (float): step size to decrease epsilon
            lr (float): learning rate
            max_epsilon (float): max value of epsilon
            min_epsilon (float): min value of epsilon
            gamma (float): discount factor
            alpha (float): determines how much prioritization is used
            beta (float): determines how much importance sampling is used
            prior_eps (float): guarantees every transition can be sampled
        """
        obs_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n

        self.env = env

        self.batch_size = batch_size
        self.epsilon = max_epsilon
        self.epsilon_decay = epsilon_decay
        self.seed = seed
        self.max_epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.target_update = target_update
        self.gamma = gamma

        # device: cpu / gpu
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        print(self.device)

        # PER
        # In DQN, We used "ReplayBuffer(obs_dim, memory_size, batch_size)"
        self.beta = beta
        self.prior_eps = prior_eps
        self.memory = PrioritizedReplayBuffer(
            obs_dim, memory_size, batch_size, alpha
        )

        # networks: dqn, dqn_target
        self.dqn = Network(obs_dim, action_dim).to(self.device)
        self.dqn_target = Network(obs_dim, action_dim).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        # optimizer
        self.optimizer = optim.Adam(self.dqn.parameters())

        # transition to store in memory
        self.transition = list()

        # mode: train / test
        self.is_test = False

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input state."""
        # epsilon greedy policy
        if self.epsilon > np.random.random():
            selected_action = self.env.action_space.sample()
        else:
            selected_action = self.dqn(
                torch.FloatTensor(state).to(self.device)
            ).argmax()
            selected_action = selected_action.detach().cpu().numpy()

        if not self.is_test:
            self.transition = [state, selected_action]

        return selected_action

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
        """Take an action and return the response of the env."""
        next_state, reward, done= self.env.step(action)

        if not self.is_test:
            self.transition += [reward, next_state, done]
            self.memory.store(*self.transition)

        return next_state, reward, done

    def update_model(self) -> torch.Tensor:
        """Update the model by gradient descent."""
        # PER needs beta to calculate weights
        samples = self.memory.sample_batch(self.beta)
        weights = torch.FloatTensor(
            samples["weights"].reshape(-1, 1)
        ).to(self.device)
        indices = samples["indices"]

        # PER: importance sampling before average
        elementwise_loss = self._compute_dqn_loss(samples)
        loss = torch.mean(elementwise_loss * weights)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # PER: update priorities
        loss_for_prior = elementwise_loss.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.prior_eps
        self.memory.update_priorities(indices, new_priorities)

        return loss.item()

    def train(self, num_episodes, optimal_return, max_frames):
        """Train the agent."""
        self.is_test = False

        state, _ = self.env.reset(seed=self.seed)
        update_cnt = 0
        epsilons = []
        losses = []
        scores = np.zeros((num_episodes))
        score = 0

        episode_count = 0
        step_count = 0
        early_stop_buffer = np.ones(10, dtype=np.float64) * -1000
        converged = False

        while episode_count != num_episodes and step_count != max_frames:
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

            # PER: increase beta
            fraction = min(step_count / max_frames, 1.0)
            self.beta = self.beta + fraction * (1.0 - self.beta)

            # if episode ends
            if done:
                state, _ = self.env.reset(seed=self.seed)
                scores[episode_count % num_episodes] += score
                early_stop_buffer[episode_count % 10] = self.test()
                score = 0
                if (episode_count % 200 == 0):
                    print(f'Episode: {episode_count} | Step Count: {step_count} | Average Score (Last 10): {np.mean(early_stop_buffer)}')
                if (not converged) and (np.mean(early_stop_buffer) == optimal_return):
                    print('Agent has converged to the optimal solution...')
                    converged_steps = step_count
                    converged = True
                episode_count += 1

            # if training is ready
            if len(self.memory) >= self.batch_size:
                loss = self.update_model()
                losses.append(loss)
                update_cnt += 1

                # linearly decrease epsilon
                self.epsilon = max(
                    self.min_epsilon, self.epsilon - (
                        self.max_epsilon - self.min_epsilon
                    ) * self.epsilon_decay
                )
                epsilons.append(self.epsilon)

                # if hard update is needed
                if update_cnt % self.target_update == 0:
                    self._target_hard_update()

            step_count += 1

        if not converged:
            converged_steps = step_count

        self.env.close()

        return scores, converged_steps

    def test(self) -> None:
        """Test the agent."""
        self.is_test = True

        state, _ = self.env.reset(seed=self.seed)
        done = False
        score = 0
        steps = 0

        while not done and steps < 50:
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward
            steps += 1
        self.env.close()

        return score

    def _compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:
        """Return dqn loss."""
        device = self.device  # for shortening the following lines
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"].reshape(-1, 1)).to(device)
        reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
        done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)

        # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        curr_q_value = self.dqn(state).gather(1, action)
        next_q_value = self.dqn_target(
            next_state
        ).max(dim=1, keepdim=True)[0].detach()
        mask = 1 - done
        target = (reward + self.gamma * next_q_value * mask).to(self.device)

        # calculate element-wise dqn loss
        elementwise_loss = F.smooth_l1_loss(curr_q_value, target, reduction="none")

        return elementwise_loss

    def _target_hard_update(self):
        """Hard update: target <- local."""
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def _plot(
        self,
        scores,
        step_count,
        ep_num
    ):
        """Plot the training progresses."""
        clear_output(True)
        plt.figure(figsize=(20, 5))
        plt.subplot(131)
        plt.title(f'Current Step Count: {step_count} | Current Ep: {ep_num}')
        plt.plot(scores)
        plt.show()

## Environment

In [None]:
# 0: normal tile
# 1: orange tile
# 2: soft switch
# 3: hard switch
# 4: goal
# 5: transport switch
# 8: block
# 9: none

# Level 1:
level_one_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 4, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

# Level 2:
level_two_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 3, 0, 9, 9, 0, 4, 0, 9, 9, 9],
        [9, 9, 0, 0, 2, 0, 9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

level_two_soft_switches = np.array(
    [{"switch_location": (4, 4), "toggle_tiles": [(6, 6), (6, 7)], "mode": "toggle"}]
)


level_two_hard_switches = np.array(
    [{"switch_location": (3, 10), "toggle_tiles": [(6, 12), (6, 13)]}]
)

# Level 3:
level_three_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 0, 0, 0, 9, 9, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 0, 0, 0, 0, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 0, 0, 4, 0, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

# Level 4:
level_four_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 0, 0, 0, 0, 1, 1, 1, 1, 1, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 0, 0, 0, 0, 1, 1, 1, 1, 1, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 0, 4, 0, 9, 9, 1, 1, 0, 1, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 1, 1, 1, 1, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

# Level 5:
level_five_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9],
        [9, 9, 9, 0, 0, 2, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 2, 9, 9, 9],
        [9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 9, 9, 9],
        [9, 9, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

level_five_soft_switches = np.array(
    [
        {
            "switch_location": (2, 10),
            "toggle_tiles": [(2, 7), (2, 8)],
            "mode": "toggle",
        },
        {
            "switch_location": (7, 16),
            "toggle_tiles": [(9, 7), (9, 8)],
            "mode": "toggle",
        },
        {"switch_location": (6, 8), "toggle_tiles": [(9, 7), (9, 8)], "mode": "off"},
        {"switch_location": (4, 5), "toggle_tiles": [(9, 7), (9, 8)], "mode": "on"},
    ]
)

# Level 6:
level_six_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 0, 0, 0, 0, 0, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 0, 0, 0, 0, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 0, 0, 4, 0, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 0, 0, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

# Level 7:
level_seven_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 0, 9, 9, 0, 0, 0, 0, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9, 9, 0, 4, 0, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 0, 0, 3, 9, 9, 0, 0, 0, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 0, 0, 0, 9, 9, 0, 0, 0, 9, 9],
        [9, 9, 9, 9, 0, 0, 9, 9, 9, 9, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

level_seven_hard_switches = np.array(
    [{"switch_location": (5, 12), "toggle_tiles": [(7, 6)]}]
)


# Level 8:
level_eight_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 5, 0, 9, 9, 9, 0, 0, 0, 0, 4, 0, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

level_eight_teleport_switches = np.array(
    [{"switch_location": (6, 7), "split_positions": [(3, 13), (9, 13)]}]
)


# Level 9:
level_nine_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 0, 9, 9, 9, 0, 0, 0, 0, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 0, 9, 9, 9, 0, 0, 5, 0, 9, 9],
        [9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 4, 0, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

level_nine_teleport_switches = np.array(
    [{"switch_location": (4, 16), "split_positions": [(4, 15), (4, 5)]}]
)


# Level 10:
level_ten_env = np.array(
    [
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0, 9, 9, 9],
        [9, 9, 9, 0, 4, 0, 9, 9, 0, 9, 9, 0, 0, 0, 0, 5, 0, 9, 9, 9],
        [9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9, 0, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 9, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 0, 0, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 9, 9, 0, 0, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 0, 2, 9, 9, 0, 0, 0, 3, 0, 9, 9, 9, 9],
        [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ]
)

level_ten_teleport_switches = np.array(
    [{"switch_location": (2, 15), "split_positions": [(2, 15), (2, 12)]}]
)

level_ten_hard_switches = np.array(
    [{"switch_location": (10, 14), "toggle_tiles": [(2, 9), (2, 10),(3, 15), (4, 15)]}]
)

level_ten_soft_switches = np.array(
    [{"switch_location": (10, 8), "toggle_tiles": [(2, 6), (2, 7)], "mode":"toggle"}]
)


In [None]:
class Block:

    def __init__(self, r1, c1, r2, c2):
        self._r1 = r1
        self._r2 = r2
        self._c1 = c1
        self._c2 = c2

        self._focus_block = 0

    def set_coords(self, r1, c1, r2, c2):
        self._r1 = r1
        self._r2 = r2
        self._c1 = c1
        self._c2 = c2

    def get_coords(self):
        return self._r1, self._c1, self._r2, self._c2

    def is_upright(self):
        return self._r1 == self._r2 and self._c1 == self._c2

    def is_wide(self):
        return self._r1 == self._r2 and self._c1 != self._c2

    def move_up(self):
        match self._focus_block:
            case 0:
                # vertical
                if self.is_upright():
                    self._r1 -= 1
                    self._r2 -= 2

                # flat and wide
                elif self.is_wide():
                    self._r1 -= 1
                    self._r2 -= 1

                # flat and long
                else:
                    min_r = min(self._r1, self._r2)
                    self._r1 = min_r - 1
                    self._r2 = min_r - 1

            case 1:
                self._r1 -= 1

            case 2:
                self._r2 -= 1

    def move_down(self):
        match self._focus_block:
            case 0:
                # vertical
                if self.is_upright():
                    self._r1 += 1
                    self._r2 += 2

                # flat and wide
                elif self.is_wide():
                    self._r1 += 1
                    self._r2 += 1

                # flat and long
                else:
                    max_r = max(self._r1, self._r2)
                    self._r1 = max_r + 1
                    self._r2 = max_r + 1

            case 1:
                self._r1 += 1
            case 2:
                self._r2 += 1

    # edited
    def move_right(self):
        match self._focus_block:
            case 0:
                # vertical
                if self.is_upright():
                    self._c1 += 1
                    self._c2 += 2

                # flat and wide
                elif self.is_wide():
                    max_c = max(self._c1, self._c2)
                    self._c1 = max_c + 1
                    self._c2 = max_c + 1

                # flat and long
                else:
                    self._c1 += 1
                    self._c2 += 1

            case 1:
                self._c1 += 1
            case 2:
                self._c2 += 1

    # edited
    def move_left(self):
        match self._focus_block:
            case 0:
                # vertical
                if self.is_upright():
                    self._c1 -= 1
                    self._c2 -= 2

                # flat and wide
                elif self.is_wide():
                    min_c = min(self._c1, self._c2)
                    self._c1 = min_c - 1
                    self._c2 = min_c - 1

                # flat and long
                else:
                    self._c1 -= 1
                    self._c2 -= 1
            case 1:
                self._c1 -= 1
            case 2:
                self._c2 -= 1

    def toggle_focus(self):
        if self._focus_block == 0:
            self._focus_block = 0
        elif self._focus_block == 1:
            self._focus_block = 2
        else:
            self._focus_block = 1

    def set_focus(self, focus):
        self._focus_block = focus

    def get_focus(self):
        return self._focus_block

    def join_single_blocks(self):
        if self._focus_block == 1 or self._focus_block == 2:
            if abs(self._r1 - self._r2) == 1 and (self._c1 == self._c2):
                self.set_focus(0)
            elif abs(self._c1 - self._c2) == 1 and (self._r1 == self._r2):
                self.set_focus(0)

In [None]:
class Level(gym.Env):
    metadata = {"render_modes": [], "render_fps": 0}

    def __init__(
        self,
        start_pos: tuple,
        base_env: np.array([]),
        soft_switches=np.array([]),
        hard_switches=np.array([]),
        teleport_switches=np.array([]),
        render_mode=None,
    ):
        self._r_start = start_pos[0]
        self._c_start = start_pos[1]

        self._block = Block(self._r_start, self._c_start, self._r_start, self._c_start)

        self._base_env = base_env

        self._soft_switches = soft_switches
        self._hard_switches = hard_switches
        self._teleport_switches = teleport_switches

        self._actions = {
            0: self._block.move_right,
            1: self._block.move_up,
            2: self._block.move_left,
            3: self._block.move_down,
            4: self._block.toggle_focus,
        }

        self.observation_space = np.append(base_env.ravel(), 0)
        self.action_space = gym.spaces.Discrete(5)

    def step(self, action):
        # if the block is split, check if single blocks are adjacent, and join together
        self._block.join_single_blocks()

        # update the agent's coords by passing it the action
        self._perform_action(action)

        # check if the agent is out of bounds -> reset to the start
        r1, c1, r2, c2 = self._block.get_coords()

        reward, done = self._is_done(r1, c1, r2, c2)

        # only check for environment changes if the action is not "Switch Focus"
        if action != 4:
            self._move_to_start(r1, c1, r2, c2)
            self._activate_teleport_switch(r1, c1, r2, c2)
            self._toggle_soft_switches(r1, c1, r2, c2)
            self._toggle_hard_switches(r1, c1, r2, c2)
        else:
            reward = -5

        state = self._format_environment()

        return state, reward, done

    def reset(self, seed):
        # set both of the agent's coords to (self._r_start,self._c_start) and (self._r_start,self._c_start)
        self._block.set_coords(
            self._r_start, self._c_start, self._r_start, self._c_start
        )
        self._block.set_focus(0)

        # reset the environment (important to undo any obstacle interactions)
        self._current_env = np.copy(self._base_env)

        # place the agent in the environment using its position
        state = np.copy(self._current_env)
        state[self._r_start, self._c_start] = 8
        state = state.ravel()
        state = np.append(state, self._block.get_focus())
        # state = np.array2string(state, separator="") + str(self._block.get_focus())

        return state, False

    def _move_to_start(self, r1, c1, r2, c2):
        if self._current_env[r1, c1] == 9 or self._current_env[r2, c2] == 9:
            self.reset(42)

    def _is_done(self, r1, c1, r2, c2):
        # check if the agent is on the goal -> set done to True and reward to 0

        # reward is -1 and done is False unless the agent hit the goal
        reward = -1
        done = False

        if self._current_env[r1, c1] == 4 and self._current_env[r2, c2] == 4:
            reward = 0
            done = True
        # elif self._current_env[r1, c1] == 9 or self._current_env[r2, c2] == 9:
        #   reward = -1000
        #   done = True

        return reward, done

    def _format_environment(self):
        # place the agent in the environment using its position
        r1, c1, r2, c2 = self._block.get_coords()
        state = np.copy(self._current_env)
        state[r1, c1] = 8
        state[r2, c2] = 8

        state = state.ravel()
        state = np.append(state, self._block.get_focus())
        # state = np.array2string(state, separator="") + str(self._block.get_focus())

        return state

    def _toggle_soft_switches(self, r1, c1, r2, c2):
        # check if the agent is on a circle switch -> activate bridge
        for c in self._soft_switches:
            switch_location = c["switch_location"]
            toggle_tiles = c["toggle_tiles"]
            mode = c["mode"]

            if (r1 == switch_location[0] and c1 == switch_location[1]) or (
                r2 == switch_location[0] and c2 == switch_location[1]
            ):
                if mode == "toggle":
                    if self._current_env[toggle_tiles[0][0], toggle_tiles[0][1]] == 0:
                        for t in toggle_tiles:
                            self._current_env[t[0], t[1]] = 9
                            self._current_env[t[0], t[1]] = 9

                    else:
                        for t in toggle_tiles:
                            self._current_env[t[0], t[1]] = 0
                            self._current_env[t[0], t[1]] = 0
                elif mode == "on":
                    for t in toggle_tiles:
                        self._current_env[t[0], t[1]] = 0
                        self._current_env[t[0], t[1]] = 0
                elif mode == "off":
                    for t in toggle_tiles:
                        self._current_env[t[0], t[1]] = 9
                        self._current_env[t[0], t[1]] = 9

    def _toggle_hard_switches(self, r1, c1, r2, c2):
        # check if the agent is on an x switch -> activate bridge
        for c in self._hard_switches:
            switch_location = c["switch_location"]
            toggle_tiles = c["toggle_tiles"]

            if (r1 == switch_location[0] and c1 == switch_location[1]) and (
                r2 == switch_location[0] and c2 == switch_location[1]
            ):
                if self._current_env[toggle_tiles[0][0], toggle_tiles[0][1]] == 0:
                    for t in toggle_tiles:
                        self._current_env[t[0], t[1]] = 9
                        self._current_env[t[0], t[1]] = 9

                else:
                    for t in toggle_tiles:
                        self._current_env[t[0], t[1]] = 0
                        self._current_env[t[0], t[1]] = 0

    def _activate_teleport_switch(self, r1, c1, r2, c2):
        # check if block is on teleport switch -> split block into two single blocks
        for t in self._teleport_switches:
            switch_location = t["switch_location"]
            split_positions = t["split_positions"]


            if (r1 == switch_location[0] and c1 == switch_location[1]) and (
                r2 == switch_location[0] and c2 == switch_location[1]
            ):

                single_block_one = split_positions[0]
                single_block_two = split_positions[1]

                r1 = single_block_one[0]
                c1 = single_block_one[1]

                r2 = single_block_two[0]
                c2 = single_block_two[1]

                self._block.set_focus(1)
                self._block.set_coords(r1, c1, r2, c2)

    def _handle_orange_tile(self, r1, c1, r2, c2):
        # check if block is vertical
        if (r1, c1) == (r2, c2):
            # check if tile is orange tile
            if self._current_env[r1, c1] == 1:
                # tile disappears/block falls through grid

                self._block.set_coords(
                    self._r_start, self._c_start, self._r_start, self._c_start
                )

        # nothing happens if block is not vertical on an orange tile

    def _perform_action(self, action):
        # Get the corresponding method from 'actions' and call it
        action_method = self._actions.get(int(action))
        if action_method:
            action_method()

        else:
            print("Invalid action")

    def get_state(self):
        r1, c1, r2, c2 = self._block.get_coords()
        print(r1, c1, r2, c2)
        state = np.copy(self._current_env)
        state[r1, c1] = 8
        state[r2, c2] = 8

        return state

    def get_block(self):
        return self._block

In [None]:
level = 1

if level == 1:
    env = Level(start_pos=(3, 6), base_env=level_one_env)
    optimal_return = -6

elif level == 2:
    env = Level(
        start_pos=(6, 3),
        base_env=level_two_env,
        soft_switches=level_two_soft_switches,
        hard_switches=level_two_hard_switches,
    )
    optimal_return = -16

elif level == 3:
    env = Level(start_pos=(4, 3), base_env=level_three_env)
    optimal_return = -18

elif level == 4:
    env = Level(start_pos=(6, 4), base_env=level_four_env)
    optimal_return = -27

elif level == 5:
    env = Level(
        start_pos=(2, 15),
        base_env=level_five_env,
        soft_switches=level_five_soft_switches,
    )
    optimal_return = -32

elif level == 6:
    env = Level(
        start_pos=(4, 3),
        base_env=level_six_env,
    )
    optimal_return = -34

elif level == 7:
    env = Level(
        start_pos=(4, 4),
        base_env=level_seven_env,
        hard_switches=level_seven_hard_switches,
    )
    optimal_return = -43

elif level == 8:
    env = Level(
        start_pos=(6, 4),
        base_env=level_eight_env,
        teleport_switches=level_eight_teleport_switches,
    )
    optimal_return = -10

elif level == 9:
    env = Level(
        start_pos=(4, 4),
        base_env=level_nine_env,
        teleport_switches=level_nine_teleport_switches
    )
    optimal_return = -28

elif level == 10:
    env = Level(
        start_pos=(2, 12),
        base_env=level_ten_env,
        soft_switches=level_ten_soft_switches,
        hard_switches=level_ten_hard_switches,
        teleport_switches=level_ten_teleport_switches,
    )
    optimal_return = -61

## Set random seed

In [None]:
seed = 777

def seed_torch(seed):
    torch.manual_seed(seed)
    if torch.backends.cudnn.enabled:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

np.random.seed(seed)
random.seed(seed)
seed_torch(seed)

## Initialize

In [None]:
# parameters
num_episodes = 2000
memory_size = 1000
batch_size = 128
target_update = 100
epsilon_decay = 1 / 100
num_trials = 10
max_frames = 500000 # stops the agent in the event it does not hit 2000 episodes

## Train

In [None]:
total_scores = []
total_steps = []

for trial in range(num_trials):
    print(f'Starting trial {trial}...')
    agent = DQNAgent(env, memory_size, batch_size, target_update, epsilon_decay, seed, gamma=1.0)
    scores, steps = agent.train(num_episodes, optimal_return, max_frames)
    total_scores.append(scores)
    total_steps.append(steps)

avg_scores = np.mean(total_scores, axis=0)
avg_steps = np.mean(total_steps)
print(avg_scores)
print(avg_steps)
np.save('../results/per_dqn', np.concatenate((avg_scores, [avg_steps])))