In [None]:
import ray
import json
from _jsonnet import evaluate_file
import numpy as np
import os

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
# ROOT_DIR = '/content/drive/My Drive/Colab Notebooks'
RESULT_DIR = f'{ROOT_DIR}/results'
CONFIG_DIR = f'{ROOT_DIR}/configs'
TENSORBOARD_DIR = f'{ROOT_DIR}/tensorboards'

"""
:class:`~data_pulling.utilities.registrable.Registrable` is a "mixin" for endowing
any base class with a named registry for its subclasses and a decorator
for registering them. It is adapted from the allennlp codebase:
https://github.com/allenai/allennlp/blob/master/allennlp/common/registrable.py
"""

import logging
from collections import defaultdict
from typing import TypeVar, Type, Dict, List

logger = logging.getLogger(__name__)

T = TypeVar('T')


class Registrable:
    """
    Any class that inherits from ``Registrable`` gains access to a named
    registry for its subclasses. To register them, just decorate them with the
    classmethod ``@BaseClass.register(name)``.
    After which you can call ``BaseClass.list_available()`` to get the keys for
    the registered subclasses, and ``BaseClass.by_name(name)`` to get the
    corresponding subclass.

    Note that the registry stores the subclasses themselves; not class
    instances. In most cases you would then call ``from_params(params)`` on the
    returned subclass.

    You can specify a default by setting ``BaseClass.default_implementation``.
    If it is set, it will be the first element of ``list_available()``.
    Note that if you use this class to implement a new ``Registrable`` abstract
    class, you must ensure that all subclasses of the abstract class are loaded
    when the module is loaded, because the subclasses register themselves in
    their respective files. You can achieve this by having the abstract class
    and all subclasses in the __init__.py of the module in which they reside
    (as this causes any import of either the abstract class or a subclass to
    load all other subclasses and the abstract class).
    """
    _registry: Dict[Type, Dict[str, Type]] = defaultdict(dict)
    default_implementation: str = None

    @classmethod
    def register(cls: Type[T], name: str):
        registry = Registrable._registry[cls]

        def add_subclass_to_registry(subclass: Type[T]):
            # Add to registry, raise an error if key has already been used.
            if name in registry:
                message = "Cannot register %s as %s; name already in use for %s" % (
                    name, cls.__name__, registry[name].__name__)
                raise ConfigurationError(message)
            registry[name] = subclass
            return subclass

        return add_subclass_to_registry

    @classmethod
    def by_name(cls: Type[T], name: str) -> Type[T]:
        logger.debug(f"instantiating registered subclass {name} of {cls}")
        if name not in Registrable._registry[cls]:
            raise ConfigurationError(
                "%s is not a registered name for %s" % (name, cls.__name__))
        return Registrable._registry[cls].get(name)

    @classmethod
    def list_available(cls) -> List[str]:
        """List default first if it exists"""
        keys = list(Registrable._registry[cls].keys())
        default = cls.default_implementation

        if default is None:
            return keys
        elif default not in keys:
            message = "Default implementation %s is not registered" % default
            raise ConfigurationError(message)
        else:
            return [default] + [k for k in keys if k != default]


class ConfigurationError(Exception):
    def __init__(self, message):
        super(ConfigurationError, self).__init__()
        self.message = message

    def __str__(self):
        return repr(self.message)


import numpy as np
from enum import Enum
from typing import NamedTuple, Union


class RLAlgorithm(Enum):
    VI = 'value_iteration'
    PI = 'policy_iteration'
    QLearning = 'q_learning'
    MONTE_CARLO = 'monte_carlo'
    EXPECTED_SARSA = 'expected_sarsa'
    SARSA = 'sarsa'


class TargetUpdate(Enum):
    HARD = 'hard'
    SOFT = 'soft'


class ReplayType(Enum):
    EXPERIENCE_REPLAY = 'ExperienceReplay'
    PRIORITIZED_EXPERIENCE_REPLAY = 'PrioritizedExperienceReplay'


class Transition(NamedTuple):
    # store state, action, reward, next_state, done as Transition tuple
    s0: np.ndarray
    a: Union[int, str]  # Action
    r: np.ndarray
    s1: np.ndarray
    done: bool = False



import torch
from collections import defaultdict
from torch.distributions import Categorical
import numpy as np
from gym import Env


def nested_d():
    """for any arbitrary number of levels"""
    return defaultdict(nested_d)


# Thanks Vlad!
def torch_argmax_mask(q: torch.Tensor, dim: int):
    """ Returns a random tie-breaking argmax mask
    Example:
        >>> import torch
        >>> torch.manual_seed(1337)
        >>> q = torch.ones(3, 2)
        >>> torch_argmax_mask(q, 1)
        # tensor([[False,  True],
        #         [ True, False],
        #         [ True, False]])
        >>> torch_argmax_mask(q, 1)
        # tensor([[False,  True],
        #         [False,  True],
        #         [ True, False]])
    """
    rand = torch.rand_like(q)
    if dim == 0:
        mask = rand * (q == q.max(dim)[0])
        mask = mask == mask.max(dim)[0]
        assert int(mask.sum()) == len(q.shape)
    elif dim == 1:
        mask = rand * (q == q.max(dim)[0].unsqueeze(1).expand(q.shape))
        mask = mask == mask.max(dim)[0].unsqueeze(1).expand(q.shape)
        assert int(mask.sum()) == int(q.shape[0])
    else:
        raise NotImplemented("Only vectors and matrices are supported")
    return mask


def get_epsilon_dist(eps: float, env: Env, observation: torch.Tensor,
                     model: torch.nn.Module) -> Categorical:
    """get the probability distributions of the q-value"""
    q = model(observation)
    probs = torch.empty_like(q).fill_(
        eps / (env.action_space.n - 1))
    probs[torch_argmax_mask(q, len(q.shape) - 1)] = 1 - eps
    return Categorical(probs=probs)


def get_epsilon(eps_start: float, eps_final: float, eps_decay: float,
                t: int) -> float:
    """use decay for epsilon exploration"""
    return eps_final + (eps_start - eps_final) * np.exp(-1.0 * t / eps_decay)


def soft_update(value_net: torch.nn.Module, target_net: torch.nn.Module,
                tau: float):
    """update each training step by a hyperparameter adjustment"""
    for t_param, v_param in zip(target_net.parameters(),
                                value_net.parameters()):
        if t_param is v_param:
            continue
        new_param = tau * v_param.data + (1.0 - tau) * t_param.data
        t_param.data.copy_(new_param)


def hard_update(value_net: torch.nn.Module, target_net: torch.nn.Module):
    """update each training step by a full update, based on an update frequency"""
    for t_param, v_param in zip(target_net.parameters(),
                                value_net.parameters()):
        if t_param is v_param:
            continue
        new_param = v_param.data
        t_param.data.copy_(new_param)


import logging
from typing import List
from termcolor import colored


class ProjectLogger:
    def __init__(self,
                 log_file: str = None,
                 level: int = logging.DEBUG,
                 printing: bool = True, attrs: List[str] = None,
                 name: str = 'project_logger',
                 ):
        """ Basic logger that can write to a file on disk or to sterr.
        :param log_file: name of the file to log to
        :param level: logging verbosity level
        :param printing: flag for whether to log to sterr
        """
        root_logger = logging.getLogger(name)
        root_logger.setLevel(level)
        self.printing = printing
        self.attrs = attrs

        # Set up writing to a file
        if log_file:
            file_handler = logging.FileHandler(log_file, mode='a')
            file_formatter = logging.Formatter(
                '%(levelname)s: %(asctime)s %(message)s',
                datefmt='%m/%d/%Y %image:%M:%S %p'
            )
            file_handler.setFormatter(file_formatter)
            root_logger.addHandler(file_handler)

        # Set up printing to stderr
        def check_if_sterr(hdlr: logging.Handler):
            return isinstance(hdlr, logging.StreamHandler) \
                   and not isinstance(hdlr, logging.FileHandler)

        if printing and not list(filter(check_if_sterr, root_logger.handlers)):
            console_handler = logging.StreamHandler()
            console_handler.setFormatter(logging.Formatter("%(message)s"))
            root_logger.addHandler(console_handler)

        self.log = root_logger

    def debug(self, msg, color='grey', attrs: List[str] = None):
        self.log.debug(colored(msg, color, attrs=attrs or self.attrs))

    def info(self, msg, color='green', attrs: List[str] = None):
        self.log.info(colored(msg, color, attrs=attrs or self.attrs))

    def warning(self, msg, color='blue', attrs: List[str] = None):
        self.log.warning(colored(msg, color, attrs=attrs or self.attrs))

    def error(self, msg, color='magenta', attrs: List[str] = None):
        self.log.error(colored(msg, color, attrs=attrs or self.attrs))

    def critical(self, msg, color='red', attrs: List[str] = None):
        self.log.critical(colored(msg, color, attrs=attrs or self.attrs))


import pickle
import os
from typing import Dict
from collections import defaultdict
import plotly
import plotly.graph_objects as go
from torch.utils.tensorboard import SummaryWriter


def plot_episodic_results(idx: int, seed: int, writer: SummaryWriter,
                          episode_result: defaultdict):
    """writes episodic results into tensorboard, called at the end of each
    episode"""
    for k, v in episode_result.items():
        for i in range(idx, idx + len(v)):
            if 'net_params' not in k:
                writer.add_scalar(tag=k, scalar_value=v[i - idx], global_step=i)
            else:
                for tag, value in v[i - idx]:
                    tag_ = f"{tag.replace('.', '/')}/{str(seed)}"
                    writer.add_histogram(tag_, value.data.cpu().numpy(), i)
                    tag_ = f"{tag.replace('.', '/')}/grad/{str(seed)}"
                    writer.add_histogram(tag_, value.grad.data.cpu().numpy(),
                                         i)



import random
from collections import deque
from typing import List, Dict, Tuple
import torch
import numpy as np

class Replay(Registrable):
    @classmethod
    def build(cls, type: str, params: Dict):
        replay = cls.by_name(type)
        return replay.from_params(params)

    @classmethod
    def from_params(cls, params: Dict):
        raise NotImplementedError(
            f'from_params not implemented in {cls.__class__.name}')


# TODO: initial version, not optimized with tree structures used in paper
@Replay.register('ExperienceReplay')
class ExperienceReplay(Replay, Registrable):
    def __init__(self,
                 capacity: int,
                 n_step: int,
                 gamma: float):
        """

        :param capacity: maximum number of transition tuple stored in replay
        :param n_step: n step used for replay
        :param gamma: discount factor when computing td-error
        """
        self.replay_type = self.__class__.__name__
        self.capacity = capacity
        self.memory = deque(maxlen=self.capacity)
        self.n_step = n_step
        if self.n_step > 0:
            self.n_step_memory = deque(maxlen=self.n_step)
            self.gamma = gamma

    @classmethod
    def from_params(cls, params: Dict):
        return cls(**params)

    def push(self, transition: Transition):
        """push Transition into memory for batch sampling and n_step
        computations"""
        if self.n_step > 0:
            self.n_step_memory.append(transition)
            if len(self.n_step_memory) == self.n_step:
                transition = self.generate_n_step_q()
        self.memory.append(transition)

    # TODO: try running with this update
    def generate_n_step_q(self) -> Transition:
        """with s(t), s(t+1), calculate a discounted reward by backtracking
        n_steps prior to t and setting s(t) to s(t-n_steps)"""
        transitions = self.n_step_memory
        reward = 0
        next_observation, done = transitions[-1][-2:]
        for idx, transition in enumerate(transitions):
            reward += (self.gamma ** idx) * transition.r
        # for i in range(len(transitions) - 1):
        #     reward = self.gamma * reward * (1 - transitions[i].done) + \
        #              transitions[i].r
        #     next_observation, done = (transitions[i].s1, transitions[i].done) \
        #         if transitions[i].done else (next_observation, done)
        observation, action = transitions[0][:2]
        return Transition(s0=observation, a=action, r=reward,
                          s1=next_observation, done=done)

    def sample(self, batch_size: int) -> List[Transition]:
        """Uniform sampling with a batch from memory and concatenates the
        dimensions of observations and convert to torch"""
        batch = random.sample(self.memory, batch_size if len(
            self.memory) > batch_size else len(self.memory))
        observation, action, reward, next_observation, done = zip(*batch)
        observation = torch.cat(tuple(torch.FloatTensor(observation)), dim=0)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        next_observation = torch.cat(tuple(torch.FloatTensor(next_observation)),
                                     dim=0)
        done = torch.FloatTensor(done)
        return Transition(s0=observation, a=action, r=reward,
                          s1=next_observation, done=done)

    def __len__(self):
        return len(self.memory)


@Replay.register('PrioritizedExperienceReplay')
class PrioritizedExperienceReplay(ExperienceReplay, Registrable):
    def __init__(self,
                 capacity: int,
                 n_step: int,
                 alpha: float,
                 beta: float,
                 beta_inc: float,
                 gamma: float,
                 non_zero_variant: float):

        """

        :param capacity: maximum number of transition tuple stored in replay
        :param n_step: n step used for replay
        :param gamma: discount rate for calculating td-error
        :param alpha: 0 for no prioritization, 1 for full prioritization
        :param beta:
        :param beta_inc:
        :param non_zero_variant: small constant to ensure non-zero probabilities
        """
        # try with alpha=0.6, beta=0.4, beta_inc=100~network update frequency
        super().__init__(capacity=capacity, n_step=n_step, gamma=gamma)
        assert alpha + beta == 1.0
        self.alpha = alpha
        self.beta = beta
        self.beta_inc = (1 - beta) / beta_inc
        self.non_zero_variant = non_zero_variant
        self.priorities = np.zeros([self.capacity], dtype=np.float32)
        self.idx = 0
        self.memory = []

    @classmethod
    def from_params(cls, params: Dict):
        return cls(**params)

    def push(self, transition: Transition):
        max_prior = np.max(self.priorities) if self.memory else 1.0

        # n_step computation
        if self.n_step > 0:
            self.n_step_memory.append(transition)
            if len(self.n_step_memory) == self.n_step:
                transition = self.generate_n_step_q()

        # unlike ExperienceReplay which updates based on FIFO
        # update from start of queue
        if len(self.memory) < self.capacity:
            self.memory.append(transition)
        else:
            self.memory[self.idx] = transition
        self.priorities[self.idx] = max_prior
        self.idx += 1
        self.idx = self.idx % self.capacity

    def sample(self, batch_size: int) -> Tuple[List[Transition], np.array]:
        """use absolute td-error to favor model to optimize"""
        if len(self.memory) < self.capacity:
            probs = self.priorities[:len(self.memory)]
        else:
            probs = self.priorities
        # probs = abs(td-error), use probabilities
        probs = (probs ** self.alpha) / np.sum(probs ** self.alpha)
        self.indices = np.random.choice(len(self.memory), batch_size, p=probs)
        if self.beta < 1:
            self.beta += self.beta_inc

        # samples a batch from memory and concatenates the dimensions of
        # observations and convert to torch
        batch = [self.memory[idx] for idx in self.indices]
        observation, action, reward, next_observation, done = zip(*batch)
        observation = torch.cat(tuple(torch.FloatTensor(observation)), dim=0)
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        next_observation = torch.cat(tuple(torch.FloatTensor(next_observation)),
                                     dim=0)
        done = torch.FloatTensor(done)

        # need weights to compute loss
        weights = (len(self.memory) * probs[self.indices]) ** -self.beta
        weights = np.array(weights / np.max(weights), dtype=np.float32)
        return Transition(s0=observation, a=action, r=reward,
                          s1=next_observation, done=done), weights

    def update_priorities(self, losses: np.array):
        """update absolute td-error to compute probabilities"""
        for idx, priority in zip(self.indices, losses):
            self.priorities[idx] = priority



from typing import Dict, Generator, List, Tuple
import torch
import torch.nn as nn


class TorchModel(nn.Module, Registrable):
    @classmethod
    def build(cls, type: str, params: Dict):
        model = cls.by_name(type)
        return model.from_params(params)

    @classmethod
    def from_params(cls, params: Dict):
        return cls(**params)

    def forward(self, *input):
        raise NotImplementedError()

    def to_device(self, device):
        self.to(device)
        self.device = device


@TorchModel.register('LinearFCBody')
class LinearFCBody(TorchModel):
    def __init__(self,
                 seed: int,
                 state_dim: int,
                 action_dim: int,
                 hidden_units: List = [64, 64],
                 gate: nn.ReLU = nn.ReLU):
        super(LinearFCBody, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.state_dim = state_dim
        self.action_dim = action_dim
        hidden_unit_1, hidden_unit_2 = hidden_units[0], hidden_units[1]
        self.fc1 = nn.Sequential(
            nn.Linear(self.state_dim, hidden_unit_1),
            # nn.BatchNorm1d(hidden_size1),
            gate()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(hidden_unit_1, hidden_unit_2),
            # nn.BatchNorm1d(hidden_size2),
            gate()
        )
        self.fc3 = nn.Linear(hidden_unit_2, self.action_dim)

    @classmethod
    def from_params(cls, params: Dict):
        return cls(**params)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


@TorchModel.register('BasicRNN')
class BasicRNN(TorchModel):
    def __init__(self,
                 seed: int,
                 input_dim: int,
                 hidden_dim: int,
                 num_layers: int,
                 output_dim: int,
                 dropout: float = 0,
                 bidirectional: bool = False):
        super().__init__()
        self.seed = torch.manual_seed(seed)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_directions = 2 if bidirectional else 1
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True,
                            bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_dim * self.num_directions, output_dim)

    @classmethod
    def from_params(cls, params: Dict):
        return cls(**params)

    def forward(self, x):
        # Set initial hidden and cell states
        h0 = torch.zeros(self.num_layers * self.num_directions,
                         x.size(0), self.hidden_dim).to(self.device)
        c0 = torch.zeros(self.num_layers * self.num_directions,
                         x.size(0), self.hidden_dim).to(self.device)

        # Forward propagate LSTM
        # out: tensor of shape (batch_size, seq_length, hidden_dim)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out


@TorchModel.register('Inception')
class Inception(TorchModel):
    def __init__(self,
                 in_channels: int,
                 gate: nn.ReLU):
        super().__init__()
        out_channels = 32
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)),
            # gate(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 1)), )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)),
            # gate(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(5, 1)), )
        self.conv3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=(3, 1)),
            # gate(),
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)), )

    @classmethod
    def from_params(cls, params: Dict):
        return cls(**params)

    def forward(self, x):
        # (batch_size, out_channels, timesteps, features),
        # timesteps=timesteps-kernel+1
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)
        conv3 = self.conv3(x)
        outputs = [conv1, conv2, conv3]
        return torch.cat(outputs, 2)  # concatenate by timesteps



from gym import Env
from typing import Dict


class Agent(Registrable):
    @classmethod
    def build(cls, type: str, env: Env, params: Dict):
        agent = cls.by_name(type)
        return agent.from_params(env, params)

    @classmethod
    def from_params(cls, env: Env, params: Dict):
        raise NotImplementedError(
            f'from_params not implemented in {cls.__name__}')


import time
import numpy as np
from typing import Dict, Generator, Tuple
from gym import Env
import torch.optim as optim
import torch.nn as nn
import torch
# import torchviz


@Agent.register('DeepTDAgent')
class DeepTDAgent(Agent, Registrable):
    def __init__(self,
                 env: Env,
                 agent_cfg: Dict):
        """

        :param env: gym environment used for Experiment
        :param agent_cfg: config file for given agent
        """
        super().__init__()
        self.env = env
        self.epochs, self.total_steps = 0, 0
        self.episodic_result = dict()
        self.episodic_result['Training/Q-Loss'] = []
        self.episodic_result['Training/Mean-Q-Value-Action'] = []
        self.episodic_result['Training/Mean-Q-Value-Opposite-Action'] = []
        self.episodic_result['value_net_params'] = []

        # specs for RL agent
        self.eps = agent_cfg['eps']
        if agent_cfg['use_eps_decay']:
            self.use_eps_decay = agent_cfg['use_eps_decay']
            self.eps_decay = agent_cfg['eps_decay']
            self.eps_min = agent_cfg['eps_min']
        self.gamma = agent_cfg['gamma']
        self.update_type = agent_cfg['update_type']
        if self.update_type == TargetUpdate.SOFT.value:
            self.tau = agent_cfg['tau']
        self.update_freq = agent_cfg['update_freq']
        self.warm_up_freq = agent_cfg['warm_up_freq']
        self.use_grad_clipping = agent_cfg['use_grad_clipping']
        self.grad_clipping = agent_cfg['grad_clipping']
        self.lr = agent_cfg['lr']
        self.batch_size = agent_cfg['batch_size']
        self.seed = agent_cfg['seed']
        self.params = vars(self).copy()

        # details experience replay
        self.replay_buffer = Replay.build(
            type=agent_cfg['experience_replay']['type'],
            params=agent_cfg['experience_replay']['params'])

        # details for the NN model
        agent_cfg['model']['seed'] = self.seed
        use_cuda = agent_cfg['use_gpu'] and torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.value_net = TorchModel.build(type=agent_cfg['model']['type'],
                                          params=agent_cfg['model']['params'])
        self.target_net = TorchModel.build(type=agent_cfg['model']['type'],
                                           params=agent_cfg['model']['params'])
        self.value_net.to_device(self.device)
        self.target_net.to_device(self.device)
        self.target_net.load_state_dict(self.value_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.value_net.parameters(), self.lr)

        # Huber loss acts like the mean squared error when the error is small,
        # but like the mean absolute error when the error is large
        # this makes it more robust to outliers when the estimates of Q
        # are very noisy. It is calculated over a batch of transitions
        # sampled from the replay memory:
        self.loss_func = nn.SmoothL1Loss()

    @classmethod
    def from_params(cls, env: Env, params: Dict):
        return cls(env, **params)

    @torch.no_grad()
    def get_action(self, observation) -> int:
        """either take greedy action or explore with epsilon rate"""
        if np.random.random() < self.eps:
            # return self.env.action_space.sample()
            return np.random.randint(self.env.action_space.n)
        else:
            state = torch.FloatTensor(observation).to(self.device)
            return self.value_net(state).max(1)[1].data[0].item()
        # observation = torch.FloatTensor(observation).to(self.device)
        # dist = get_epsilon_dist(eps=self.eps, env=self.env,
        #                         model=self.value_net, observation=observation)
        # return dist.sample().item()

    def train(self):
        raise NotImplementedError('DeepTD Agent requires a train() method')

    def training_update(self):
        """backprop loss, target network, epsilon, and result updates"""
        self.epochs += 1
        self.train()

        # hard update
        if self.update_type == TargetUpdate.HARD.value:
            if self.epochs != 0 and self.update_freq % self.epochs == 0:
                self.target_net.load_state_dict(
                    self.value_net.state_dict())
        elif self.update_type == TargetUpdate.SOFT.value:
            soft_update(value_net=self.value_net, target_net=self.target_net,
                        tau=self.tau)

        if self.use_eps_decay:
            # self.eps = get_epsilon(eps_start=self.eps, eps_final=self.eps_min,
            #                        eps_decay=self.eps_decay, t=1)
            if self.eps >= self.eps_min:
                self.eps *= self.eps_decay

        # save episodic results
        self.episodic_result['Training/Q-Loss'].append(
            self.loss.detach().cpu().numpy())
        self.episodic_result['Training/Mean-Q-Value-Action'].append(
            np.mean(self.q.detach().cpu().numpy()))
        self.episodic_result[
            'Training/Mean-Q-Value-Opposite-Action'].append(
            np.mean(self.q_.detach().cpu().numpy()))
        # self.episodic_result['value_net_params'].append(
        #     self.value_net.named_parameters())

    def learn(self, num_steps: int) -> Generator:
        """use agent to interact with environment by making actions based on
        optimal policy to obtain cumulative rewards"""
        cr, t = 0, 0
        done = False
        start = time.time()
        observation = self.env.reset()
        observation = np.expand_dims(observation, 0)
        action = self.get_action(observation=observation)

        while not done and t < num_steps:
            self.total_steps += 1
            next_observation, reward, done, info = self.env.step(action)
            next_observation = np.expand_dims(next_observation, 0)
            cr += reward

            # store into experience replay buffer and sample batch of
            # transitions to estimate the q-values and train on losses
            transition = Transition(s0=observation, a=action, r=reward,
                                    s1=next_observation, done=done)
            self.replay_buffer.push(transition)

            # train policy network and update target network
            # update epsilon decay, more starting exploration
            if self.total_steps >= self.warm_up_freq:
                self.training_update()

            observation = next_observation
            action = self.get_action(observation=observation)
            t += 1

        # TODO: pause training and use eval with generator, need to add eval!
        yield {
            'cum_reward': cr,
            'time_to_solve': t,
            'mean_q_loss': self.episodic_result['Training/Q-Loss'][-t:],
            'mean_action_q_value': np.mean(self.episodic_result[
                                       'Training/Mean-Q-Value-Action'][-t:]),
            'mean_opposite_action_q_value': np.mean(self.episodic_result[
                                       'Training/Mean-Q-Value-Opposite-Action'][
                                   -t:]),
            'episode_time': time.time() - start,
        }


@Agent.register('DQNAgent')
class DQNAgent(DeepTDAgent, Registrable):
    def __init__(self, env: Env, agent_cfg: Dict):
        super().__init__(env, agent_cfg)
        self.use_double = agent_cfg['use_double']

    @classmethod
    def from_params(cls, env: Env, params: Dict):
        return cls(env, params)

    def train(self):
        # handle different replay types
        if self.replay_buffer.replay_type == ReplayType.EXPERIENCE_REPLAY.value:
            batch = self.replay_buffer.sample(batch_size=self.batch_size)
        elif self.replay_buffer.replay_type == \
                ReplayType.PRIORITIZED_EXPERIENCE_REPLAY.value:
            batch, weights = self.replay_buffer.sample(
                batch_size=self.batch_size)

        # handle different DQN
        if self.use_double:
            next_q = self.target_net(batch.s1).max(1)[0]
        else:
            next_q_actions = torch.max(self.value_net(batch.s1), dim=1)[1]
            next_q = self.target_net(batch.s1).gather(1,
                                                      next_q_actions.unsqueeze(
                                                          1)).squeeze(1)

        # expected Q and Q using value net
        q_values = self.value_net(batch.s0)
        self.q = q_values.gather(1, batch.a.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            self.q_ = q_values.gather(1, 1 - batch.a.unsqueeze(1)).squeeze(1)
            expected_q = batch.r + self.gamma * (1 - batch.done) * next_q
        if self.replay_buffer.replay_type == ReplayType.EXPERIENCE_REPLAY.value:
            self.loss = self.loss_func(expected_q, self.q)
        elif self.replay_buffer.replay_type == ReplayType. \
                PRIORITIZED_EXPERIENCE_REPLAY.value:
            self.loss = self.loss_func(expected_q, self.q) * torch.FloatTensor(
                weights)
        # torchviz.make_dot(self.loss).render('loss')
        self.loss = self.loss.mean()
        self.optimizer.zero_grad()
        self.loss.backward()

        # gradient clipping to avoid loss divergence based on DeepMind's DQN in
        # 2015, where the author clipped the gradient within [-1, 1]
        if self.use_grad_clipping:
            nn.utils.clip_grad_norm_(self.value_net.parameters(),
                                     self.grad_clipping)

        # update replay buffer..
        if self.replay_buffer.replay_type == ReplayType. \
                PRIORITIZED_EXPERIENCE_REPLAY.value:
            # less memory used
            with torch.no_grad():
                abs_td_error = torch.abs(expected_q - self.q).cpu().numpy() + \
                               self.replay_buffer.non_zero_variant
            self.replay_buffer.update_priorities(losses=abs_td_error)
        self.optimizer.step()


@Agent.register('DeepSarsaAgent')
class DeepSarsaAgent(DeepTDAgent, Registrable):
    def __init__(self, env: Env, agent_cfg: Dict):
        super().__init__(env, agent_cfg)

    @classmethod
    def from_params(cls, env: Env, params: Dict):
        return cls(env, params)

    def train(self):
        # handle different replay types
        if self.replay_buffer.replay_type == ReplayType.EXPERIENCE_REPLAY.value:
            batch = self.replay_buffer.sample(batch_size=self.batch_size)
        elif self.replay_buffer.replay_type == \
                ReplayType.PRIORITIZED_EXPERIENCE_REPLAY.value:
            batch, weights = self.replay_buffer.sample(
                batch_size=self.batch_size)

        next_q = self.target_net(batch.s1).gather(1,
                                                  batch.a.unsqueeze(1)).squeeze(
            1)

        q_values = self.value_net(batch.s0)
        self.q = q_values.gather(1, batch.a.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            self.q_ = q_values.gather(1, 1 - batch.a.unsqueeze(1)).squeeze(1)
            expected_q = batch.r + self.gamma * (1 - batch.done) * next_q
        if self.replay_buffer.replay_type == ReplayType.EXPERIENCE_REPLAY.value:
            self.loss = self.loss_func(expected_q, self.q)
        elif self.replay_buffer.replay_type == ReplayType. \
                PRIORITIZED_EXPERIENCE_REPLAY.value:
            self.loss = self.loss_func(expected_q, self.q) * torch.FloatTensor(
                weights)
        # torchviz.make_dot(self.loss).render('loss')
        self.loss = self.loss.mean()
        self.optimizer.zero_grad()
        self.loss.backward()

        # gradient clipping to avoid loss divergence based on DeepMind's DQN in
        # 2015, where the author clipped the gradient within [-1, 1]
        if self.use_grad_clipping:
            nn.utils.clip_grad_norm_(self.value_net.parameters(),
                                     self.grad_clipping)

        # update replay buffer..
        if self.replay_buffer.replay_type == ReplayType. \
                PRIORITIZED_EXPERIENCE_REPLAY.value:
            # less memory used
            with torch.no_grad():
                abs_td_error = torch.abs(expected_q - self.q).cpu().numpy() + \
                               self.replay_buffer.non_zero_variant
            self.replay_buffer.update_priorities(losses=abs_td_error)
        self.optimizer.step()


@Agent.register('DeepExpectedSarsaAgent')
class DeepExpectedSarsaAgent(DeepTDAgent, Registrable):
    def __init__(self, env: Env, agent_cfg: Dict):
        super().__init__(env, agent_cfg)

    @classmethod
    def from_params(cls, env: Env, params: Dict):
        return cls(env, params)

    def train(self):
        # handle different replay types
        if self.replay_buffer.replay_type == ReplayType.EXPERIENCE_REPLAY.value:
            batch = self.replay_buffer.sample(batch_size=self.batch_size)
        elif self.replay_buffer.replay_type == \
                ReplayType.PRIORITIZED_EXPERIENCE_REPLAY.value:
            batch, weights = self.replay_buffer.sample(
                batch_size=self.batch_size)
        prob_dist = get_epsilon_dist(eps=self.eps, env=self.env,
                                     model=self.value_net, observation=batch.s1)
        next_q = torch.sum(self.target_net(batch.s1) * prob_dist.probs,
                           axis=1)

        # expected Q and Q using value net
        q_values = self.value_net(batch.s0)
        self.q = q_values.gather(1, batch.a.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            self.q_ = q_values.gather(1, 1 - batch.a.unsqueeze(1)).squeeze(1)
            expected_q = batch.r + self.gamma * (1 - batch.done) * next_q
        if self.replay_buffer.replay_type == ReplayType.EXPERIENCE_REPLAY.value:
            self.loss = self.loss_func(expected_q, self.q)
        elif self.replay_buffer.replay_type == ReplayType. \
                PRIORITIZED_EXPERIENCE_REPLAY.value:
            self.loss = self.loss_func(expected_q, self.q) * torch.FloatTensor(
                weights)
        # torchviz.make_dot(self.loss).render('loss')
        self.loss = self.loss.mean()
        self.optimizer.zero_grad()
        self.loss.backward()

        # gradient clipping to avoid loss divergence based on DeepMind's DQN in
        # 2015, where the author clipped the gradient within [-1, 1]
        if self.use_grad_clipping:
            nn.utils.clip_grad_norm_(self.value_net.parameters(),
                                     self.grad_clipping)

        # update replay buffer..
        if self.replay_buffer.replay_type == ReplayType. \
                PRIORITIZED_EXPERIENCE_REPLAY.value:
            # less memory used
            with torch.no_grad():
                abs_td_error = torch.abs(expected_q - self.q).cpu().numpy() + \
                               self.replay_buffer.non_zero_variant
            self.replay_buffer.update_priorities(losses=abs_td_error)
        self.optimizer.step()


from typing import Dict, List
from collections import defaultdict
from datetime import datetime

class Experiment(Registrable):
    def __init__(self,
                 logger: ProjectLogger,
                 env_names: List,
                 agents: List,
                 seeds: List,
                 experiment_cfg: dict,
                 agent_cfg: dict,
                 *args, **kwargs):
        self.logger = logger
        self.env_names = env_names
        self.agents = agents
        self.seeds = seeds
        self.experiment_cfg = experiment_cfg
        self.agent_cfg = agent_cfg
        self.experiment_cfg['date'] = datetime.today().strftime('%Y-%m-%d')

    @classmethod
    def build(cls, type: str, logger: ProjectLogger, params: Dict):
        experiment = cls.by_name(type)
        return experiment.from_params(logger, params)

    @classmethod
    def from_params(cls, logger: ProjectLogger, params: Dict):
        return cls(logger, **params)

    def generate_metrics(self, results: List) -> defaultdict:
        """generate whatever metrics needed for the experiment"""
        raise NotImplementedError('Experiment must generate metrics!')


import ray
import gym
import numpy as np
import pickle
from collections import defaultdict, deque
from typing import Dict, List, Tuple

from torch.utils.tensorboard import SummaryWriter


@Experiment.register('DeepTDExperiment')
class DeepTDExperiment(Experiment):
    def __init__(self,
                 logger: ProjectLogger,
                 *args, **kwargs):
        super().__init__(logger=logger, *args, **kwargs)
        self.replay_buffer_capacities = self.experiment_cfg[
            'replay_buffer_capacities']
        self.lrs = self.experiment_cfg['lrs']

    @classmethod
    def from_params(cls, logger: ProjectLogger, params: Dict):
        return cls(logger, **params)

    def run(self) -> defaultdict:
        """for each gym environment and RL agentrithms, test different replay
        buffer capaciity over multiple seeds"""
        output = defaultdict(nested_d)
        for env_name in self.env_names:
            for agent in self.agents:
                for capacity in self.replay_buffer_capacities:
                    self.agent_cfg[agent]['experience_replay']['params'][
                        'capacity'] = capacity
                    results = [DeepTDExperiment._inner_run.remote(
                        agent_cfg=self.agent_cfg,
                        experiment_cfg=self.experiment_cfg,
                        env_name=env_name, seed=seed, agent_name=agent) for seed
                        in self.seeds]
                    results = ray.get(results)
                    output = self.generate_metrics(env_name=env_name,
                                                   agent=agent,
                                                   capacity=capacity,
                                                   results=results,
                                                   output=output)
                    with open(self.experiment_cfg['experiment_path'],
                              'wb') as file:
                        pickle.dump(output, file)
                self.logger.info(
                    f'Finished running experiments for {env_name} | {agent}')
        return output

    @staticmethod
    @ray.remote
    def _inner_run(agent_cfg: dict, experiment_cfg: dict,
                   env_name: str, seed: int = 1, agent_name: str = 'sarsa') -> \
            Tuple[np.array, np.array]:

        # seed and result initialization
        np.random.seed(seed)
        mean_q_loss = np.zeros((len(experiment_cfg['lrs']),
                                 experiment_cfg['runs'],
                                 experiment_cfg['episodes']))
        cum_reward = np.zeros((len(experiment_cfg['lrs']),
                               experiment_cfg['runs'],
                               experiment_cfg['episodes']))
        time_to_solve = np.ones((len(experiment_cfg['lrs']),
                                 experiment_cfg['runs'],
                                 experiment_cfg['episodes'])) * experiment_cfg[
                            'steps']
        env = gym.make(env_name)

        # O(lrs * runs * episodes * max(test_rng * steps, steps))
        for i_lr in range(len(experiment_cfg['lrs'])):
            # create agent, set the learning rate, tensorboard path..
            agent_config = agent_cfg[agent_name]
            agent_config['lr'] = experiment_cfg['lrs'][i_lr]
            agent_config['seed'] = seed
            agent = Agent.build(type=agent_name, env=env, params=agent_config)

            # go through runs, in order to further average, and episodes
            for r in range(experiment_cfg['runs']):
                for i_episode in range(experiment_cfg['episodes']):
                    generator_obj = agent.learn(
                        num_steps=experiment_cfg['steps'])
                    episode_result = next(generator_obj)
                    cum_reward[i_lr, r, i_episode] = episode_result[
                        'cum_reward']
                    time_to_solve[i_lr, r, i_episode] = episode_result[
                        'time_to_solve']
                    mean_q_loss[i_lr, r, i_episode] = episode_result[
                        'mean_q_loss']
                    
                    msg = f"lr {agent_config['lr']} | run {r} | " \
                        f"episode {i_episode} | eps {agent.eps} "
                    for k, v in episode_result.items():
                        msg += f"| {k} {v} "
                    print(msg)
        env.close()

        # generates learning rates * episodes
        cum_reward = np.mean(cum_reward, axis=1)
        time_to_solve = np.mean(time_to_solve, axis=1)
        return cum_reward, time_to_solve, mean_q_loss

    def generate_metrics(self, env_name: str, agent: str, capacity: int,
                         results: List, output: defaultdict) -> defaultdict:
        """generate whatever metrics needed for the experiment"""
        # results over the seeds
        for idx, lr in enumerate(self.experiment_cfg['lrs']):
            output[env_name][agent][capacity]['mean_q_loss'][lr] = np.mean(
                [results[i][2] for i in range(len(results))], axis=0)[idx]
            output[env_name][agent][capacity]['mean_cum_rewards'][lr] = np.mean(
                [results[i][0] for i in range(len(results))], axis=0)[idx]
            output[env_name][agent][capacity]['std_cum_rewards'][lr] = np.std(
                [results[i][0] for i in range(len(results))], axis=0)[idx]
            output[env_name][agent][capacity]['upper_std_cum_rewards'][lr] = \
                output[env_name][agent][capacity]['mean_cum_rewards'][lr] + \
                output[env_name][agent][capacity]['std_cum_rewards'][lr]
            output[env_name][agent][capacity]['lower_std_cum_rewards'][lr] = \
                output[env_name][agent][capacity]['mean_cum_rewards'][lr] - \
                output[env_name][agent][capacity]['std_cum_rewards'][lr]
            output[env_name][agent][capacity]['max_cum_rewards'][lr] = np.max(
                [results[i][0] for i in range(len(results))], axis=0)[idx]
            output[env_name][agent][capacity]['mean_timesteps'][lr] = np.mean(
                [results[i][1] for i in range(len(results))], axis=0)[idx]
            output[env_name][agent][capacity]['min_timesteps'][lr] = np.min(
                [results[i][1] for i in range(len(results))], axis=0)[idx]
            output[env_name][agent][capacity]['max_timesteps'][lr] = np.max(
                [results[i][1] for i in range(len(results))], axis=0)[idx]
        return output


import numpy as np
import matplotlib.pyplot as plt


def plot_lr_reward(output: Dict):
    for env_name in output.keys():
        for agent, capacities in output[env_name].items():

            plt.rcParams.update({'font.size': 18})
            fig = plt.figure(figsize=(10, 16)).add_subplot(111)
            fig.title.set_text(f'{agent} Plot #1')
            fig.set_ylabel('Average Reward of last 10 episodes')
            fig.set_xlabel(r'$\alpha$')

            for capacity, metrics in capacities.items():
                for metric, lrs in metrics.items():
                    if metric == 'mean_cum_rewards':
                        fig.plot(list(lrs.keys()),
                                 [np.mean(v[-10:]) for k, v in lrs.items()],
                                 label=f'capacity={capacity}')

            plt.grid(linestyle='--')
            plt.legend(loc='upper left')
            plt.show()
            plt.savefig(f'{agent}_plot_1.png')
            plt.clf()

    # find max based on...?
    max_cap, max_lr = 500, 0.05
    # pick a capacity/lr with best episode rewards..
    for env_name in output.keys():
        for agent, capacities in output[env_name].items():
            plt.rcParams.update({'font.size': 18})
            fig = plt.figure(figsize=(16, 10)).add_subplot(111)
            fig.title.set_text(f'{agent} Plot #2')
            fig.set_ylabel('Reward')
            fig.set_xlabel('Episode')
            for capacity, metrics in capacities.items():
                if capacity == max_cap:
                    for metric, lrs in metrics.items():
                        if 'mean' in metric or 'upper' in metric or 'lower' in metric:
                            fig.plot(np.arange(len(lrs[max_lr])),
                                     lrs[max_lr],
                                     label=metric)
            plt.grid(linestyle='--')
            plt.legend(loc='upper left')
            plt.show()
            plt.savefig(f'{agent}_plot_2.png')
            plt.clf()




if __name__ == "__main__":
    # logger
    logger = ProjectLogger(level=10)

    # load initial configs for params
    cfg = evaluate_file(f'{CONFIG_DIR}/n_step_td_config.jsonnet')
    logger.info(f'Using the following config: \n{cfg}')
    cfg = json.loads(cfg)
    
    ray.init(
        # local_mode=True,
        ignore_reinit_error=True,
    )

    # specs for the experiment & agent
    experiment_cfg, agent_cfg = cfg['experiment_cfg'], cfg['agent_cfg']
    experiment_path = f"{RESULT_DIR}/" \
        f"{cfg['experiment_name']}_experiments.pickle"
    hyperparams_path = f"{RESULT_DIR}/" \
        f"{cfg['experiment_name']}_experiments_hyperparameters.pickle"
    tensorboard_path = f"{TENSORBOARD_DIR}/{cfg['experiment_name']}/trainer"
    experiment_cfg['experiment_path'] = experiment_path
    experiment_cfg['hyperparams_path'] = hyperparams_path
    experiment_cfg['tensorboard_path'] = tensorboard_path
    seeds = np.random.choice(99999, 10, replace=False)
    # seeds = [1337]
    agents = cfg['agents']
    env_names = cfg['env_names']
    params = {'env_names': env_names, 'agents': agents, 'seeds': seeds,
              'experiment_cfg': experiment_cfg, 'agent_cfg': agent_cfg}

    experiment = Experiment.build(type=cfg['experiment_name'], logger=logger,
                                  params=params)

    # run dp experiments
    output = experiment.run()
    logger.info(f'Finished running experiments')
    
    # with open(f'{RESULT_DIR}/experiments.pickle', 'rb') as file:
    #     output = pickle.load(file)
    # 
    # cfg = evaluate_file(f'{CONFIG_DIR}/n_step_td_config.jsonnet')
    # cfg = json.loads(cfg)
    # 
    # # specs for the experiment
    # experiment_cfg = cfg['experiment_cfg']
    # plot_lr_reward(output)