In [45]:
from typing import Callable, Union, Tuple, Optional, Dict
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import os
import random
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
from tensordict import TensorDict
from enum import Enum
import torch.nn.functional as F
import joblib
import wandb
import csv
from collections import deque
from datetime import datetime
import argparse
from __future__ import annotations
import math
from abc import ABC, abstractmethod
from time import time, sleep


In [47]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # Create a long enough positional encoding
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        # Register as a buffer that is not a model parameter
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x is [batch_size, seq_len, feature_size]
        # Adjust positional encoding to have the same size as the input
        pe = self.pe[:x.size(1), :]  # Shape: [seq_len, d_model]
        pe = pe.squeeze(1)  # Remove the singleton dimension
        # Ensure pe is expanded to match the batch size of x
        pe = pe.unsqueeze(0).repeat(x.size(0), 1, 1)  # Shape: [batch_size, seq_len, d_model]
        # The add operation below should now be valid
        return x + pe

In [19]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, head_size, num_heads, ff_dim, dropout=0):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout)
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=ff_dim, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=ff_dim, out_channels=d_model, kernel_size=1)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        # Note: MultiheadAttention expects input of shape (L, N, E) where L is the sequence length, N is the batch size, and E is the embedding dimension.
        attn_output, _ = self.attention(query, key, value)
        out1 = self.layernorm1(query + attn_output)
        
        # Conv1D layers expect input of shape (N, C, L), hence we permute
        out1_permuted = out1.permute(1, 2, 0)
        ff_output = F.relu(self.conv1(out1_permuted))
        ff_output = self.conv2(ff_output)
        
        # Permute back to match the MultiheadAttention output shape
        ff_output = ff_output.permute(2, 0, 1)
        out2 = self.layernorm2(out1 + self.dropout(ff_output))
        
        return out2

In [20]:
class TransformerModel(nn.Module):
    def __init__(self, num_input_samples, d_model, head_size, num_heads, ff_dim, dropout=0, num_transformers=10):
        super(TransformerModel, self).__init__()  # Corrected super() call
        self.d_model = d_model
        self.input_projection = nn.Linear(1, d_model)
        self.pos_encoding = PositionalEncoding(d_model, num_input_samples)
        self.transformers = nn.ModuleList([
            TransformerBlock(d_model, head_size, num_heads, ff_dim, dropout) 
            for _ in range(num_transformers)
        ])
        self.global_avg_pooling = nn.AdaptiveAvgPool1d(1)
        self.output_layer = nn.Linear(d_model, 1)

    def forward(self, x):
        x = self.input_projection(x)  # Projects input to d_model dimensions
        x = x + self.pos_encoding(x)  # Apply positional encoding

        for transformer in self.transformers:
            x = transformer(x, x, x)  # Process through transformer blocks

        x = x.mean(dim=1, keepdim=True)  # Aggregate features
        x = self.output_layer(x)  # Apply the output layer to get the final prediction
        
        return x

In [46]:
class ObservationEmbeddingRepresentation(nn.Module):
    def __init__(
        self,
        observation_embedding: nn.Module,
    ):
        super().__init__()
        self.observation_embedding = observation_embedding

    def forward(self, obs: torch.Tensor):
        batch, seq = obs.size(0), obs.size(1)
        # Flatten batch and seq dims
        obs = torch.flatten(obs, start_dim=0, end_dim=1)
        obs_embed = self.observation_embedding(obs)
        obs_embed = obs_embed.reshape(batch, seq, obs_embed.size(-1))
        return obs_embed
    @staticmethod
    def make_action_representation(
        num_actions: int,
        action_dim: int,
    ) -> ObservationEmbeddingRepresentation:
        embed = nn.Sequential(
            nn.Embedding(num_actions, action_dim), nn.Flatten(start_dim=-2)
        )
        return ObservationEmbeddingRepresentation(observation_embedding=embed)

    @staticmethod
    def make_continuous_representation(obs_dim: int, outer_embed_size: int):
        """
        For use in continuous observation environments. Projects the observation to the
            specified dimensionality for use in the network.

        Args:
            obs_dim:    The length of the observation vector (assuming 1d)
            embed_size: The length of the resulting embedding vector
        """
        embedding = nn.Linear(obs_dim, outer_embed_size)
        return ObservationEmbeddingRepresentation(observation_embedding=embedding)

In [12]:
def compute_flattened_size(
    height: int, width: int, kernels: list, paddings: list, strides: list
) -> int:
    for i in range(len(kernels)):
        height = update_size(height, kernels[i], paddings[i], strides[i])
        width = update_size(width, kernels[i], paddings[i], strides[i])
    return int(height * width)

In [13]:
def update_size(component: int, kernel: int, padding: int, stride: int) -> int:
    return math.floor((component - kernel + 2 * padding) / stride) + 1

In [14]:
class ActionEmbeddingRepresentation(nn.Module):
    def __init__(self, num_actions: int, action_dim: int):
        super().__init__()
        self.embedding = nn.Sequential(
            nn.Embedding(num_actions, action_dim),
            nn.Flatten(start_dim=-2),
        )

    def forward(self, action: torch.Tensor):
        return self.embedding(action)

In [17]:
class DqnAgent:
    def __init__(
        self,
        network_factory: Callable[[], Module],
        buffer_size: int,
        device: torch.device,
        env_obs_length: int,
        max_env_steps: int,
        obs_mask: Union[int, float],
        num_actions: int,
        is_discrete_env: bool,
        learning_rate: float = 0.0003,
        batch_size: int = 32,
        context_len: int = 1,
        gamma: float = 0.99,
        grad_norm_clip: float = 1.0,
        target_update_frequency: int = 10_000,
        **kwargs,
    ):
        self.context_len = context_len
        self.env_obs_length = env_obs_length
        # Initialize environment & networks
        self.policy_network = network_factory()
        self.target_network = network_factory()
        # Ensure network's parameters are the same
        self.target_update()
        self.target_network.eval()

        # We can be more efficient with space if we are using discrete environments
        # and don't need to use floats
        if is_discrete_env:
            self.obs_context_type = np.int_
            self.obs_tensor_type = torch.long
        else:
            self.obs_context_type = np.float32
            self.obs_tensor_type = torch.float32

        # PyTorch config
        self.device = device

        self.optimizer = optim.Adam(self.policy_network.parameters(), lr=learning_rate)

        self.replay_buffer = ReplayBuffer(
            buffer_size,
            env_obs_length=env_obs_length,
            obs_mask=obs_mask,
            max_episode_steps=max_env_steps,
            context_len=context_len,
        )

        # Hyperparameters
        self.batch_size = batch_size
        self.gamma = gamma
        self.grad_norm_clip = grad_norm_clip
        self.target_update_frequency = target_update_frequency

        # Logging
        self.num_train_steps = 0
        self.td_errors = RunningAverage(100)
        self.grad_norms = RunningAverage(100)
        self.qvalue_max = RunningAverage(100)
        self.target_max = RunningAverage(100)
        self.qvalue_mean = RunningAverage(100)
        self.target_mean = RunningAverage(100)
        self.qvalue_min = RunningAverage(100)
        self.target_min = RunningAverage(100)

        self.num_actions = num_actions
        self.train_mode = TrainMode.TRAIN
        self.obs_mask = obs_mask

        self.train_context = Context(
            context_len, obs_mask, self.num_actions, env_obs_length
        )
        self.eval_context = Context(
            context_len, obs_mask, self.num_actions, env_obs_length
        )

    @property
    def context(self) -> Context:
        if self.train_mode == TrainMode.TRAIN:
            return self.train_context
        elif self.train_mode == TrainMode.EVAL:
            return self.eval_context

    def eval_on(self) -> None:
        self.train_mode = TrainMode.EVAL
        self.policy_network.eval()

    def eval_off(self) -> None:
        self.train_mode = TrainMode.TRAIN
        self.policy_network.train()

    @torch.no_grad()
    def get_action(self, epsilon=0.0) -> int:
        """Use policy_network to get an e-greedy action given the current obs."""
        if RNG.rng.random() < epsilon:
            return RNG.rng.integers(self.num_actions)
        q_values = self.policy_network(
            torch.as_tensor(
                self.context.obs[min(self.context.timestep, self.context_len - 1)],
                dtype=self.obs_tensor_type,
                device=self.device,
            )
            .unsqueeze(0)
            .unsqueeze(0)
        )
        return torch.argmax(q_values).item()

    def observe(self, obs, action, reward, done) -> None:
        if self.train_mode == TrainMode.TRAIN:
            self.replay_buffer.store(obs, action, reward, done)

    def context_reset(self, obs: np.ndarray) -> None:
        self.context.reset(obs)
        if self.train_mode == TrainMode.TRAIN:
            self.replay_buffer.store_obs(obs)

    def train(self) -> None:
        """Perform one gradient step of the network"""
        if not self.replay_buffer.can_sample(self.batch_size):
            return

        self.eval_off()
        obss, actions, rewards, next_obss, _, dones, _ = self.replay_buffer.sample(
            self.batch_size
        )

        # We pull obss/next_obss as [batch-size x 1 x obs-dim]
        obss = torch.as_tensor(obss, dtype=self.obs_tensor_type, device=self.device)
        next_obss = torch.as_tensor(
            next_obss, dtype=self.obs_tensor_type, device=self.device
        )
        # Actions is [batch-size x 1 x 1] which we want to be [batch-size x 1]
        actions = torch.as_tensor(actions, dtype=torch.long, device=self.device)
        # Rewards/Dones are [batch-size x 1 x 1] which we want to be [batch-size]
        rewards = torch.as_tensor(
            rewards, dtype=torch.float32, device=self.device
        ).squeeze()
        dones = torch.as_tensor(dones, dtype=torch.long, device=self.device).squeeze()

        # obss is [batch-size x obs-dim] and after network is [batch-size x action-dim]
        # Then we gather it and squeeze to [batch-size]
        q_values = self.policy_network(obss)
        # [batch-seq-actions]
        q_values = q_values.gather(2, actions).squeeze()

        with torch.no_grad():
            # We use DDQN, so the policy network determines which future actions we'd
            # take, but the target network determines the value of those
            next_obs_qs = self.policy_network(next_obss)
            argmax = torch.argmax(next_obs_qs, axis=-1).unsqueeze(-1)
            next_obs_q_values = (
                self.target_network(next_obss).gather(2, argmax).squeeze()
            )

            # here goes BELLMAN
            targets = rewards + (1 - dones) * (next_obs_q_values * self.gamma)

        self.qvalue_max.add(q_values.max().item())
        self.qvalue_mean.add(q_values.mean().item())
        self.qvalue_min.add(q_values.min().item())

        self.target_max.add(targets.max().item())
        self.target_mean.add(targets.mean().item())
        self.target_min.add(targets.min().item())

        # Optimization step
        loss = F.mse_loss(q_values, targets)
        self.td_errors.add(loss.item())
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        norm = torch.nn.utils.clip_grad_norm_(
            self.policy_network.parameters(),
            self.grad_norm_clip,
            error_if_nonfinite=True,
        )
        self.grad_norms.add(norm.item())
        self.optimizer.step()
        self.num_train_steps += 1

        if self.num_train_steps % self.target_update_frequency == 0:
            self.target_update()

    def target_update(self) -> None:
        """Hard update where we copy the network parameters from the policy network to the target network"""
        self.target_network.load_state_dict(self.policy_network.state_dict())

    def save_mini_checkpoint(self, checkpoint_dir: str, wandb_id: str) -> None:
        torch.save(
            {"step": self.num_train_steps, "wandb_id": wandb_id},
            checkpoint_dir + "_mini_checkpoint.pt",
        )

    @staticmethod
    def load_mini_checkpoint(checkpoint_dir: str) -> dict:
        return torch.load(checkpoint_dir + "_mini_checkpoint.pt")

    def save_checkpoint(
        self,
        checkpoint_dir: str,
        wandb_id: str,
        episode_successes: RunningAverage,
        episode_rewards: RunningAverage,
        episode_lengths: RunningAverage,
        eps: LinearAnneal,
    ) -> None:
        self.save_mini_checkpoint(checkpoint_dir=checkpoint_dir, wandb_id=wandb_id)
        torch.save(
            # np.savez_compressed(
            # checkpoint_dir + "_checkpoint",
            {
                "step": self.num_train_steps,
                "wandb_id": wandb_id,
                # Replay Buffer: Don't keep the observation index saved
                "replay_buffer_pos": [self.replay_buffer.pos[0], 0],
                # Neural Net
                "policy_net_state_dict": self.policy_network.state_dict(),
                "target_net_state_dict": self.target_network.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "epsilon": eps.val,
                # Results
                "episode_successes": episode_successes,
                "episode_rewards": episode_rewards,
                "episode_lengths": episode_lengths,
                # Losses
                "td_errors": self.td_errors,
                "grad_norms": self.grad_norms,
                "qvalue_max": self.qvalue_max,
                "qvalue_mean": self.qvalue_mean,
                "qvalue_min": self.qvalue_min,
                "target_max": self.target_max,
                "target_mean": self.target_mean,
                "target_min": self.target_min,
                # RNG states
                "random_rng_state": random.getstate(),
                "rng_bit_generator_state": RNG.rng.bit_generator.state,
                "numpy_rng_state": np.random.get_state(),
                "torch_rng_state": torch.get_rng_state(),
                "torch_cuda_rng_state": torch.cuda.get_rng_state()
                if torch.cuda.is_available()
                else torch.get_rng_state(),
            },
            checkpoint_dir + "_checkpoint.pt",
        )
        joblib.dump(self.replay_buffer.obss, checkpoint_dir + "buffer_obss.sav")
        joblib.dump(self.replay_buffer.actions, checkpoint_dir + "buffer_actions.sav")
        joblib.dump(self.replay_buffer.rewards, checkpoint_dir + "buffer_rewards.sav")
        joblib.dump(self.replay_buffer.dones, checkpoint_dir + "buffer_dones.sav")
        joblib.dump(
            self.replay_buffer.episode_lengths, checkpoint_dir + "buffer_eplens.sav"
        )

    def load_checkpoint(
        self, checkpoint_dir: str
    ) -> Tuple[str, RunningAverage, RunningAverage, RunningAverage, float]:
        checkpoint = torch.load(checkpoint_dir + "_checkpoint.pt")
        # checkpoint = np.load(checkpoint_dir + "_checkpoint.npz", allow_pickle=True)

        self.num_train_steps = checkpoint["step"]
        # Replay Buffer
        self.replay_buffer.pos = checkpoint["replay_buffer_pos"]
        self.replay_buffer.obss = joblib.load(checkpoint_dir + "buffer_obss.sav")
        self.replay_buffer.actions = joblib.load(checkpoint_dir + "buffer_actions.sav")
        self.replay_buffer.rewards = joblib.load(checkpoint_dir + "buffer_rewards.sav")
        self.replay_buffer.dones = joblib.load(checkpoint_dir + "buffer_dones.sav")
        self.replay_buffer.episode_lengths = joblib.load(
            checkpoint_dir + "buffer_eplens.sav"
        )
        # Neural Net
        self.policy_network.load_state_dict(checkpoint["policy_net_state_dict"])
        self.target_network.load_state_dict(checkpoint["target_net_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        # Losses
        self.td_errors = checkpoint["td_errors"]
        self.grad_norms = checkpoint["grad_norms"]
        self.qvalue_max = checkpoint["qvalue_max"]
        self.qvalue_mean = checkpoint["qvalue_mean"]
        self.qvalue_min = checkpoint["qvalue_min"]
        self.target_max = checkpoint["target_max"]
        self.target_mean = checkpoint["target_mean"]
        self.target_min = checkpoint["target_min"]
        # RNG states
        random.setstate(checkpoint["random_rng_state"])
        RNG.rng.bit_generator.state = checkpoint["rng_bit_generator_state"]
        np.random.set_state(checkpoint["numpy_rng_state"])
        torch.set_rng_state(checkpoint["torch_rng_state"])
        if torch.cuda.is_available():
            torch.cuda.set_rng_state(checkpoint["torch_cuda_rng_state"])

        # Results
        episode_successes = checkpoint["episode_successes"]
        episode_rewards = checkpoint["episode_rewards"]
        episode_lengths = checkpoint["episode_lengths"]
        # Exploration value
        epsilon = checkpoint["epsilon"]

        return (
            checkpoint["wandb_id"],
            episode_successes,
            episode_rewards,
            episode_lengths,
            epsilon,
        )

In [48]:
class DTQN(nn.Module):
    """Deep Transformer Q-Network for partially observable reinforcement learning.

    Args:
        obs_dim:            The length of the observation vector.
        num_actions:        The number of possible environments actions.
        embed_per_obs_dim:  Used for discrete observation space. Length of the embed for each
            element in the observation dimension.
        action_dim:         The number of features to give the action.
        inner_embed_size:   The dimensionality of the network. Referred to as d_k by the
            original transformer.
        num_heads:          The number of heads to use in the MultiHeadAttention.
        num_layers:         The number of transformer blocks to use.
        history_len:        The maximum number of observations to take in.
        dropout:            Dropout percentage. Default: `0.0`
        gate:               Which layer to use after the attention and feedforward submodules (choices: `res`
            or `gru`). Default: `res`
        identity:           Whether or not to use identity map reordering. Default: `False`
        pos:                The kind of position encodings to use. `0` uses no position encodings, `1` uses
            learned position encodings, and `sin` uses sinusoidal encodings. Default: `1`
        discrete:           Whether or not the environment has discrete observations. Default: `False`
        vocab_sizes:        If discrete env only. Represents the number of observations in the
            environment. If the environment has multiple obs dims with different number
            of observations in each dim, this can be supplied as a vector. Default: `None`
    """

    def __init__(
        self,
        obs_dim: int,
        num_actions: int,
        embed_per_obs_dim: int,
        action_dim: int,
        inner_embed_size: int,
        num_heads: int,
        num_layers: int,
        history_len: int,
        dropout: float = 0.0,
        identity: bool = False,
        pos: Union[str, int] = 1,
        discrete: bool = False,
        vocab_sizes: Optional[Union[np.ndarray, int]] = None,
        bag_size: int = 0,
        **kwargs,
    ):
        super().__init__()
        self.obs_dim = obs_dim
        self.discrete = discrete
        # Input Embedding: Allocate space for the action embedding
        obs_output_dim = inner_embed_size - action_dim
        if action_dim > 0:
            self.action_embedding = ActionEmbeddingRepresentation(
                num_actions=num_actions, action_dim=action_dim
            )
        else:
            self.action_embedding = None
        self.obs_embedding = (
            ObservationEmbeddingRepresentation.make_continuous_representation(
                obs_dim=obs_dim, outer_embed_size=obs_output_dim
            )
        )

        # PosEnum.LEARNED/SIN/NONE
        pos_function_map = PositionEncoding
        self.position_embedding = pos_function_map

        self.dropout = nn.Dropout(dropout)

        if identity:
            transformer_block = TransformerIdentityLayer
        else:
            transformer_block = TransformerLayer
        self.transformer_layers = nn.Sequential(
            *[
                transformer_block(
                    num_heads,
                    inner_embed_size,
                    history_len,
                    dropout,
                    attn_gate,
                    mlp_gate,
                )
                for _ in range(num_layers)
            ]
        )

        self.bag_size = bag_size
        self.bag_attn_weights = None
        if bag_size > 0:
            self.bag_attention = nn.MultiheadAttention(
                inner_embed_size,
                num_heads,
                dropout=dropout,
                batch_first=True,
            )
            self.ffn = nn.Sequential(
                nn.Linear(inner_embed_size * 2, inner_embed_size),
                nn.ReLU(),
                nn.Linear(inner_embed_size, num_actions),
            )
        else:
            self.ffn = nn.Sequential(
                nn.Linear(inner_embed_size, inner_embed_size),
                nn.ReLU(),
                nn.Linear(inner_embed_size, num_actions),
            )

        self.history_len = history_len
        self.apply(torch_utils.init_weights)

    def forward(
        self,
        obss: torch.Tensor,
        actions: Optional[torch.Tensor] = None,
        bag_obss: Optional[torch.Tensor] = None,
        bag_actions: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        obss    is  batch x seq_len  x obs_dim
        actions is  batch x seq_len  x       1
        bag     is  batch x bag_size x obs_dim
        """
        history_len = obss.size(1)
        assert (
            history_len <= self.history_len
        ), "Cannot forward, history is longer than expected."

        # If the observations are images, obs_dim is the dimensions of the image
        obs_dim = obss.size()[2:] if len(obss.size()) > 3 else obss.size(2)
        assert (
            obs_dim == self.obs_dim
        ), f"Obs dim is incorrect. Expected {self.obs_dim} got {obs_dim}"

        token_embeddings = self.obs_embedding(obss)

        # Just to keep shapes correct if we choose to disble including actions
        if self.action_embedding is not None:
            # [batch x seq_len x 1] -> [batch x seq_len x action_embed]
            action_embed = self.action_embedding(actions)

            if history_len > 1:
                action_embed = torch.roll(action_embed, 1, 1)
                # First observation in the sequence doesn't have a previous action, so zero the features
                action_embed[:, 0, :] = 0.0
            token_embeddings = torch.concat([action_embed, token_embeddings], dim=-1)

        # [batch x seq_len x model_embed] -> [batch x seq_len x model_embed]
        working_memory = self.transformer_layers(
            self.dropout(
                token_embeddings + self.position_embedding()[:, :history_len, :]
            )
        )

        if self.bag_size > 0:
            # [batch x bag_size x action_embed] + [batch x bag_size x obs_embed] -> [batch x bag_size x model_embed]
            if self.action_embedding is not None:
                bag_embeddings = torch.concat(
                    [self.action_embedding(bag_actions), self.obs_embedding(bag_obss)],
                    dim=-1,
                )
            else:
                bag_embeddings = self.obs_embedding(bag_obss)
            # [batch x seq_len x model_embed] x [batch x bag_size x model_embed] -> [batch x seq_len x model_embed]
            persistent_memory, self.attn_weights = self.bag_attention(
                working_memory, bag_embeddings, bag_embeddings
            )
            output = self.ffn(torch.concat([working_memory, persistent_memory], dim=-1))
        else:
            output = self.ffn(working_memory)

        return output[:, -history_len:, :]

In [49]:
class TrainMode(Enum):
    TRAIN = 1
    EVAL = 2

In [18]:
class DtqnAgent(DqnAgent):
    # noinspection PyTypeChecker
    def __init__(
        self,
        network_factory: Callable[[], Module],
        buffer_size: int,
        device: torch.device,
        env_obs_length: int,
        max_env_steps: int,
        obs_mask: Union[int, float],
        num_actions: int,
        is_discrete_env: bool,
        learning_rate: float = 0.0003,
        batch_size: int = 32,
        context_len: int = 50,
        gamma: float = 0.99,
        grad_norm_clip: float = 1.0,
        target_update_frequency: int = 10_000,
        history: int = 50,
        bag_size: int = 0,
        **kwargs,
    ):
        super().__init__(
            network_factory,
            buffer_size,
            device,
            env_obs_length,
            max_env_steps,
            obs_mask,
            num_actions,
            is_discrete_env,
            learning_rate,
            batch_size,
            context_len,
            gamma,
            grad_norm_clip,
            target_update_frequency,
        )
        self.history = history
        self.train_context = Context(
            context_len,
            obs_mask,
            num_actions,
            env_obs_length,
        )
        self.eval_context = Context(
            context_len,
            obs_mask,
            num_actions,
            env_obs_length,
        )
        self.train_bag = Bag(bag_size, obs_mask, env_obs_length)
        self.eval_bag = Bag(bag_size, obs_mask, env_obs_length)

    @property
    def bag(self) -> Bag:
        if self.train_mode == TrainMode.TRAIN:
            return self.train_bag
        elif self.train_mode == TrainMode.EVAL:
            return self.eval_bag

    @torch.no_grad()
    def get_action(self, epsilon: float = 0.0) -> int:
        if RNG.rng.random() < epsilon:
            return RNG.rng.integers(self.num_actions)
        # Truncate the context of observations and actions to remove padding if it exists
        context_obs_tensor = torch.as_tensor(
            self.context.obs[: min(self.context.max_length, self.context.timestep + 1)],
            dtype=self.obs_tensor_type,
            device=self.device,
        ).unsqueeze(0)
        context_action_tensor = torch.as_tensor(
            self.context.action[
                : min(self.context.max_length, self.context.timestep + 1)
            ],
            dtype=torch.long,
            device=self.device,
        ).unsqueeze(0)
        # Always include the full bag, even if it has padding TODO:
        bag_obs_tensor = torch.as_tensor(
            self.bag.obss, dtype=self.obs_tensor_type, device=self.device
        ).unsqueeze(0)
        bag_action_tensor = torch.as_tensor(
            self.bag.actions, dtype=torch.long, device=self.device
        ).unsqueeze(0)

        q_values = self.policy_network(
            context_obs_tensor, context_action_tensor, bag_obs_tensor, bag_action_tensor
        )

        # We take the argmax of the last timestep's Q values
        # In other words, select the highest q value action
        return torch.argmax(q_values[:, -1, :]).item()

    def context_reset(self, obs: np.ndarray) -> None:
        self.context.reset(obs)
        if self.train_mode == TrainMode.TRAIN:
            self.replay_buffer.store_obs(obs)
        if self.bag.size > 0:
            self.bag.reset()

    def observe(self, obs: np.ndarray, action: int, reward: float, done: bool) -> None:
        """Add an observation to the context. If the context would evict an observation to make room,
        attempt to put the observation in the bag, which may require evicting something else from the bag.

        If we're in train mode, then we also add the transition to our replay buffer."""
        evicted_obs, evicted_action = self.context.add_transition(
            obs, action, reward, done
        )
        # If there is an evicted obs, we need to decide if it should go in the bag or not
        if self.bag.size > 0 and evicted_obs is not None:
            # Bag is already full
            if not self.bag.add(evicted_obs, evicted_action):
                # For each possible bag, get the Q-values
                possible_bag_obss = np.tile(self.bag.obss, (self.bag.size + 1, 1, 1))
                possible_bag_actions = np.tile(
                    self.bag.actions, (self.bag.size + 1, 1, 1)
                )
                for i in range(self.bag.size):
                    possible_bag_obss[i, i] = evicted_obs
                    possible_bag_actions[i, i] = evicted_action
                tiled_context = np.tile(self.context.obs, (self.bag.size + 1, 1, 1))
                tiled_actions = np.tile(self.context.action, (self.bag.size + 1, 1, 1))
                q_values = self.policy_network(
                    torch.as_tensor(
                        tiled_context, dtype=self.obs_tensor_type, device=self.device
                    ),
                    torch.as_tensor(
                        tiled_actions, dtype=torch.long, device=self.device
                    ),
                    torch.as_tensor(
                        possible_bag_obss,
                        dtype=self.obs_tensor_type,
                        device=self.device,
                    ),
                    torch.as_tensor(
                        possible_bag_actions, dtype=torch.long, device=self.device
                    ),
                )

                bag_idx = torch.argmax(torch.mean(torch.max(q_values, 2)[0], 1))
                self.bag.obss = possible_bag_obss[bag_idx]
                self.bag.actions = possible_bag_actions[bag_idx]

        if self.train_mode == TrainMode.TRAIN:
            self.replay_buffer.store(obs, action, reward, done, self.context.timestep)

    def train(self) -> None:
        if not self.replay_buffer.can_sample(self.batch_size):
            return
        self.eval_off()
        if self.bag.size > 0:
            (
                obss,
                actions,
                rewards,
                next_obss,
                next_actions,
                dones,
                episode_lengths,
                bag_obss,
                bag_actions,
            ) = self.replay_buffer.sample_with_bag(self.batch_size, self.bag)
            # Bags: [batch-size x bag-size x obs-dim]
            bag_obss = torch.as_tensor(
                bag_obss, dtype=self.obs_tensor_type, device=self.device
            )
            bag_actions = torch.as_tensor(
                bag_actions, dtype=torch.long, device=self.device
            )
        else:
            (
                obss,
                actions,
                rewards,
                next_obss,
                next_actions,
                dones,
                episode_lengths,
            ) = self.replay_buffer.sample(self.batch_size)
            bag_obss = None
            bag_actions = None

        # Obss and Next obss: [batch-size x hist-len x obs-dim]
        obss = torch.as_tensor(obss, dtype=self.obs_tensor_type, device=self.device)
        next_obss = torch.as_tensor(
            next_obss, dtype=self.obs_tensor_type, device=self.device
        )
        # Actions: [batch-size x hist-len x 1]
        actions = torch.as_tensor(actions, dtype=torch.long, device=self.device)
        next_actions = torch.as_tensor(
            next_actions, dtype=torch.long, device=self.device
        )
        # Rewards: [batch-size x hist-len x 1]
        rewards = torch.as_tensor(rewards, dtype=torch.float32, device=self.device)
        # Dones: [batch-size x hist-len x 1]
        dones = torch.as_tensor(dones, dtype=torch.long, device=self.device)

        # obss is [batch-size x hist-len x obs-len]
        # then q_values is [batch-size x hist-len x n-actions]
        q_values = self.policy_network(obss, actions, bag_obss, bag_actions)

        # After gathering, Q values becomes [batch-size x hist-len x 1] then
        # after squeeze becomes [batch-size x hist-len]
        q_values = q_values.gather(2, actions).squeeze()

        with torch.no_grad():
            # Next obss goes from [batch-size x hist-len x obs-dim] to
            # [batch-size x hist-len x n-actions] and then goes through gather and squeeze
            # to become [batch-size x hist-len]
            if self.history:
                argmax = torch.argmax(
                    self.policy_network(next_obss, next_actions, bag_obss, bag_actions),
                    dim=2,
                ).unsqueeze(-1)
                next_obs_q_values = self.target_network(
                    next_obss, next_actions, bag_obss, bag_actions
                )
                next_obs_q_values = next_obs_q_values.gather(2, argmax).squeeze()

            # here goes BELLMAN
            targets = rewards.squeeze() + (1 - dones.squeeze()) * (
                next_obs_q_values * self.gamma
            )

        q_values = q_values[:, -self.history :]
        targets = targets[:, -self.history :]
        # Calculate loss
        loss = F.mse_loss(q_values, targets)
        # Log Losses
        self.qvalue_max.add(q_values.max().item())
        self.qvalue_mean.add(q_values.mean().item())
        self.qvalue_min.add(q_values.min().item())

        self.target_max.add(targets.max().item())
        self.target_mean.add(targets.mean().item())
        self.target_min.add(targets.min().item())

        self.td_errors.add(loss.item())
        # Optimization step
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        norm = torch.nn.utils.clip_grad_norm_(
            self.policy_network.parameters(),
            self.grad_norm_clip,
            error_if_nonfinite=True,
        )
        # Logging
        self.grad_norms.add(norm.item())

        self.optimizer.step()
        self.num_train_steps += 1

        if self.num_train_steps % self.target_update_frequency == 0:
            self.target_update()

# Utils

In [22]:
MODEL_MAP = {
    "DTQN": DTQN,
    "DTQN-bag": DTQN,
}

In [23]:
AGENT_MAP = {
    "DTQN": DtqnAgent,
    "DTQN-bag": DtqnAgent,
}

In [24]:
def get_agent(
    model_str: str,
    envs: Tuple[gym.Env],
    embed_per_obs_dim: int,
    action_dim: int,
    inner_embed: int,
    buffer_size: int,
    device: torch.device,
    learning_rate: float,
    batch_size: int,
    context_len: int,
    max_env_steps: int,
    history: int,
    target_update_frequency: int,
    gamma: float,
    num_heads: int = 1,
    num_layers: int = 1,
    dropout: float = 0.0,
    identity: bool = False,
    gate: str = "res",
    pos: str = "learned",
    bag_size: int = 0,
):
    """Function to create the agent. This will also set up the policy and target networks that the agent needs.
    Arguments:
        model_str: str, the name of the Q-function model we are going to use.
        envs: Tuple[gym.Env], a list of gym environments the agent will train and evaluate on. They must all have the same observation and action space.
        ember_per_obs_dim: int, the number of features to give each dimension of the observation. This is only used for discrete domains.
        action_dim: int, the number of features to give each action.
        inner_embed: int, the size of the main transformer model.
        buffer_size: int, the number of transitions to store in the replay buffer.
        device: torch.device, the device to use for training.
        learning_rate: float, the learning rate for the ADAM optimiser.
        batch_size: int, the batch size to use for training.
        context_len: int, the maximum sequence length to use as input to the network.
        max_env_steps: int, the maximum number of steps allowed in the environment before timeout. This will be inferred if not explicitly supplied.
        history: int, the number of Q-values to use during training for each sample.
        target_update_frequency: int, the number of training steps between (hard) target network update.
        gamma: float, the discount factor.
        -DTQN-Specific-
        num_heads: int, the number of heads to use in the MultiHeadAttention.
        num_layers: int, the number of transformer blocks to use.
        dropout: float, the dropout percentage to use.
        identity: bool, whether or not to use identity map reordering.
        gate: str, which combine step to use (residual skip connection or GRU)
        pos: str, which type of position encoding to use ("learned", "sin", or "none")
        bag_size: int, the size of the persistent memory bag

    Returns:
        the agent we created with all those arguments, complete with replay buffer, context, policy and target network.
    """
    # All envs must have the same observation shape
    env_obs_length = env_processing.get_env_obs_length(envs[0])
    env_obs_mask = env_processing.get_env_obs_mask(envs[0])
    if max_env_steps <= 0:
        max_env_steps = max([env_processing.get_env_max_steps(env) for env in envs])
    if isinstance(env_obs_mask, np.ndarray):
        obs_vocab_size = env_obs_mask.max() + 1
    else:
        obs_vocab_size = env_obs_mask + 1
    is_discrete_env = isinstance(
        envs[0].observation_space,
        (gym.spaces.Discrete, gym.spaces.MultiDiscrete, gym.spaces.MultiBinary),
    )
    # Keep the history between 1 and context length
    if history < 1 or history > context_len:
        print(
            f"History must be 1 < history <= context_len, but history is {history} and context len is {context_len}. Clipping history to {np.clip(history, 1, context_len)}..."
        )
        history = np.clip(history, 1, context_len)
    # All envs must share same action space
    num_actions = envs[0].action_space.n

    def make_dtqn(network_cls):
        """Creates DTQN"""
        return lambda: network_cls(
            env_obs_length,
            num_actions,
            embed_per_obs_dim,
            action_dim,
            inner_embed,
            num_heads,
            num_layers,
            context_len,
            dropout=dropout,
            gate=gate,
            identity=identity,
            pos=pos,
            discrete=is_discrete_env,
            vocab_sizes=obs_vocab_size,
            target_update_frequency=target_update_frequency,
            bag_size=bag_size,
        ).to(device)
    network_factory = make_dtqn(MODEL_MAP[model_str])
    return AGENT_MAP[model_str](
        network_factory,
        buffer_size,
        device,
        env_obs_length,
        max_env_steps,
        env_obs_mask,
        num_actions,
        is_discrete_env,
        learning_rate=learning_rate,
        batch_size=batch_size,
        gamma=gamma,
        context_len=context_len,
        embed_size=inner_embed,
        history=history,
        target_update_frequency=target_update_frequency,
        bag_size=bag_size,
    )

In [25]:
class Bag:
    """A Dataclass dedicated to storing important observations that would have fallen out of the agent's context

    Args:
        bag_size: Size of bag
        obs_mask: The mask to use to indicate the observation is padding
        obs_length: shape of an observation
    """

    def __init__(self, bag_size: int, obs_mask: Union[int, float], obs_length: int):
        self.size = bag_size
        self.obs_mask = obs_mask
        self.obs_length = obs_length
        # Current position in bag
        self.pos = 0

        self.obss, self.actions = self.make_empty_bag()

    def reset(self) -> None:
        self.pos = 0
        self.obss, self.actions = self.make_empty_bag()

    def add(self, obs: np.ndarray, action: int) -> bool:
        if not self.is_full:
            self.obss[self.pos] = obs
            self.actions[self.pos] = action
            self.pos += 1
            return True
        else:
            # Reject adding the observation-action
            return False

    def export(self) -> Tuple[np.ndarray, np.ndarray]:
        return self.obss[: self.pos], self.actions[: self.pos]

    def make_empty_bag(self) -> np.ndarray:
        # Image
        if isinstance(self.obs_length, tuple):
            return np.full((self.size, *self.obs_length), self.obs_mask), np.full(
                (self.size, 1), 0
            )
        # Non-Image
        else:
            return np.full((self.size, self.obs_length), self.obs_mask), np.full(
                (self.size, 1), 0
            )

    @property
    def is_full(self) -> bool:
        return self.pos >= self.size

In [26]:
# noinspection PyAttributeOutsideInit
class Context:
    """A Dataclass dedicated to storing the agent's history (up to the previous `max_length` transitions)

    Args:
        context_length: The maximum number of transitions to store
        obs_mask: The mask to use for observations not yet seen
        num_actions: The number of possible actions we can take in the environment
        env_obs_length: The dimension of the observations (assume 1d arrays)
        init_hidden: The initial value of the hidden states (used for RNNs)
    """

    def __init__(
        self,
        context_length: int,
        obs_mask: int,
        num_actions: int,
        env_obs_length: int,
        init_hidden: Tuple[torch.Tensor] = None,
    ):
        self.max_length = context_length
        self.env_obs_length = env_obs_length
        self.num_actions = num_actions
        self.obs_mask = obs_mask
        self.reward_mask = 0.0
        self.done_mask = True
        self.timestep = 0
        self.init_hidden = init_hidden

    def reset(self, obs: np.ndarray):
        """Resets to a fresh context"""
        # Account for images
        if isinstance(self.env_obs_length, tuple):
            self.obs = np.full(
                [self.max_length, *self.env_obs_length],
                self.obs_mask,
                dtype=np.uint8,
            )
        else:
            self.obs = np.full([self.max_length, self.env_obs_length], self.obs_mask)
        # Initial observation
        self.obs[0] = obs

        self.action = RNG.rng.integers(self.num_actions, size=(self.max_length, 1))
        self.reward = np.full_like(self.action, self.reward_mask)
        self.done = np.full_like(self.reward, self.done_mask, dtype=np.int32)
        self.hidden = self.init_hidden
        self.timestep = 0

    def add_transition(
        self, o: np.ndarray, a: int, r: float, done: bool
    ) -> Tuple[Union[np.ndarray, None], Union[int, None]]:
        """Add an entire transition. If the context is full, evict the oldest transition"""
        self.timestep += 1
        self.obs = self.roll(self.obs)
        self.action = self.roll(self.action)
        self.reward = self.roll(self.reward)
        self.done = self.roll(self.done)

        t = min(self.timestep, self.max_length - 1)

        # If we are going to evict an observation, we need to return it for possibly adding to the bag
        evicted_obs = None
        evicted_action = None
        if self.is_full:
            evicted_obs = self.obs[t].copy()
            evicted_action = self.action[t]

        self.obs[t] = o
        self.action[t] = np.array([a])
        self.reward[t] = np.array([r])
        self.done[t] = np.array([done])

        return evicted_obs, evicted_action

    def export(
        self,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Export the context"""
        current_timestep = min(self.timestep, self.max_length) - 1
        return (
            self.obs[current_timestep + 1],
            self.action[current_timestep],
            self.reward[current_timestep],
            self.done[current_timestep],
        )

    def roll(self, arr: np.ndarray) -> np.ndarray:
        """Utility function to help with insertions at the end of the array. If the context is full, we replace the first element with the new element, then 'roll' the new element to the end of the array"""
        return np.roll(arr, -1, axis=0) if self.timestep >= self.max_length else arr

    @property
    def is_full(self) -> bool:
        return self.timestep >= self.max_length

    @staticmethod
    def context_like(context):
        """Creates a new context to mimic the supplied context"""
        return Context(
            context.max_length,
            context.obs_mask,
            context.num_actions,
            context.env_obs_length,
            init_hidden=context.init_hidden,
        )

In [27]:
class DTQNEnvironment(EnvBase):
    def __init__(self, A, B, C, D, dt, ref=1, device="cpu"):
        super(DTQNEnvironment, self).__init__() # call the constructor of the base class
        
        # custom property intialization - unique to this environment
        self.dtype = np.float32

        self.A, self.B, self.C, self.D, self.dt, self.ref = A, B, C, D, dt, ref
        self.device = device # where does the outgoing data go?

        self.state_size = self.A.shape[0] # how many in the first dimension
        self.action_size = self.B.shape[1] # how many in the second dimension

        self.state = np.zeros((self.state_size, 1), dtype=self.dtype)
        
        # specs - needs to be initialized
        self.action_spec = BoundedTensorSpec(minimum=-1, maximum=1, shape=torch.Size([self.action_size])) # limit the action values

        #this is the requirement of the abstract class but state = observation so it is never        
        observation_spec = UnboundedContinuousTensorSpec(shape=torch.Size([self.state_size])) # unlimited observation space
        self.observation_spec = CompositeSpec(observation=observation_spec) # has to be CompositeSpec(not sure why)

        self.reward_spec = UnboundedContinuousTensorSpec(shape=torch.Size([1])) # unlimited reward space(even though we could limit it to (-inf, 0] in this particular example)

    def _reset(self, tensordict, **kwargs):
        
        # init new state and pack it up in a tensordict
        
        out_tensordict = TensorDict({}, batch_size=torch.Size())

        self.state = np.zeros((self.state_size, 1), dtype=self.dtype)
        out_tensordict.set("observation", torch.tensor(self.state.flatten(), device=self.device))

        return out_tensordict

    def _step(self, tensordict):
        #needs to be changed
        action = tensordict["action"]
        action = action.cpu().numpy().reshape((self.action_size, 1))

        self.state += self.dt * (self.A @ self.state + self.B @ action)

        y = self.C @ self.state + self.D @ action

        error = (self.ref - y) ** 2

        reward = -error

        out_tensordict = TensorDict({"observation": torch.tensor(self.state.astype(self.dtype).flatten(), device=self.device),
                                     "reward": torch.tensor(reward.astype(np.float32), device=self.device),
                                     "done": False}, batch_size=torch.Size())

        return out_tensordict

    def _set_seed(self, seed):
        pass

In [28]:
def get_env_obs_length(env: DTQNEnvironment) -> int:
    """Gets the length of the observations in an environment"""
    if isinstance(env.state_spec): 
        return env.state_size
    else:
        raise NotImplementedError(f"We do not yet support {env.observation_space}") #if nothing

In [29]:
def get_env_obs_mask(env: DTQNEnvironment) -> Union[int, np.ndarray]:
    """Gets the number of observations possible (for discrete case).
    For continuous case, please edit the -5 to something lower than
    lowest possible observation (while still being finite) so the
    network knows it is padding.
    """
    # changed to a variable (in agent utils) when creating the agent, passed in as a obs_vocab_size
    if isinstance(env.observation_space): # find the lowest possible indice that is realistic
        # If you would like to use DTQN with a continuous action space, make sure this value is
        #       below the minimum possible observation. Otherwise it will appear as a real observation
        #       to the network which may cause issues. In our case, Car Flag has min of -1 so this is
        #       fine.
        # find the lowest indice
        return -5
    else:
        raise NotImplementedError(f"We do not yet support {env.observation_space}")

In [30]:
def get_env_max_steps(env: DTQNEnvironment) -> Union[int, None]:
    """Gets the maximum steps allowed in an episode before auto-terminating"""
    try:
        return env._max_episode_steps
    except AttributeError:
        try:
            return env.max_episode_steps
        except AttributeError:
            return None

In [32]:
class EpsilonAnneal(ABC):
    @abstractmethod
    def anneal(self):
        pass

In [33]:
class Constant(EpsilonAnneal):
    def __init__(self, start):
        self.val = start

    def anneal(self):
        pass

In [34]:
class LinearAnneal(EpsilonAnneal):
    """Linear Annealing Schedule.

    Args:
        start:      The initial value of epsilon.
        end:        The final value of epsilon.
        duration:   The number of anneals from start value to end value.

    """

    def __init__(self, start: float, end: float, duration: int):
        self.val = start
        self.min = end
        self.duration = duration

    def anneal(self):
        self.val = max(self.min, self.val - (self.val - self.min) / self.duration)

In [35]:
class RunningAverage:
    def __init__(self, size):
        self.size = size
        self.q = deque()
        self.sum = 0

    def add(self, val):
        self.q.append(val)
        self.sum += val
        if len(self.q) > self.size:
            self.sum -= self.q.popleft()

    def mean(self):
        # Avoid divide by 0
        return self.sum / max(len(self.q), 1)

In [36]:
def timestamp():
    return datetime.now().strftime("%B %d, %H:%M:%S")

In [37]:
def wandb_init(config, group_keys, **kwargs) -> None:
    wandb.init(
        project=config["project_name"],
        group="_".join(
            [f"{key}={val}" for key, val in config.items() if key in group_keys]
        ),
        config=config,
        **kwargs,
    )

In [38]:
class CSVLogger:
    """Logger to write results to a CSV. The log function matches that of Weights and Biases.

    Args:
        path: path for the csv results file
    """

    def __init__(self, path: str, args: argparse.Namespace):
        self.results_path = path + "_results.csv"
        self.losses_path = path + "_losses.csv"
        self.envs = args.envs
        # If we have a checkpoint, we don't want to overwrite
        if not os.path.exists(self.results_path):
            head_row = ["Hours", "Step"]
            for env in self.envs:
                head_row += [
                    f"{env}/SuccessRate",
                    f"{env}/EpisodeLength",
                    f"{env}/Return",
                ]
            with open(self.results_path, "w") as file:
                writer = csv.writer(file)
                writer.writerow(head_row)
        if not os.path.exists(self.losses_path):
            with open(self.losses_path, "w") as file:
                writer = csv.writer(file)
                writer.writerow(
                    [
                        "Hours",
                        "Step",
                        "TD Error",
                        "Grad Norm",
                        "Max Q Value",
                        "Mean Q Value",
                        "Min Q Value",
                        "Max Target Value",
                        "Mean Target Value",
                        "Min Target Value",
                    ]
                )

    def log(self, results: Dict[str, str], step: int):
        results_row = [results["losses/hours"], step]
        for env in self.envs:
            results_row += [
                results[f"{env}/SuccessRate"],
                results[f"{env}/EpisodeLength"],
                results[f"{env}/Return"],
            ]
        with open(self.results_path, "a") as file:
            writer = csv.writer(file)
            writer.writerow(results_row)
        with open(self.losses_path, "a") as file:
            writer = csv.writer(file)
            writer.writerow(
                [
                    results["losses/hours"],
                    step,
                    results["losses/TD_Error"],
                    results["losses/Grad_Norm"],
                    results["losses/Max_Q_Value"],
                    results["losses/Mean_Q_Value"],
                    results["losses/Min_Q_Value"],
                    results["losses/Max_Target_Value"],
                    results["losses/Mean_Target_Value"],
                    results["losses/Min_Target_Value"],
                ]
            )

In [39]:
def get_logger(
    policy_path: str, args: argparse.Namespace, wandb_kwargs: Dict[str, str]
):
    if args.disable_wandb:
        logger = CSVLogger(policy_path, args)
    else:
        wandb_init(
            vars(args),
            [
                "model",
                "obs_embed",
                "a_embed",
                "in_embed",
                "context",
                "layers",
                "bag_size",
                "gate",
                "identity",
                "history",
                "pos",
            ],
            **wandb_kwargs,
        )
        logger = wandb
    return logger

In [40]:
class RNG:
    rng: np.random.Generator = None

In [41]:
def set_global_seed(seed: int, *args: Tuple[gym.Env]) -> None:
    """Sets seed for PyTorch, NumPy, and random.

    Args:
        seed: The random seed to use.
        args: The gym environment(s) to seed.
    """
    random.seed(seed)
    tseed = random.randint(1, 1e6)
    npseed = random.randint(1, 1e6)
    ospyseed = random.randint(1, 1e6)
    torch.manual_seed(tseed)
    np.random.seed(npseed)
    for env in args:
        env.seed(seed=seed)
        env.observation_space.seed(seed=seed)
        env.action_space.seed(seed=seed)
    os.environ["PYTHONHASHSEED"] = str(ospyseed)
    RNG.rng = np.random.Generator(np.random.PCG64(seed=seed))

In [42]:
#torch_utils
import torch.nn as nn

In [43]:
def init_weights(module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.MultiheadAttention):
        module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
        module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
        module.in_proj_bias.data.zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

In [44]:
class ReplayBuffer:
    """
    FIFO Replay Buffer which stores contexts of length ``context_len`` rather than single
        transitions

    Args:
        buffer_size: The number of transitions to store in the replay buffer
        env_obs_length: The size (length) of the environment's observation
        context_len: The number of transitions that will be stored as an agent's context. Default: 1
    """

    def __init__(
        self,
        buffer_size: int,
        env_obs_length: Union[int, Tuple],
        obs_mask: int,
        max_episode_steps: int,
        context_len: Optional[int] = 1,
    ):
        self.max_size = buffer_size // max_episode_steps
        self.context_len = context_len
        self.env_obs_length = env_obs_length
        self.max_episode_steps = max_episode_steps
        self.obs_mask = obs_mask
        self.pos = [0, 0]

        # Image domains
        if isinstance(env_obs_length, tuple):
            self.obss = np.full(
                [
                    self.max_size,
                    max_episode_steps + 1,  # Keeps first and last obs together for +1
                    *env_obs_length,
                ],
                obs_mask,
                dtype=np.uint8,
            )
        else:
            self.obss = np.full(
                [
                    self.max_size,
                    max_episode_steps + 1,  # Keeps first and last obs together for +1
                    env_obs_length,
                ],
                obs_mask,
                dtype=np.float32,
            )

        # Need the +1 so we have space to roll for the first observation
        self.actions = np.zeros(
            [self.max_size, max_episode_steps + 1, 1],
            dtype=np.uint8,
        )
        self.rewards = np.zeros(
            [self.max_size, max_episode_steps, 1],
            dtype=np.float32,
        )
        self.dones = np.ones(
            [self.max_size, max_episode_steps, 1],
            dtype=np.bool_,
        )
        self.episode_lengths = np.zeros([self.max_size], dtype=np.uint8)

    def store(
        self,
        obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        episode_length: Optional[int] = 0,
    ) -> None:
        episode_idx = self.pos[0] % self.max_size
        obs_idx = self.pos[1]
        self.obss[episode_idx, obs_idx + 1] = obs
        self.actions[episode_idx, obs_idx] = action
        self.rewards[episode_idx, obs_idx] = reward
        self.dones[episode_idx, obs_idx] = done
        self.episode_lengths[episode_idx] = episode_length
        self.pos = [self.pos[0], self.pos[1] + 1]

    def store_obs(self, obs: np.ndarray) -> None:
        """Use this at the beginning of the episode to store the first obs"""
        episode_idx = self.pos[0] % self.max_size
        self.cleanse_episode(episode_idx)
        self.obss[episode_idx, 0] = obs

    def can_sample(self, batch_size: int) -> bool:
        return batch_size < self.pos[0]

    def flush(self):
        self.pos = [self.pos[0] + 1, 0]

    def cleanse_episode(self, episode_idx: int) -> None:
        # Cleanse the episode of any previous data
        # Image domains
        if isinstance(self.env_obs_length, tuple):
            self.obss[episode_idx] = np.full(
                [
                    self.max_episode_steps
                    + 1,  # Keeps first and last obs together for +1
                    *self.env_obs_length,
                ],
                self.obs_mask,
                dtype=np.uint8,
            )
        else:
            self.obss[episode_idx] = np.full(
                [
                    self.max_episode_steps
                    + 1,  # Keeps first and last obs together for +1
                    self.env_obs_length,
                ],
                self.obs_mask,
                dtype=np.float32,
            )
        self.actions[episode_idx] = np.zeros(
            [self.max_episode_steps + 1, 1],
            dtype=np.uint8,
        )
        self.rewards[episode_idx] = np.zeros(
            [self.max_episode_steps, 1],
            dtype=np.float32,
        )
        self.dones[episode_idx] = np.ones(
            [self.max_episode_steps, 1],
            dtype=np.bool_,
        )
        self.episode_lengths[episode_idx] = 0

    def sample(
        self, batch_size: int
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        # Exclude the current episode we're in
        valid_episodes = [
            i
            for i in range(min(self.pos[0], self.max_size))
            if i != self.pos[0] % self.max_size
        ]
        episode_idxes = np.array(
            [[random.choice(valid_episodes)] for _ in range(batch_size)]
        )
        transition_starts = np.array(
            [
                random.randint(
                    0, max(0, self.episode_lengths[idx[0]] - self.context_len)
                )
                for idx in episode_idxes
            ]
        )
        transitions = np.array(
            [range(start, start + self.context_len) for start in transition_starts]
        )
        return (
            self.obss[episode_idxes, transitions],
            self.actions[episode_idxes, transitions],
            self.rewards[episode_idxes, transitions],
            self.obss[episode_idxes, 1 + transitions],
            self.actions[episode_idxes, 1 + transitions],
            self.dones[episode_idxes, transitions],
            np.clip(self.episode_lengths[episode_idxes], 0, self.context_len),
        )

    # TODO:
    def sample_with_bag(
        self, batch_size: int, sample_bag: Bag
    ) -> Tuple[
        np.ndarray,
        np.ndarray,
        np.ndarray,
        np.ndarray,
        np.ndarray,
        np.ndarray,
        np.ndarray,
        np.ndarray,
        np.ndarray,
    ]:
        episode_idxes = np.array(
            [
                # Exclude the current episode we're in
                [
                    random.choice(
                        [
                            i
                            for i in range(min(self.pos[0], self.max_size))
                            if i != self.pos[0]
                        ]
                    )
                ]
                for _ in range(batch_size)
            ]
        )
        transition_starts = np.array(
            [
                random.randint(
                    0, max(0, self.episode_lengths[idx[0]] - self.context_len)
                )
                for idx in episode_idxes
            ]
        )
        transitions = np.array(
            [range(start, start + self.context_len) for start in transition_starts]
        )

        # Create `batch_size` replica bags
        bag_obss = np.full(
            [batch_size, sample_bag.size, sample_bag.obs_length],
            sample_bag.obs_mask,
        )
        bag_actions = np.full(
            [batch_size, sample_bag.size, 1],
            0,
        )

        # Sample from the bag with observations that won't be in the main context
        for bag_idx in range(batch_size):
            # Possible bag is smaller than max bag size, so take all of it
            if transition_starts[bag_idx] < sample_bag.size:
                bag_obss[bag_idx, : transition_starts[bag_idx]] = self.obss[
                    episode_idxes[bag_idx], : transition_starts[bag_idx]
                ]
                bag_actions[bag_idx, : transition_starts[bag_idx]] = self.actions[
                    episode_idxes[bag_idx], : transition_starts[bag_idx]
                ]
            # Otherwise, randomly sample
            else:
                bag_obss[bag_idx] = np.array(
                    random.sample(
                        self.obss[episode_idxes[bag_idx], : transition_starts[bag_idx]]
                        .squeeze()
                        .tolist(),
                        k=sample_bag.size,
                    )
                )
                bag_actions[bag_idx] = np.expand_dims(
                    np.array(
                        random.sample(
                            self.actions[
                                episode_idxes[bag_idx], : transition_starts[bag_idx]
                            ]
                            .squeeze()
                            .tolist(),
                            k=sample_bag.size,
                        )
                    ),
                    axis=1,
                )
        return (
            self.obss[episode_idxes, transitions],
            self.actions[episode_idxes, transitions],
            self.rewards[episode_idxes, transitions],
            self.obss[episode_idxes, 1 + transitions],
            self.actions[episode_idxes, 1 + transitions],
            self.dones[episode_idxes, transitions],
            self.episode_lengths[episode_idxes],
            bag_obss,
            bag_actions,
        )

# Run Functions

In [None]:
def evaluate(
    agent,
    eval_env: Env,
    eval_episodes: int,
    render: Optional[bool] = None,
):
    """Evaluate the network for n_episodes using a greedy policy.

    Arguments:
        agent:          the agent to evaluate.
        eval_env:       gym.Env, the environment to use for the evaluation.
        eval_episodes:  int, the number of episodes to run.
        render:         bool, whether or not to render the timesteps for enjoy mode.

    Returns:
        mean_success:           float, number of successes divided by number of episodes.
        mean_return:            float, the average return per episode.
        mean_episode_length:    float, the average episode length.
    """
    # Set networks to eval mode (turns off dropout, etc.)
    agent.eval_on()

    total_reward = 0
    num_successes = 0
    total_steps = 0

    for _ in range(eval_episodes):
        agent.context_reset(eval_env.reset())
        done = False
        ep_reward = 0
        if render:
            eval_env.render()
            sleep(0.5)
        while not done:
            action = agent.get_action(epsilon=0.0)
            obs_next, reward, done, info = eval_env.step(action)
            agent.observe(obs_next, action, reward, done)
            ep_reward += reward
            if render:
                eval_env.render()
                if done:
                    print(f"Episode terminated. Episode reward: {ep_reward}")
                sleep(0.5)
        total_reward += ep_reward
        total_steps += agent.context.timestep
        if info.get("is_success", False) or ep_reward > 0:
            num_successes += 1

    # Set networks back to train mode
    agent.eval_off()
    # Prevent divide by 0
    episodes = max(eval_episodes, 1)
    return (
        num_successes / episodes,
        total_reward / episodes,
        total_steps / episodes,
    )

In [None]:
def train(
    agent,
    envs: Tuple[Env],
    eval_envs: Tuple[Env],
    env_strs: Tuple[str],
    total_steps: int,
    eps: epsilon_anneal.EpsilonAnneal,
    eval_frequency: int,
    eval_episodes: int,
    policy_path: str,
    save_policy: bool,
    logger,
    mean_success_rate: RunningAverage,
    mean_episode_length: RunningAverage,
    mean_reward: RunningAverage,
    time_remaining: Optional[int],
    verbose: bool = False,
) -> None:
    """Train the agent.

    Arguments:
        agent:              the agent to train.
        envs:               Tuple[gym.Env], the list of envs to train on.
        eval_envs:          Tuple[gym.Env], the list of envs to evaluate with.
        env_strs:           Tuple[str], the list of environment names.
        total_steps:        int, the total number of timesteps to train.
        eps:                EpsilonAnneal, the schedule to set for epsilon throughout training.
        eval_frequency:     int, the number of training steps between evaluation periods.
        eval_episodes:      int, the number of episodes to evaluate on for each eval period.
        policy_path:        str, the path to store the policy and checkpoints at.
        logger:             the logger to use (either wandb or csv).
        mean_success_rate:  RunningAverage, the success rate over several evaluation periods.
        mean_episode_length:RunningAverage, the episode length over several evaluation periods.
        mean_reward:        RunningAverage, the episodic return over several evaluation periods.
        time_remaining:     int, if using time limits, the amount of time left since starting the job.
        verbose:            bool, whether or not to print updates to standard out.
    """
    start_time = time()
    # Turn on train mode
    agent.eval_off()
    # Choose an environment at the start and on every episode reset.
    env = RNG.rng.choice(envs)
    agent.context_reset(env.reset())

    for timestep in range(agent.num_train_steps, total_steps):
        done = step(agent, env, eps)

        if done:
            agent.replay_buffer.flush()
            env = RNG.rng.choice(envs)
            agent.context_reset(env.reset())
        agent.train()
        eps.anneal()

        if timestep % eval_frequency == 0:
            hours = (time() - start_time) / 3600
            # Log training values
            log_vals = {
                "losses/TD_Error": agent.td_errors.mean(),
                "losses/Grad_Norm": agent.grad_norms.mean(),
                "losses/Max_Q_Value": agent.qvalue_max.mean(),
                "losses/Mean_Q_Value": agent.qvalue_mean.mean(),
                "losses/Min_Q_Value": agent.qvalue_min.mean(),
                "losses/Max_Target_Value": agent.target_max.mean(),
                "losses/Mean_Target_Value": agent.target_mean.mean(),
                "losses/Min_Target_Value": agent.target_min.mean(),
                "losses/hours": hours,
            }
            # Perform an evaluation for each of the eval environments and add to our log
            for env_str, eval_env in zip(env_strs, eval_envs):
                sr, ret, length = evaluate(agent, eval_env, eval_episodes)

                log_vals.update(
                    {
                        f"{env_str}/SuccessRate": sr,
                        f"{env_str}/Return": ret,
                        f"{env_str}/EpisodeLength": length,
                    }
                )

            # Commit the log values.
            logger.log(
                log_vals,
                step=timestep,
            )

            if verbose:
                print(
                    f"[ {timestamp()} ] Training Steps: {timestep}, Env: {env_str}, Success Rate: {sr:.2f}, Return: {ret:.2f}, Episode Length: {length:.2f}, Hours: {hours:.2f}"
                )

        if save_policy and timestep % 50_000 == 0:
            torch.save(agent.policy_network.state_dict(), policy_path)

        if time_remaining and time() - start_time >= time_remaining:
            print(
                f"Reached time limit. Saving checkpoint with {agent.num_train_steps} steps completed."
            )

            agent.save_checkpoint(
                policy_path,
                wandb.run.id if logger == wandb else None,
                mean_success_rate,
                mean_reward,
                mean_episode_length,
                eps,
            )
            return

In [None]:
def step(agent, env: Env, eps: float) -> bool:
    """Use the agent's policy to get the next action, take it, and then record the result.

    Arguments:
        agent:  the agent to use.
        env:    gym.Env
        eps:    the epsilon value (for epsilon-greedy policy)

    Returns:
        done: bool, whether or not the episode has finished.
    """
    action = agent.get_action(epsilon=eps.val)
    next_obs, reward, done, info = env.step(action)

    # OpenAI Gym TimeLimit truncation: don't store it in the buffer as done
    if info.get("TimeLimit.truncated", False):
        buffer_done = False
    else:
        buffer_done = done

    agent.observe(next_obs, action, reward, buffer_done)
    return done

In [None]:
def prepopulate(agent, prepop_steps: int, envs: Tuple[Env]) -> None:
    """Prepopulate the replay buffer. Sample an enviroment on each episode.

    Arguments:
        agent:          the agent whose buffer needs to be stored.
        prepop_steps:   int, the number of timesteps to populate.
        envs:           Tuple[gym.Env], the list of environments to use for sampling.
    """
    timestep = 0
    while timestep < prepop_steps:
        env = RNG.rng.choice(envs)
        agent.context_reset(env.reset())
        done = False
        while not done:
            action = RNG.rng.integers(env.action_space.n)
            next_obs, reward, done, info = env.step(action)

            # OpenAI Gym TimeLimit truncation: don't store it in the buffer as done
            if info.get("TimeLimit.truncated", False):
                buffer_done = False
            else:
                buffer_done = done

            agent.observe(next_obs, action, reward, buffer_done)
            timestep += 1
        agent.replay_buffer.flush()

Args
- obs_embed
- envs
- num_steps
- model
- a_embed
- in_embed
- buf_size
- lr
- batch
- context
- max_episode_steps
- history
- tuf
- discount
- heads
- layers
- dropout
- identity
- pos
- bad_size
- eval_frequency
- eval_episodes
- save_policy
- verbose

In [None]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--project-name",
        type=str,
        default="DTQN-test",
        help="The project name (for wandb) or directory name (for local logging) to store the results.",
    )
    parser.add_argument(
        "--disable-wandb",
        action="store_true",
        help="Use `--disable-wandb` to log locally.",
    )
    parser.add_argument(
        "--time-limit",
        type=float,
        default=None,
        help="Time limit allowed for job. Useful for some cluster jobs such as slurm.",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="DTQN",
        choices=list(MODEL_MAP.keys()),
        help="Network model to use.",
    )
    parser.add_argument(
        "--envs",
        type=str,
        nargs="+",
        default="DiscreteCarFlag-v0",
        help="Domain to use. You can supply multiple domains, but they must have the same observation and action space. With multiple environments, the agent will sample a new one on each episode reset for conducting policy rollouts and collection experience. During evaluation, it will perform the same evaluation for each domain (Note: this may significantly slow down your run! Consider increasing the eval-frequency or reducing the eval-episodes).",
    )
    parser.add_argument(
        "--num-steps",
        type=int,
        default=2_000_000,
        help="Number of steps to train the agent.",
    )
    parser.add_argument(
        "--tuf",
        type=int,
        default=10_000,
        help="How many steps between each (hard) target network update.",
    )
    parser.add_argument(
        "--lr", type=float, default=3e-4, help="Learning rate for the optimizer."
    )
    parser.add_argument("--batch", type=int, default=32, help="Batch size.")
    parser.add_argument(
        "--buf-size",
        type=int,
        default=500_000,
        help="Number of timesteps to store in replay buffer. Note that we store the max length episodes given by the environment, so episodes that take longer will be padded at the end. This does not affect training but may affect the number of real observations in the buffer.",
    )
    parser.add_argument(
        "--eval-frequency",
        type=int,
        default=5_000,
        help="How many training timesteps between agent evaluations.",
    )
    parser.add_argument(
        "--eval-episodes",
        type=int,
        default=10,
        help="Number of episodes for each evaluation period.",
    )
    parser.add_argument(
        "--device", type=str, default="cuda", help="Pytorch device to use."
    )
    parser.add_argument(
        "--context",
        type=int,
        default=50,
        help="For DRQN and DTQN, the context length to use to train the network.",
    )
    parser.add_argument(
        "--obs-embed",
        type=int,
        default=8,
        help="For discrete observation domains only. The number of features to give each observation.",
    )
    parser.add_argument(
        "--a-embed",
        type=int,
        default=0,
        help="The number of features to give each action. A value of 0 will prevent the policy from using the previous action.",
    )
    parser.add_argument(
        "--in-embed",
        type=int,
        default=128,
        help="The dimensionality of the network. In the transformer, this is referred to as `d_model`.",
    )
    parser.add_argument(
        "--max-episode-steps",
        type=int,
        default=-1,
        help="The maximum number of steps allowed in the environment. If `env` has a `max_episode_steps`, this will be inferred. Otherwise, this argument must be supplied.",
    )
    parser.add_argument("--seed", type=int, default=1, help="The random seed to use.")
    parser.add_argument(
        "--save-policy",
        action="store_true",
        help="Use this to save the policy so you can load it later for rendering.",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Print out evaluation results as they come in to the console.",
    )
    parser.add_argument(
        "--render",
        action="store_true",
        help="Enjoy mode (NOTE: must have a trained policy saved).",
    )
    parser.add_argument(
        "--history",
        type=int,
        default=50,
        help="This is how many (intermediate) Q-values we use to train for each context. To turn off intermediate Q-value prediction, set `--history 1`. To use the entire context, set history equal to the context length.",
    )
    # DTQN-Specific
    parser.add_argument(
        "--heads",
        type=int,
        default=8,
        help="Number of heads to use for the transformer.",
    )
    parser.add_argument(
        "--layers",
        type=int,
        default=2,
        help="Number of transformer blocks to use for the transformer.",
    )
    parser.add_argument(
        "--dropout", type=float, default=0.0, help="Dropout probability."
    )
    parser.add_argument("--discount", type=float, default=0.99, help="Discount factor.")
    parser.add_argument(
        "--identity",
        action="store_true",
        help="Whether or not to use identity map reordering.",
    )
    parser.add_argument(
        "--pos",
        default="learned",
        choices=["learned", "sin", "none"],
        help="The type of positional encodings to use.",
    )
    parser.add_argument(
        "--bag-size", type=int, default=0, help="The size of the persistent memory bag."
    )
    # For slurm
    parser.add_argument(
        "--slurm-job-id",
        default=0,
        type=str,
        help="The `$SLURM_JOB_ID` assigned to this job.",
    )

    return parser.parse_args()

In [None]:
obs_embed = 8 # For discrete observation domains only. The number of features to give each observation
envs = "string for domain - refer back to custom env" 
# Domain to use. You can supply multiple domains, but they must have the same observation and action space. 
# With multiple environments, the agent will sample a new one on each episode reset for conducting policy 
# rollouts and collection experience. During evaluation, it will perform the same evaluation for each domain
# (Note: this may significantly slow down your run! Consider increasing the eval-frequency or reducing the eval-episodes)
num_steps = 2_000_000 # Number of steps to train the agent
model = "DTQN" # Network model to use
a_embed = 0 # The number of features to give each action. A value of 0 will prevent the policy from using the previous action
in_embed = 1298 # The dimensionality of the network. In the transformer, this is referred to as `d_model`
buf_size = 500_000 #Number of timesteps to store in replay buffer.
# Note that we store the max length episodes given by the environment, so episodes that take longer will be padded at the end. 
# This does not affect training but may affect the number of real observations in the buffer
lr = 3e-4 # Learning rate for the optimizer
batch = 32 # Batch size
context = 50 # For DRQN and DTQN, the context length to use to train the network
max_episode_steps = -1 # The maximum number of steps allowed in the environment. 
#If `env` has a `max_episode_steps`, this will be inferred. Otherwise, this argument must be supplied.
history = 50 #This is how many (intermediate) Q-values we use to train for each context. 
# To turn off intermediate Q-value prediction, set `--history 1`. To use the entire context, 
# set history equal to the context length.
tuf = 10_000 # How many steps between each (hard) target network update
discount = 0.99 # Discount factor
heads = 8 # Number of heads to use for the transformer
layers = 2 # Number of transformer blocks to use for the transformer
dropout = 0.0 # Dropout probability
identity = True # Whether or not to use identity map reordering
pos = # type f positional encoder to use (look back at one they had to see if there is something to switch otherwise set value)
bag_size = 0 # The size of the persistent memory bag
eval_frequency = 5_000 #How many training timesteps between agent evaluations
eval_episodes = 10 #Number of episodes for each evaluation period.
save_policy = True # Use this to save the policy so you can load it later for rendering.
verbose = True # (boolean - Print out evaluation results as they come in to the console): default = 

In [None]:
def run_experiment(args):
    """Uses the command-line arguments to create the agent and associated tools, then begin training."""
    start_time = time()
    # Create envs, set seed, create RL agent
    envs = []
    eval_envs = []
    for env_str in args.envs:
        envs.append(env_processing.make_env(env_str))
        eval_envs.append(env_processing.make_env(env_str))
    device = torch.device(args.device)
    set_global_seed(args.seed, *(envs + eval_envs))

    eps = epsilon_anneal.LinearAnneal(1.0, 0.1, args.num_steps // 10)

    agent = get_agent(
        args.model,
        envs,
        args.obs_embed,
        args.a_embed,
        args.in_embed,
        args.buf_size,
        device,
        args.lr,
        args.batch,
        args.context,
        args.max_episode_steps,
        args.history,
        args.tuf,
        args.discount,
        # DTQN specific
        args.heads,
        args.layers,
        args.dropout,
        args.identity,
        args.pos,
        args.bag_size,
    )

    print(
        f"[ {timestamp()} ] Creating {args.model} with {sum(p.numel() for p in agent.policy_network.parameters())} parameters"
    )

    # Create logging dir
    policy_save_dir = os.path.join(
        os.getcwd(), "policies", args.project_name, *args.envs
    )
    os.makedirs(policy_save_dir, exist_ok=True)
    policy_path = os.path.join(
        policy_save_dir,
        f"model={args.model}_envs={','.join(args.envs)}_obs_embed={args.obs_embed}_a_embed={args.a_embed}_in_embed={args.in_embed}_context={args.context}_heads={args.heads}_layers={args.layers}_"
        f"batch={args.batch}_identity={args.identity}_history={args.history}_pos={args.pos}_bag={args.bag_size}_seed={args.seed}",
    )

    # Enjoy mode
    if args.render:
        agent.policy_network.load_state_dict(
            torch.load(policy_path, map_location="cpu")
        )
        evaluate(agent, eval_envs[0], 1_000_000, render=True)

    # If there is already a saved checkpoint, load it and resume training if more steps are needed
    # Or exit early if we have already finished training.
    if os.path.exists(policy_path + "_mini_checkpoint.pt"):
        steps_completed = agent.load_mini_checkpoint(policy_path)["step"]
        print(
            f"Found a mini checkpoint that completed {steps_completed} training steps."
        )
        if steps_completed >= args.num_steps:
            print(f"Removing checkpoint and exiting...")
            if os.path.exists(policy_path + "_checkpoint.pt"):
                os.remove(policy_path + "_checkpoint.pt")
            exit(0)
        else:
            (
                wandb_id,
                mean_success_rate,
                mean_reward,
                mean_episode_length,
                eps_val,
            ) = agent.load_checkpoint(policy_path)
            eps.val = eps_val
            wandb_kwargs = {"resume": "must", "id": wandb_id}
    # Begin training from scratch
    else:
        wandb_kwargs = {"resume": None}
        # Prepopulate the replay buffer
        prepopulate(agent, 50_000, envs)
        mean_success_rate = RunningAverage(10)
        mean_reward = RunningAverage(10)
        mean_episode_length = RunningAverage(10)

    # Logging setup
    logger = get_logger(policy_path, args, wandb_kwargs)

    time_remaining = (
        args.time_limit * 3600 - (time() - start_time) if args.time_limit else None
    )

    train(
        agent,
        envs,
        eval_envs,
        args.envs,
        args.num_steps,
        eps,
        args.eval_frequency,
        args.eval_episodes,
        policy_path,
        args.save_policy,
        logger,
        mean_success_rate,
        mean_reward,
        mean_episode_length,
        time_remaining,
        args.verbose,
    )

    # Save a small checkpoint if we finish training to let following runs know we are finished
    agent.save_mini_checkpoint(
        checkpoint_dir=policy_path, wandb_id=wandb.run.id if logger == wandb else None
    )

In [None]:
run_experiment #get_args