# **ProjectAthena**
## **Subsection: Q-Net**
### Q-Net: Optimizing the performance of off policy reinforcement learning algorithims using supervised meta learned networks as replacements.
## **Goal:**
#### - Outperform Q-Learning algorithims in training time and effectiveness
#### - If possible, generate a replacement to the Q algorithim that is generalized to all programs
#### - If possible, generate a model, so effective, that it replaces the existing RL algorithms that already outperform Q-Learning

# **Policy Gradient (aka Reinforce)**

## **Imports and Global Variables**

In [None]:
!pip install cherry-rl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting cherry-rl
  Downloading cherry-rl-0.1.4.tar.gz (40 kB)
[K     |████████████████████████████████| 40 kB 5.0 MB/s 
Building wheels for collected packages: cherry-rl
  Building wheel for cherry-rl (setup.py) ... [?25l[?25hdone
  Created wheel for cherry-rl: filename=cherry_rl-0.1.4-py3-none-any.whl size=57842 sha256=68b67780cff3f30d0ed56136518cd5ac273f5dcc2e70c492edf0c85ecfaffcf5
  Stored in directory: /root/.cache/pip/wheels/9e/34/30/42ee307e7214d2af33c5999c6379feac44f1734baa02c2d850
Successfully built cherry-rl
Installing collected packages: cherry-rl
Successfully installed cherry-rl-0.1.4


In [None]:
import cherry.distributions as distributions
from torch.distributions import Categorical
import torch.nn.functional as F
from cherry.td import discount
from cherry import normalize
import torch.nn as nn
import cherry as ch
import gym.spaces
import torch
import gym

In [None]:
env_name = "CartPole-v0"
render = False
gamma = 0.99

# Only for getting timesteps, and obs-action spaces
sample_env = gym.make(env_name)
timestep_limit = sample_env.spec.max_episode_steps
obs_space = sample_env.observation_space
action_space = sample_env.action_space
print("Observation space:", obs_space)
print("Action space:", action_space)

obs_size = obs_space.low.size
action_size = action_space.n

def update(replay, optimizer):
    policy_loss = []
    total_rewards = 0

    # Discount and normalize rewards
    rewards = ch.discount(gamma, replay.reward(), replay.done())
    rewards = ch.normalize(rewards)

    # Compute loss
    for sars, reward in zip(replay, rewards):
        log_prob = sars.log_prob
        policy_loss.append(-log_prob * reward)
        
        total_rewards += reward

    # Take optimization step
    optimizer.zero_grad()
    policy_loss = torch.stack(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    
    return total_rewards

def get_action_value(state, policy):
    mass = Categorical(policy(state))
    action = mass.sample()
    info = {
        'log_prob': mass.log_prob(action),  # Cache log_prob for later
    }
    return action, info

In [None]:
class PolicyNet(nn.Module):
    def __init__(self, env):
        super(PolicyNet, self).__init__()
        self.network = torch.nn.Sequential(
            nn.Linear(env.state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_size),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.network(x)

env = gym.make(env_name).env
env = ch.envs.Torch(env)
env = ch.envs.Runner(env)
replay = ch.ExperienceReplay()
device = "cpu"

policy = PolicyNet(env)
optimizer = torch.optim.RAdam(policy.parameters())
get_action = lambda state: get_action_value(state, policy)
epochs = 5

for e in range(epochs):
    replay = env.run(get_action, episodes=100, render=render)
    print(f"Epoch {e}  Total Reward: {round(update(replay, optimizer).item(), 2)}\ttimesteps: {len(replay)}")

env.close()

# **DQN**

## **Imports and Global Variables**

In [None]:
!pip install simple_rl

In [None]:
import cherry.distributions as distributions
from torch.distributions import Categorical
from collections import defaultdict
from prettytable import PrettyTable
import torch.nn.functional as F
from cherry.td import discount
from cherry import normalize
import torch.nn as nn
import cherry as ch
import numpy as np
import gym.spaces
import argparse
import random
import random
import tqdm
import torch
import time
import gym

In [None]:
# Other imports.
from simple_rl.agents.AgentClass import Agent

class QNet(Agent):
  def __init__(self, input, output):
    self.model = 

class QLearningAgent(Agent):
    ''' Implementation for a Q Learning Agent '''

    def __init__(self, actions, name="Q-learning", alpha=0.1, gamma=0.99, epsilon=0.1, explore="uniform", anneal=False, custom_q_init=None, default_q=0):
        '''
        Args:
            actions (list): Contains strings denoting the actions.
            name (str): Denotes the name of the agent.
            alpha (float): Learning rate.
            gamma (float): Discount factor.
            epsilon (float): Exploration term.
            explore (str): One of {softmax, uniform}. Denotes explore policy.
            custom_q_init (defaultdict{state, defaultdict{action, float}}): a dictionary of dictionaries storing the initial q-values. Can be used for potential shaping (Wiewiora, 2003)
            default_q (float): the default value to initialize every entry in the q-table with [by default, set to 0.0]
        '''
        name_ext = "-" + explore if explore != "uniform" else ""
        Agent.__init__(self, name=name + name_ext, actions=actions, gamma=gamma)

        # Set/initialize parameters and other relevant classwide data
        self.alpha, self.alpha_init = alpha, alpha
        self.epsilon, self.epsilon_init = epsilon, epsilon
        self.step_number = 0
        self.anneal = anneal
        self.default_q = default_q # 0 # 1 / (1 - self.gamma)
        self.explore = explore
        self.custom_q_init = custom_q_init
        self.qnet = Qnet()
        # Q Function:
        if self.custom_q_init:
            self.q_func = self.custom_q_init
        else:
            self.q_func = defaultdict(lambda: defaultdict(lambda: self.default_q))
        
        # Key: state
        # Val: dict
            #   Key: action
            #   Val: q-value


    def get_parameters(self):
        '''
        Returns:
            (dict) key=param_name (str) --> val=param_val (object).
        '''
        param_dict = defaultdict(int)

        param_dict["alpha"] = self.alpha
        param_dict["gamma"] = self.gamma
        param_dict["epsilon"] = self.epsilon_init
        param_dict["anneal"] = self.anneal
        param_dict["explore"] = self.explore

        return param_dict

    # --------------------------------
    # ---- CENTRAL ACTION METHODS ----
    # --------------------------------

    def act(self, state, reward, learning=True):
        '''
        Args:
            state (State)
            reward (float)
        Returns:
            (str)
        Summary:
            The central method called during each time step.
            Retrieves the action according to the current policy
            and performs updates given (s=self.prev_state,
            a=self.prev_action, r=reward, s'=state)
        '''
        if learning:
            self.update(self.prev_state, self.prev_action, reward, state)
        if self.explore == "softmax":
            # Softmax exploration
            action = self.soft_max_policy(state)
        else:
            # Uniform exploration
            action = self.epsilon_greedy_q_policy(state)

        self.prev_state = state
        self.prev_action = action
        self.step_number += 1

        # Anneal params.
        if learning and self.anneal:
            self._anneal()

        return action

    def epsilon_greedy_q_policy(self, state):
        '''
        Args:
            state (State)
        Returns:
            (str): action.
        '''
        # Policy: Epsilon of the time explore, otherwise, greedyQ.
        if numpy.random.random() > self.epsilon:
            # Exploit.
            action = self.get_max_q_action(state)
        else:
            # Explore
            action = numpy.random.choice(self.actions)

        return action

    def soft_max_policy(self, state):
        '''
        Args:
            state (State): Contains relevant state information.
        Returns:
            (str): action.
        '''
        return numpy.random.choice(self.actions, 1, p=self.get_action_distr(state))[0]

    # ---------------------------------
    # ---- Q VALUES AND PARAMETERS ----
    # ---------------------------------

    def update(self, state, action, reward, next_state):
        '''
        Args:
            state (State)
            action (str)
            reward (float)
            next_state (State)
        Summary:
            Updates the internal Q Function according to the Bellman Equation. (Classic Q Learning update)
        '''
        # If this is the first state, just return.
        if state is None:
            self.prev_state = next_state
            return

        # Update the Q Function.
        max_q_curr_state = self.get_max_q_value(next_state)
        prev_q_val = self.get_q_value(state, action)
        self.q_func[state][action] = (1 - self.alpha) * prev_q_val + self.alpha * (reward + self.gamma*max_q_curr_state)
        

    def _anneal(self):
        # Taken from "Note on learning rate schedules for stochastic optimization, by Darken and Moody (Yale)":
        self.alpha = self.alpha_init / (1.0 +  (self.step_number / 1000.0)*(self.episode_number + 1) / 2000.0 )
        self.epsilon = self.epsilon_init / (1.0 + (self.step_number / 1000.0)*(self.episode_number + 1) / 2000.0 )

    def _compute_max_qval_action_pair(self, state):
        '''
        Args:
            state (State)
        Returns:
            (tuple) --> (float, str): where the float is the Qval, str is the action.
        '''
        # Grab random initial action in case all equal
        best_action = random.choice(self.actions)
        max_q_val = float("-inf")
        shuffled_action_list = self.actions[:]
        random.shuffle(shuffled_action_list)

        # Find best action (action w/ current max predicted Q value)
        for action in shuffled_action_list:
            q_s_a = self.get_q_value(state, action)
            if q_s_a > max_q_val:
                max_q_val = q_s_a
                best_action = action

        return max_q_val, best_action

    def get_max_q_action(self, state):
        '''
        Args:
            state (State)
        Returns:
            (str): denoting the action with the max q value in the given @state.
        '''
        return self._compute_max_qval_action_pair(state)[1]

    def get_max_q_value(self, state):
        '''
        Args:
            state (State)
        Returns:
            (float): denoting the max q value in the given @state.
        '''
        return self._compute_max_qval_action_pair(state)[0]

    def get_value(self, state):
        '''
        Args:
            state (State)
        Returns:
            (float)
        '''
        return self.get_max_q_value(state)

    def get_q_value(self, state, action):
        '''
        Args:
            state (State)
            action (str)
        Returns:
            (float): denoting the q value of the (@state, @action) pair.
        '''
        return self.q_func[state][action]

    def get_action_distr(self, state, beta=0.2):
        '''
        Args:
            state (State)
            beta (float): Softmax temperature parameter.
        Returns:
            (list of floats): The i-th float corresponds to the probability
            mass associated with the i-th action (indexing into self.actions)
        '''
        all_q_vals = []
        for i, action in enumerate(self.actions):
            all_q_vals.append(self.get_q_value(state, action))

        # Softmax distribution.
        total = sum([numpy.exp(beta * qv) for qv in all_q_vals])
        softmax = [numpy.exp(beta * qv) / total for qv in all_q_vals]

        return softmax

    def reset(self):
        self.step_number = 0
        self.episode_number = 0
        if self.custom_q_init:
            self.q_func = self.custom_q_init
        else:
            self.q_func = defaultdict(lambda : defaultdict(lambda: self.default_q))
        Agent.reset(self)

    def end_of_episode(self):
        '''
        Summary:
            Resets the agents prior pointers.
        '''
        if self.anneal:
            self._anneal()
        Agent.end_of_episode(self)

    def print_v_func(self):
        '''
        Summary:
            Prints the V function.
        '''
        for state in self.q_func.keys():
            print(state, self.get_value(state))

    def print_q_func(self):
        '''
        Summary:
            Prints the Q function.
        '''
        if len(self.q_func) == 0:
            print("Q Func empty!")
        else:
            for state, actiond in self.q_func.items():
                print(state)
                for action, q_val in actiond.items():
                    print("    ", action, q_val)