## Environment

### Basic Env

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

'''
States:
0 -> exploring
1 -> found AGI
2 -> gave up

Actions:
0 -> retreat
1 -> continue
2 -> accept incoming
3 -> reject incoming
'''

class BasicEnv(gym.Env):
    """Custom Environment"""

    def __init__(self):
        super(BasicEnv, self).__init__()
        self.start = 0
        self.state = self.start

        # Action space
        self.action_space = spaces.Discrete(2)

        # Observation space
        self.observation_space = spaces.Discrete(3)

        self.prob_transition = {
            0: {  # State 0: "exploring"
                0: [{"state": 1, "prob": 0.02, "reward": 100},  # Explore -> Found AGI
                    {"state": 0, "prob": 0.98, "reward": -1}],  # Explore -> Stay exploring
                1: [{"state": 2, "prob": 1, "reward": 10}]  # Retreat -> Give up

            },
            1: {  # State 1: "discovered AGI"
            },
            2: {  # State 2: "gave up on AGI"
            }
        }

        self.terminal_states = [1, 2]
        self.max_steps = None
        self.current_step = 0

    def step(self, action):
        """Take an action and return the new state, reward, done, truncated, and info"""
        assert self.observation_space.contains(self.state), f"Invalid state: {self.state}"
        assert self.action_space.contains(action), f"Invalid action: {action}"
        state = self.state
        # randomly sample from prob transitions
        transitions = self.prob_transition[state][action]
        probabilities = [t["prob"] for t in transitions]
        chosen_transition = np.random.choice(len(transitions), p=probabilities)
        new_state = transitions[chosen_transition]["state"]
        reward = transitions[chosen_transition]["reward"]

        self.state = new_state
        self.current_step += 1
        done = new_state in self.terminal_states
        return self.state, reward, done, False, {}

    def reset(self, seed=None, options=None):
        """Reset environment to initial state, returns state and info"""
        self.state = self.start
        self.current_step = 0
        return self.state, {}

    def render(self):
        """Print the grid with agent's position"""
        print("State:", self.state)

    def close(self):
        pass

if __name__ == "__main__":
    env = BasicEnv()
    obs, _ = env.reset()
    done = False

    while not done:
        action = env.action_space.sample()  # Take random actions
        obs, reward, done, _, _ = env.step(action)
        env.render()

    env.close()



### Hiring Env

In [None]:
import gym
from gym import spaces
import numpy as np

class HiringAGIEnv(gym.Env):
    """
    Hiring Environment for AGI Discovery
    """

    def decode_action(self, action):
        if action == 0:
            return {"action_type": 0, "fire_index": None}
        elif action == 1:
            return {"action_type": 1, "fire_index": None}
        else:
            return {"action_type": 2, "fire_index": action - 2}


    def __init__(self, team_size=5, s0=3.0, alpha=1.5):
        super(HiringAGIEnv, self).__init__()

        # Constants
        self.team_size = team_size
        self.max_skill = team_size
        self.s0 = s0  # AGI threshold skill
        self.alpha = alpha  # Sharpness of AGI probability curve

        # Action space: [0=Retreat, 1=Search, 2=Hire and Fire(i)]
        self.action_space = spaces.Discrete(2 + self.team_size)
        # Observation space (structured, not flattened for readability)
        self.observation_space = spaces.Dict({
            "phase": spaces.Discrete(3),
            "team_skills": spaces.Box(low=0.0, high=1.0, shape=(team_size,), dtype=np.float32),
            "team_salaries": spaces.Box(low=0.0, high=1.0, shape=(team_size,), dtype=np.float32),
            "candidate": spaces.Box(low=0.0, high=1.0, shape=(2,), dtype=np.float32),  # [skill, salary]
            "total_salary": spaces.Box(low=0.0, high=team_size, shape=(), dtype=np.float32)
        })

        self.workers = []

        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.phase = 0  # 0 = exploring, 1 = AGI found, 2 = gave up
        if options is None:
            self.workers = [self._generate_worker() for _ in range(self.team_size)]
        else:
            self.workers = options["workers"]
        self._update_stats()
        self.candidate = self._generate_worker()
        self.current_step = 0
        return self._get_obs(), {}

    def _current_skill(self):
        return sum(w['skill'] for w in self.workers)

    def _generate_worker(self, skill_bias=0.2, skill_uncertainty=0.8, skill_mean=0.5, salary_noise=0.8, salary=None):
        base_skill = np.clip(
            np.random.normal(
                loc=skill_mean + skill_bias * (self._current_skill() / self.max_skill),
                scale=skill_uncertainty
            ), 0, 1
        )

        if salary is None:
            # Add noise to salary, making it only loosely correlated with skill
            salary = np.clip(
                base_skill + np.random.normal(loc=0.0, scale=salary_noise),
                0.0,
                1.0
            )

        return {"skill": base_skill, "salary": salary}


    def _update_stats(self):
        self.current_skill = sum(w['skill'] for w in self.workers)
        self.total_salary = sum(w['salary'] for w in self.workers)

    def _agi_probability(self):
        return 1 / (1 + np.exp(-self.alpha * (self.current_skill - self.s0)))

    def _get_obs(self):
        return {
            "phase": self.phase,
            "team_skills": np.array([w['skill'] for w in self.workers], dtype=np.float32),
            "team_salaries": np.array([w['salary'] for w in self.workers], dtype=np.float32),
            "candidate": np.array([self.candidate['skill'], self.candidate['salary']], dtype=np.float32),
            "total_salary": np.float32(self.total_salary)
        }

    def step(self, action):
        action = self.decode_action(action)
        # assert self.action_space.contains(action)
        reward = 0
        done = False
        if self.phase == 0:
            if action["action_type"] == 0:  # Retreat
                self.phase = 2
                reward = 10
                done = True

            elif action["action_type"] == 1:  # Search
                if np.random.rand() < self._agi_probability():
                    self.phase = 1
                    reward = 100
                    done = True

            elif action["action_type"] == 2:  # Hire
                index = action["fire_index"]
                if 0 <= index < len(self.workers):
                    self.workers[index] = self.candidate
                    reward = self.candidate['skill'] - self.workers[index]['skill']
                    self._update_stats()
                    self.candidate = self._generate_worker()

        else:
            done = True

        # Time cost: salary penalty
        reward -=  5 * self.total_salary / self.team_size

        self.current_step += 1
        return self._get_obs(), reward, done, False, {}

    def render(self):
        print(f"Phase: {self.phase} | Step: {self.current_step} | Team Skill: {self.current_skill:.2f} | Total Salary: {self.total_salary:.2f}")
        for i, w in enumerate(self.workers):
            print(f"  Worker {i}: Skill={w['skill']:.2f}, Salary={w['salary']:.2f}")
        print(f"Candidate: Skill={self.candidate['skill']:.2f}, Salary={self.candidate['salary']:.2f}")

    def close(self):
        pass


if __name__ == "__main__":
    env = HiringAGIEnv(s0=3.0)
    obs, _ = env.reset()
    done = False

    while not done:
        env.render()
        action_type = np.random.randint(0, 2)
        fire_index = np.random.randint(0, env.team_size)
        action = action_type + (fire_index if action_type == 2 else 0) # {"action_type": action_type, "fire_index": fire_index}
        obs, reward, done, _, _ = env.step(action)
        print(f"Action: {action}, Reward: {reward:.2f}\n")

    env.close()


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class DictObsPreprocessor(nn.Module):
    def __init__(self, team_size):
        super().__init__()
        self.team_size = team_size

    def forward(self, obs_dict):
        # Handle phase
        phase = obs_dict["phase"]
        if not torch.is_tensor(phase):
            phase = torch.tensor(phase, dtype=torch.long)
        if phase.dim() == 0:
            phase = phase.unsqueeze(0)
        phase_one_hot = F.one_hot(phase, num_classes=3).float()  # shape: [1, 3]

        # Other components
        def ensure_2d(x, dtype=torch.float32):
            x = torch.tensor(x, dtype=dtype)
            if x.dim() == 1:
                x = x.unsqueeze(0)  # convert [D] → [1, D]
            return x

        team_skills = ensure_2d(obs_dict["team_skills"])
        team_salaries = ensure_2d(obs_dict["team_salaries"])
        candidate = ensure_2d(obs_dict["candidate"])
        total_salary = ensure_2d([obs_dict["total_salary"]])  # wrap scalar in list

        return torch.cat([
            phase_one_hot,
            team_skills,
            team_salaries,
            candidate,
            total_salary
        ], dim=-1)


## Q-Learning Agent (for Basic Env)

### Agent

In [None]:
from collections import defaultdict
import gymnasium as gym
import numpy as np


class QLearningAgent:
    def __init__(
        self,
        env: gym.Env,
        learning_rate: float,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
        discount_factor: float = 0.95,
    ):
        """Initialize a Reinforcement Learning agent with an empty dictionary
        of state-action values (q_values), a learning rate and an epsilon.

        Args:
            env: The training environment
            learning_rate: The learning rate
            initial_epsilon: The initial epsilon value
            epsilon_decay: The decay for epsilon
            final_epsilon: The final epsilon value
            discount_factor: The discount factor for computing the Q-value
        """
        self.env = env
        self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))

        self.lr = learning_rate
        self.discount_factor = discount_factor

        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon

        self.training_error = []

    def get_action(self, obs: tuple[int, int, bool]) -> int:
        """
        Returns the best action with probability (1 - epsilon)
        otherwise a random action with probability epsilon to ensure exploration.
        """
        # with probability epsilon return a random action to explore the environment
        if np.random.random() < self.epsilon:
            return self.env.action_space.sample()
        # with probability (1 - epsilon) act greedily (exploit)
        else:
            return int(np.argmax(self.q_values[obs]))

    def update(
        self,
        obs: tuple[int, int, bool],
        action: int,
        reward: float,
        terminated: bool,
        next_obs: tuple[int, int, bool],
    ):
        """Updates the Q-value of an action."""
        future_q_value = (not terminated) * np.max(self.q_values[next_obs])
        temporal_difference = (
            reward + self.discount_factor * future_q_value - self.q_values[obs][action]
        )

        self.q_values[obs][action] = (
            self.q_values[obs][action] + self.lr * temporal_difference
        )
        self.training_error.append(temporal_difference)

    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)


### Training

In [None]:
# hyperparameters
learning_rate = 0.005
n_episodes = 100_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2)  # reduce the exploration over time
final_epsilon = 0.1

env = BasicEnv() # gym.make('FrozenLake-v1', is_slippery=False)
env = gym.wrappers.RecordEpisodeStatistics(env, buffer_length=n_episodes)

agent = QLearningAgent(
    env=env,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)


In [None]:
from tqdm import tqdm

for episode in tqdm(range(n_episodes)):
    obs, info = env.reset()
    done = False

    # play one episode
    while not done:
        action = agent.get_action(obs)
        next_obs, reward, terminated, truncated, info = env.step(action)

        # update the agent
        agent.update(obs, action, reward, terminated, next_obs)

        # update if the environment is done and the current obs
        done = terminated or truncated
        obs = next_obs


    agent.decay_epsilon()


### Visualizations

In [None]:
from matplotlib import pyplot as plt

def get_moving_avgs(arr, window, convolution_mode):
    return np.convolve(
        np.array(arr).flatten(),
        np.ones(window),
        mode=convolution_mode
    ) / window

# Smooth over a 500 episode window
rolling_length = 5000
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))

axs[0].set_title("Episode rewards")
reward_moving_average = get_moving_avgs(
    env.return_queue,
    rolling_length,
    "valid"
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)

axs[1].set_title("Episode lengths")
length_moving_average = get_moving_avgs(
    env.length_queue,
    rolling_length,
    "valid"
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)

axs[2].set_title("Training Error")
training_error_moving_average = get_moving_avgs(
    agent.training_error,
    rolling_length,
    "same"
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def visualize_policy(agent, env):
    """Visualizes the optimal policy learned by the agent."""

    policy = {}
    for state in range(env.observation_space.n):
        policy[state] = np.argmax(agent.q_values[state])

    # Create a plot to visualize the policy
    plt.figure(figsize=(6, 4))
    for state in policy:
      plt.scatter(state, policy[state], marker='o', s=100, label=f'State {state}: Action {policy[state]}')

    plt.xlabel("State")
    plt.ylabel("Action")
    plt.title("Optimal Policy")
    plt.xticks(range(env.observation_space.n))
    plt.yticks(range(env.action_space.n))
    plt.legend()
    plt.grid(True)
    plt.show()

visualize_policy(agent, env)


## Deep Q Learning (for Hiring Env)

### Agent

In [None]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

### Training

In [None]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = HiringAGIEnv(
    team_size=5,
    s0=4,
    alpha = 1
)  # gym.make("CartPole-v1")

preprocessor = DictObsPreprocessor(team_size=6)

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [None]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
state = preprocessor(state)
n_observations = state.shape[1]
print(n_actions)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.

            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

### Visualization

In [None]:
if torch.cuda.is_available() or torch.backends.mps.is_available():
    num_episodes = 20_000
else:
    num_episodes = 1000

for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state, info = env.reset()
    state = preprocessor(state)
    state = torch.tensor(state, dtype=torch.float32, device=device)
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        observation = preprocessor(observation)
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            # plot_durations()
            break

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch

def run_and_plot_episode(policy_net, env, preprocessor, device):
    # obs, _ = env.reset(options={
    #     "workers": [
    #         {
    #             "skill": 0.1,
    #             "salary": 100000
    #         },
    #         {
    #             "skill": 0.5,
    #             "salary": 100
    #         },
    #         {
    #             "skill": 0.1,
    #             "salary": 10
    #         },
    #         {
    #             "skill": 0.1,
    #             "salary": 10
    #         },
    #         {
    #             "skill": 0.3,
    #             "salary": 10
    #         }
    #     ]
    # })
    obs, _ = env.reset()
    done = False
    steps = 0

    avg_team_skill_trace = []
    candidate_skill_trace = []
    total_salary_trace = []
    action_trace = []
    initial_team = obs["team_skills"]
    phase_trace = []
    total_reward = 0
    while not done:
        # Preprocess observation
        processed_obs = preprocessor(obs)
        state = torch.tensor(processed_obs, dtype=torch.float32, device=device)
        if state.dim() == 1:
            state = state.unsqueeze(0)  # Add batch dimension if needed

        # Select action
        with torch.no_grad():
            action = policy_net(state).argmax(dim=1).item()

        # Log values
        avg_team_skill = np.mean(obs["team_skills"])
        candidate_skill = obs["candidate"][0]
        total_salary = obs["total_salary"]
        phase = obs["phase"]

        avg_team_skill_trace.append(avg_team_skill)
        candidate_skill_trace.append(candidate_skill)
        total_salary_trace.append(total_salary)
        action_trace.append(action)
        phase_trace.append(phase)

        # Step environment
        obs, reward, done, _, _ = env.step(action)
        steps += 1
        total_reward += reward
        if steps > 500:
            print("Episode cut off at 100 steps")
            break


    print(f"Total reward: {total_reward}")
    # display final average skill value
    print(f"Final average skill: {np.mean(obs['team_skills'])}")
    # display individual skills initial and final
    print(f"Initial skills: {initial_team}")
    print(f"Final skills: {obs['team_skills']}")


    # --- Plot ---
    x = np.arange(len(action_trace))

    plt.figure(figsize=(12, 6))

    plt.subplot(3, 1, 1)
    plt.plot(x, avg_team_skill_trace, label="Avg Team Skill", linewidth=2)
    plt.plot(x, candidate_skill_trace, label="Candidate Skill", linewidth=2, linestyle='--')
    plt.ylabel("Skill")
    plt.title("Episode Rollout")
    plt.grid(True)
    plt.legend()

    plt.subplot(3, 1, 2)
    plt.plot(x, total_salary_trace, color="orange", label="Total Salary")
    plt.ylabel("Salary")
    plt.grid(True)
    plt.legend()

    # plt.subplot(3, 1, 3)
    # plt.scatter(x, phase_trace, c=action_trace, cmap="viridis", s=60)
    # plt.xlabel("Step")
    # plt.ylabel("Phase")
    # plt.grid(True)
    # plt.colorbar(label="Action Index")
    # plt.tight_layout()
    plt.subplot(3, 1, 3)
    scatter = plt.scatter(x, phase_trace, c=action_trace, cmap="tab10", s=60, vmin=0, vmax=max(action_trace))
    plt.xlabel("Step")
    plt.ylabel("Phase")
    plt.grid(True)

    # Discrete colorbar
    cbar = plt.colorbar(scatter, ticks=range(max(action_trace)+1))
    cbar.set_label("Action Index")
    cbar.ax.set_yticklabels([str(i) for i in range(max(action_trace)+1)])


    plt.show()


run_and_plot_episode(policy_net, env, preprocessor, device)