In [None]:
import gym
from gym import spaces
import matplotlib.pyplot as plt
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import cv2
from collections import deque
from torch.distributions import Categorical
import math as m
from torch.nn.utils.convert_parameters import vector_to_parameters
from IPython.display import clear_output
from gym.core import ObservationWrapper
from gym.spaces.box import Box


In [None]:

def create_atari_env(env_id):
    env = gym.make(env_id)
    env = AtariRescale84x84(env)
    env = NormalizedEnv(env)
    env = EpisodicLifeEnv(env)
    env = MaxAndSkipEnv(env)
    return env




def process_frame84(frame):
    frame = frame[34:34 + 160, :160]
    frame = cv2.resize(frame, (84, 84))
    frame = frame.mean(2)
    frame = frame.astype(np.float32)
    frame *= (1.0 / 255.0)
    return frame



class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env=None):
        """Make end-of-life == end-of-episode, but only reset on true game over.
        Done by DeepMind for the DQN and co. since it helps value estimation.
        """
        super(EpisodicLifeEnv, self).__init__(env)
        self.lives = 0
        self.was_real_done = True
        self.was_real_reset = False

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.was_real_done = done
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if lives < self.lives and lives > 0:
            # for Qbert somtimes we stay in lives == 0 condtion for a few frames
            # so its important to keep lives > 0, so that we only reset once
            # the environment advertises done.
            done = True
        self.lives = lives
        return obs, reward, done, info

    def reset(self):
        """Reset only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.
        """
        if self.was_real_done:
            obs = self.env.reset()
            self.was_real_reset = True
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _ = self.env.step(0)
            self.was_real_reset = False
        self.lives = self.env.unwrapped.ale.lives()
        return obs


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break

        max_frame = np.max(np.stack(self._obs_buffer), axis=0)

        return max_frame, total_reward, done, info

    def reset(self):
        """Clear past frame buffer and init. to first obs. from inner env."""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs
  

class AtariRescale84x84(gym.ObservationWrapper):

    def __init__(self, env=None):
        super(AtariRescale84x84, self).__init__(env)
        self.observation_space = Box(0.0, 1.0, [1, 84, 84])

    def observation(self, observation):
        return process_frame84(observation) 




class NormalizedEnv(gym.ObservationWrapper):

    def __init__(self, env=None):
        super(NormalizedEnv, self).__init__(env)
        self.state_mean = 0
        self.state_std = 0
        self.alpha = 0.9999
        self.num_steps = 0

    def observation(self, observation):
        self.num_steps += 1
        self.state_mean = self.state_mean * self.alpha + \
            observation.mean() * (1 - self.alpha)
        self.state_std = self.state_std * self.alpha + \
            observation.std() * (1 - self.alpha)

        unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
        unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))
        ret = (observation - unbiased_mean) / (unbiased_std + 1e-8)
        return np.expand_dims(ret, axis=0)


class NormalizedState:

    def __init__(self):
        self.state_mean = 0
        self.state_std = 0
        self.alpha = 0.9999
        self.num_steps = 0

    def observation(self, observation):
        self.num_steps += 1
        self.state_mean = self.state_mean * self.alpha + \
            observation.mean() * (1 - self.alpha)
        self.state_std = self.state_std * self.alpha + \
            observation.std() * (1 - self.alpha)

        unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps))
        unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps))
        ret = (observation - unbiased_mean) / (unbiased_std + 1e-8)
        return np.expand_dims(ret, axis=0)

In [None]:
env_id = 'PongNoFrameskip-v4'
env = create_atari_env(env_id)
num_actions = env.action_space.n

In [None]:
# Cuda device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
gamma = 0.99
res_threshold = 1e-10
cg_max_iters = 10
delta =  0.01

In [None]:
class TrpoNet(nn.Module):

  def __init__(self, num_actions):

    super(TrpoNet, self).__init__()
    self.conv_layers = nn.Sequential(
        nn.Conv2d(1, 16, 8, 4),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.Conv2d(16, 32, 4, 2),
        nn.BatchNorm2d(32),
        nn.ReLU()
    )

    self.fc_layers = nn.Sequential(
        nn.Linear(2592, 256),
        nn.ReLU(),
        nn.Linear(256, num_actions)
    )

    self.name = 'trpo'

  def forward(self, x):
      x = torch.FloatTensor(x).view(-1, 1, 84, 84).to(device)
      x = self.conv_layers(x)
      x = self.fc_layers(x.view(-1, 2592))
      output = F.softmax(x,dim=1)
      # Avoid one of the elements equal to 0
      output = output + 1e-6
      output = F.normalize(output, dim=1, p=1)
      return output

  def act(self, input):
      prob = self.forward(input)
      categorical = Categorical(prob)
      action = categorical.sample()
      return action, prob


In [None]:
class TrajectoryRecord:
  def __init__(self):
    self.rewards = []
    self.actions = []
    self.states = []
    self.dones = []
  
  def push(self, reward, action, state, done):
    self.rewards.append(reward)
    self.actions.append(action)
    self.states.append(state)
    self.dones.append(done)
  
  def returnRecord(self):
    rewards = self.rewards
    states = self.states
    actions = self.actions
    dones = self.dones
    return rewards, states, actions, dones

  def reset(self):
    self.rewards = []
    self.actions = []
    self.states = []
    self.dones = []

In [None]:
def get_surrogate_loss(selected_action_prbs, q_values):
    '''
    L = mean of (pi(a_n|s_n) / pi_old(a_n|sn) * Q_old(s_n, a_n)
    '''
    L = (selected_action_prbs / selected_action_prbs.data) * torch.FloatTensor(q_values).to(device)
    return L.mean()

def get_Q(rewards, dones):
    '''
    This method computes and returns the action-state values along the generated
    trajectory
    '''
    R = 0
    Q = []
    for r, done in zip(reversed(rewards), reversed(dones)):
        R = r + gamma * R * (1 - done) 
        Q.insert(0, R)
    return Q
def flat_parameters(param):
    '''
    Convert a list of tensors with different sizes into an 1d array of parameters
    '''
    return torch.cat([grad.contiguous().view(-1) for grad in param])
def get_fisher_vector_product(x, prbs, model, damping = 1e-2):
    '''
    FVP is used to indirectly compute hassin matrix with more efficency, and it
    is used for conjugate gradient.
    y = Hx
    '''
    # Step 1, compute the product of first derivative of KL divergence wrt theta and x
    kl = get_kl(prbs)
    model.zero_grad()
    kl_1_grads_ = torch.autograd.grad(kl, model.parameters(), create_graph = True, retain_graph = True)
    kl_1_grads = flat_parameters(kl_1_grads_)
    # Step2, compute the sum of the product of kl first derivative and x
    kl_1_grads_product = kl_1_grads * x
    kl_1_grads_product_sum = kl_1_grads_product.sum()
    # Step3, obtain fisher_vector_product by differentiating the result we get at step2
    model.zero_grad()
    kl_2_grads = torch.autograd.grad(kl_1_grads_product_sum, model.parameters(), retain_graph = True)
    fisher_vector_product = flat_parameters(kl_2_grads)
    return fisher_vector_product + damping * x

In [None]:
def get_kl_compare(pi, pi_old):
    return (pi_old * torch.log(pi_old / pi)).mean().item()

def get_kl(pi):
    '''
    input: state
    This method computes the KL divergence given the input state, where
    kl(pi, pi_old) = mean of (pi_old * log (pi_old/pi), where pi = pi_old,
    grad of pi should be enabled, and grad of pi_old should be disabled
    '''
    pi_old = pi.data
    result = (pi_old * torch.log((pi_old / pi))).mean()
    return result

In [None]:
def conjugate_gradient(b, prbs, model):
    '''
    Algorithm from wiki
    ---------------------------------------------------------------------------
    function x = conjgrad(A, b, x)
        r = b - A * x;
        p = r;
        rsold = r' * r;
    
        for i = 1:length(b)
            Ap = A * p;
            alpha = rsold / (p' * Ap);
            x = x + alpha * p;
            r = r - alpha * Ap;
            rsnew = r' * r;
            if sqrt(rsnew) < 1e-10
                  break;
            end
            p = r + (rsnew / rsold) * p;
            rsold = rsnew;
        end
    ---------------------------------------------------------------------------
    end
    '''
    # Init a x
    x = torch.zeros(b.size()).to(device)
    # b - A * x = b because x = 0
    r = b.clone()
    p = r.clone()
    rsold = r.dot(r)
    for i in range(cg_max_iters):
        # A  = get_fisher_vector_product()
        Ap = get_fisher_vector_product(p, prbs, model)
        alpha = rsold / (p.dot(Ap))
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = r.dot(r)
        if m.sqrt(rsnew) < res_threshold:
            break
        p = r + (rsnew / rsold) * p
        rsold = rsnew;
    return x


In [None]:
def update_theta(theta, beta, s, model, old_prbs, states):
    '''
    This function computes and updates an appropriate theta, such that Dkl(pi, pi_old) < delta
    If with the current beta, the constraint doesnt hold, decresease the beta value exponentially
    '''
    beta_factor = 1
    beta_s = beta * s
    before = 0
    revert_theta = theta.clone()
    print('theta max', theta.max(), 'theta isNan', torch.isnan(theta).any())
    with torch.no_grad():
        for i in range(10):
            new_theta = theta + torch.clamp(beta_factor * beta_s, min=-40, max=40)
            print('new_theta max', new_theta.max(), 'new_theta isNan', torch.isnan(new_theta).any())
            #print(beta_factor * beta_s)
            vector_to_parameters(new_theta, model.parameters())
            beta_factor = beta_factor / m.e
            new_prbs = model(states)
            print('kl divergence', i, get_kl_compare(new_prbs, old_prbs))
            if(m.isnan(get_kl_compare(new_prbs, old_prbs))):
                print('revert theta')
                vector_to_parameters(revert_theta, model.parameters())
                break
            if(get_kl_compare(new_prbs, old_prbs) <= 2 * delta):
                break

In [None]:
def update_policy(record, model):
  rewards, states, actions, dones = record.returnRecord()
  prbs = model(states)
  actions = torch.LongTensor(actions).unsqueeze(1).to(device)
  selected_action_prbs = prbs.gather(1, actions).squeeze(1)
  q_values = get_Q(rewards, dones)

  L = get_surrogate_loss(selected_action_prbs, q_values)
  model.zero_grad()
  g_ = torch.autograd.grad(L, model.parameters(), retain_graph=True)
  g = flat_parameters(g_)
  s = conjugate_gradient(g, prbs, model)
  Hs = get_fisher_vector_product(s, prbs, model, damping=0)
  sHs = s.dot(Hs)

  beta =  m.sqrt(2*delta / sHs)
  theta = flat_parameters(model.parameters())

  update_theta(theta, beta, s, model, prbs, states)

In [None]:
def policy_select_action(state, model):
    prbs = model(state)
    categorical = Categorical(prbs)
    action = categorical.sample()
    return int(action), prbs

In [None]:
def main(model):
    policy_iter_rewards = []
    record = TrajectoryRecord()
    last_policy_avg_rewards = 0
    for policy_iter in range(5000):
        record.reset()
        state = env.reset()
        epoch_reward = 0
        policy_total_rewards = 0
        epochs = 0
        done = False
        step = 0
        while (step < 1000 or not done):
            action, prb = policy_select_action(state, model)
            next_state, reward, done, _ = env.step(action)
            record.push(reward, action, state, done)
            epoch_reward += reward
            state = next_state
            step+=1
            if(done):
              epochs += 1
              policy_total_rewards += epoch_reward
              epoch_reward = 0
              state = env.reset()
        last_policy_avg_rewards = policy_total_rewards / epochs
        update_policy(record, model)

In [None]:
model = TrpoNet(env.action_space.n).cuda()

main(model)