In [17]:
from typing import Optional, Tuple
from abc import ABC, abstractmethod
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
from torch.distributions import Normal

import gym

In [21]:
class DiagonalGaussian(nn.Module):
    def __init__(
        self, obs_dim: int, hidden_dim: int, action_dim: int, activation
    ) -> None:
        super(DiagonalGaussian, self).__init__()
        log_std = -0.5 * np.ones(action_dim, dtype=float)
        self.covariance_matrix = torch.nn.Parameter(torch.as_tensor(log_std))
        self.mean_action_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            activation,
            nn.Linear(hidden_dim, action_dim),
            activation,
        )

    def _distribution(self, observation):
        mean_act = self.mean_action_net(observation)
        covariance_mat = torch.exp(self.covariance_matrix)
        return Normal(mean_act, covariance_mat)

    def _log_probs_from_dist(self, policy_dist, action):
        return policy_dist.log_prob(action).sum(axis=-1)

    def forward(self, observation, action=None):
        policy_dist = self._distribution(observation)
        logp_act = None
        if action is not None:
            logp_act = self._log_probs_from_dist(policy_dist, action)
        return policy_dist, logp_act


class ValueFunctionLearner(nn.Module):
    def __init__(
        self, obs_dim: int, hidden_dim: int, action_dim: int, activation
    ) -> None:
        super(ValueFunctionLearner, self).__init__()
        self.v_net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            activation,
            nn.Linear(hidden_dim, action_dim),
            activation,
        )

    def forward(self, observation):
        # return torch.squeeze(self.v_net(observation), -1)
        return self.v_net(observation)


class Agent:
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        hidden_dim: int = 32,
        activation=nn.Softmax(dim=-1),
    ) -> None:
        super(Agent, self).__init__()
        self.policy = DiagonalGaussian(obs_dim, hidden_dim, action_dim, activation)
        self.value_func = ValueFunctionLearner(
            obs_dim, hidden_dim, action_dim, activation
        )

    def step(self, obs: torch.Tensor):
        with torch.no_grad():
            policy_dist = self.policy._distribution(obs)
            action = policy_dist.sample()
            action[0] = 1
            mean_action = self.policy._log_probs_from_dist(policy_dist, action)
            value = self.value_func(obs)
        return action.numpy(), value.numpy(), mean_action.numpy()

    def act(self, obs: torch.Tensor):
        return self.step(obs)[0]


In [29]:
class Advantage(ABC):
    @abstractmethod
    def estimate(self, values: np.ndarray) -> np.ndarray:
        raise NotImplementedError
class ReturnEstimator(ABC):
    @abstractmethod
    def get_return(self, rewards: np.ndarray, gamma: float) -> np.ndarray:
        raise NotImplementedError

@dataclass
class DiscountReturn(ReturnEstimator):
    def get_return(self, rewards: np.ndarray, gamma: float = 0.99) -> np.ndarray:
        pot = np.cumsum(np.ones(len(rewards))) - 1
        g = np.full(len(pot), fill_value=gamma)
        discount_gamma = g**pot
        return rewards * discount_gamma
        
@dataclass
class GAE(Advantage):
    return_estimator: ReturnEstimator
    lamda: Optional[float] = 0.5
    gamma: Optional[float] = 0.99

    def estimate(self, rewards: np.ndarray, values: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        rew = np.append(rewards, 0)
        val = np.append(values, 0)
        deltas = rew[:-1] + (self.gamma * val[1:]) - val[:-1]
        adv = self.return_estimator.get_return(deltas, self.lamda*self.gamma)
        ret = self.return_estimator.get_return(rew, self.gamma)[:-1] # value function targets
        return adv, ret


class TrajectoryReplayBuffer:
    """A buffer class for storing trajectory data"""

    def __init__(
        self,
        advantage: Advantage,
        return_estimator: ReturnEstimator,
        obs_dim: int,
        act_dim: int,
        val_dim: int,
        buf_size: int,
    ) -> None:
        self._buf_size = buf_size
        self._ret_estimator = return_estimator()
        self._adv_estimator = advantage(self._ret_estimator)
        self._obs = np.zeros((buf_size, obs_dim), dtype=float)
        self._act = np.zeros((buf_size, act_dim), dtype=float)
        self._val = np.zeros((buf_size, val_dim), dtype=float)
        self._adv = np.zeros(buf_size, dtype=float)
        self._mean_act = np.zeros(buf_size, dtype=float)
        self._rewards = np.zeros(buf_size, dtype=float)

    def store(
        self,
        idx: int,
        action: np.ndarray,
        value: np.ndarray,
        reward: float,
        mean_act: float,
    ) -> None:
        assert idx < self._buf_size
        self._act[idx] = action
        self._val[idx] = value
        self._rewards[idx] = reward
        self._mean_act[idx] = mean_act

    def compute_advantage(self):
        self._adv, value_targets = self._adv_estimator.estimate(self._rewards, self._val)
        return self._adv, value_targets

    def get_trajectories(self):
        data = dict(V=self._val)
        return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()}

    def expected_returns(self, arr: np.ndarray) -> np.ndarray:
        expected_returns = np.zeros(arr.shape)
        for i in reversed(range(len(arr))):
            ret_t = self._ret_estimator(arr[i:])
            expected_returns[i] = ret_t
        return expected_returns

In [28]:
### Setting up Hyperparameters ###

episodes = 1
episode_len = 1

obs_dim = 6
act_dim = 6
val_dim = 6
hidden_dim = 32
buf_size = episode_len

### Init Agent ###
trajectory_buffer = TrajectoryReplayBuffer(
    GAE,
    DiscountReturn,
    obs_dim=obs_dim,
    act_dim=act_dim,
    val_dim=val_dim,
    buf_size=buf_size
)
agent = Agent(
    obs_dim=obs_dim,
    action_dim=act_dim,
    hidden_dim=hidden_dim,
    activation=nn.Softmax(dim=-1)
)
### Init Agent ###

In [None]:
### Training Loop ###
env = gym.make("FetchReach-v1")
env = gym.wrappers.FlattenObservation(env)
for _ in range(episodes): 
  obs = env.reset()
  for t in range(episode_len):
    env.render()
    # action = env.action_space.sample()
    a = agent.act(obs)
    obs, reward, done, info = env.step(a)
    
    if done:
      print(f"Episode finished after {t} timesteps")
      break

env.close()