<a href="https://colab.research.google.com/github/akshatshah91/Game-AI/blob/master/TRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Setup

In [None]:
!pip install procgen
import gym
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import scipy.optimize
import math
import random
from collections import namedtuple



### Replay Memory

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'mask', 'next_state',
                                       'reward'))
class Memory:
  def __init__(self):
    self.memory = []

  def push(self, *args):
    self.memory.append(Transition(*args))
  
  def sample(self):
    return Transition(*zip(*self.memory))
  
  def __len__(self):
    return len(self.memory)

#Neural Network

In [None]:
class Policy(nn.Module):
  def __init__(self, num_inputs, num_outputs):
    super(Policy, self).__init__()

    self.fc_layers = nn.Sequential(
        nn.Linear(num_inputs, 64),
        nn.Tanh(),
        nn.Linear(64, 64),
        nn.Tanh()
    )

    self.action_mean = nn.Linear(64, num_outputs)
    self.action_mean.weight.data.mul_(0.1)
    self.action_mean.bias.data.mul_(0.0)

  def forward(self, x):
    x = self.fc_layers(x)

    action_prob = torch.softmax(self.action_mean(x), dim=1)

    return action_prob

  def select_action(self, x):
    x = Variable(torch.from_numpy(x).unsqueeze(0)).cuda()
    action_prob = self.forward(x)
    action = action_prob.multinomial(1)
    return action

  def get_kl(self, x):
    action_prob1 = self.forward(x)
    action_prob0 = action_prob1.detach()
    kl = action_prob0 * (torch.log(action_prob0) - torch.log(action_prob1))
    return kl.sum(1, keepdim=True)
  
  def get_log_prob(self, x, actions):
    action_prob = self.forward(x)
    # print(action_prob.shape)
    # print(actions.shape)
    # print(actions.long().shape)
    # print(actions.long().unsqueeze(1).shape)
    
    return torch.log(action_prob.gather(1, actions.long().unsqueeze(1)))
  
  def get_fim(self, x):
    action_prob = self.forward(x)
    M = action_prob.pow(-1).view(-1).detach()
    return M, action_prob

In [None]:
class Value(nn.Module):
  def __init__(self, num_inputs):
    super(Value, self).__init__()
    self.fc_layers = nn.Sequential(
        nn.Linear(num_inputs, 64),
        nn.Tanh(),
        nn.Linear(64, 64),
        nn.Tanh()
    )
    self.value_head = nn.Linear(64, 1)
    self.value_head.weight.data.mul_(0.1)
    self.value_head.bias.data.mul_(0.0)

  def forward(self, x):
    x = self.fc_layers(x)

    state_values = self.value_head(x)
    return state_values

In [None]:
def estimate_advantages(rewards, masks, values, gamma, tau, device):
  rewards, masks, values = to_device(torch.device('cpu'), rewards, masks, values)
  tensor_type = type(rewards)
  deltas = tensor_type(rewards.size(0), 1)
  advantages = tensor_type(rewards.size(0), 1)

  prev_value = 0
  prev_advantage = 0
  for i in reversed(range(rewards.size(0))):
    deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values[i]
    advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i]

    prev_value = values[i, 0]
    prev_advantage = advantages[i, 0]

  returns = values + advantages
  advantages = (advantages - advantages.mean()) / advantages.std()

  advantages, returns = to_device(device, advantages, returns)
  return advantages, returns

#TRPO

### utils

In [None]:
def to_device(device, *args):
  return [x.to(device) for x in args]

def get_flat_params_from(model):
  params = []
  for param in model.parameters():
    params.append(param.view(-1))

  flat_params = torch.cat(params)
  return flat_params


def set_flat_params_to(model, flat_params):
  prev_ind = 0
  for param in model.parameters():
    flat_size = int(np.prod(list(param.size())))
    param.data.copy_(
        flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
    prev_ind += flat_size


def get_flat_grad_from(inputs, grad_grad=False):
  grads = []
  for param in inputs:
    if grad_grad:
        grads.append(param.grad.grad.view(-1))
    else:
      if param.grad is None:
        grads.append(torch.zeros(param.view(-1).shape))
      else:
        grads.append(param.grad.view(-1))

  flat_grad = torch.cat(grads)
  return flat_grad


def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False):
    if create_graph:
      retain_graph = True

    inputs = list(inputs)
    params = []
    for i, param in enumerate(inputs):
      if i not in filter_input_ids:
        params.append(param)

    grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph)

    j = 0
    out_grads = []
    for i, param in enumerate(inputs):
      if i in filter_input_ids:
        out_grads.append(torch.zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
      else:
        out_grads.append(grads[j].view(-1))
        j += 1
    grads = torch.cat(out_grads)

    for param in params:
      param.grad = None
    return grads

### trpo stuff

In [None]:
def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
  x = torch.zeros(b.size()).to(device)
  r = b.clone()
  p = b.clone()
  rdotr = torch.dot(r, r)
  for i in range(nsteps):
    _Avp = Avp(p)
    alpha = rdotr / torch.dot(p, _Avp)
    x += alpha * p
    r -= alpha * _Avp
    new_rdotr = torch.dot(r, r)
    betta = new_rdotr / rdotr
    p = r + betta * p
    rdotr = new_rdotr
    if rdotr < residual_tol:
      break
  return x

In [None]:
def line_search(model, f, x, fullstep, expected_improve_full, max_backtracks=10, accept_ratio=0.1):
  fval = f(True).item()

  for stepfrac in [.5**x for x in range(max_backtracks)]:
    x_new = x + stepfrac * fullstep
    set_flat_params_to(model, x_new)
    fval_new = f(True).item()
    actual_improve = fval - fval_new
    expected_improve = expected_improve_full * stepfrac
    ratio = actual_improve / expected_improve

    if ratio > accept_ratio:
      return True, x_new
  return False, x

In [None]:
def trpo_step(policy_net, value_net, states, actions, returns, advantages, max_kl, damping, l2_reg, use_fim=True):
  #update critic
  def get_value_loss(flat_params):
    set_flat_params_to(value_net, torch.tensor(flat_params))
    for param in value_net.parameters():
      if param.grad is not None:
        param.grad.data.fill_(0)
    values_pred = value_net(states)
    value_loss = (values_pred - returns).pow(2).mean()

    # weight decay
    for param in value_net.parameters():
      value_loss += param.pow(2).sum() * l2_reg
    value_loss.backward()
    return value_loss.item(), get_flat_grad_from(value_net.parameters()).cpu().numpy()

  flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss,
                                                          get_flat_params_from(value_net).detach().cpu().numpy(),
                                                          maxiter=25)
  set_flat_params_to(value_net, torch.tensor(flat_params))

  #update policy
  with torch.no_grad():
    fixed_log_probs = policy_net.get_log_prob(states, actions)
  
  #define the loss function for TRPO
  def get_loss(volatile=False):
    with torch.set_grad_enabled(not volatile):
      log_probs = policy_net.get_log_prob(states, actions)
      action_loss = -advantages * torch.exp(log_probs - fixed_log_probs)
      return action_loss.mean()

  #use fisher information matrix for Hessian*vector
  def Fvp_fim(v):
    M, mu = policy_net.get_fim(states)
    mu = mu.view(-1)
    filter_input_ids = set()

    t = torch.ones(mu.size(), requires_grad=True, device=mu.device)
    mu_t = (mu * t).sum()
    Jt = compute_flat_grad(mu_t, policy_net.parameters(), filter_input_ids=filter_input_ids, create_graph=True)
    Jtv = (Jt * v).sum()
    Jv = torch.autograd.grad(Jtv, t)[0]
    MJv = M * Jv.detach()
    mu_MJv = (MJv * mu).sum()
    JTMJv = compute_flat_grad(mu_MJv, policy_net.parameters(), filter_input_ids=filter_input_ids).detach()
    JTMJv /= states.shape[0]
    return JTMJv + v * damping

  Fvp = Fvp_fim

  loss = get_loss()
  grads = torch.autograd.grad(loss, policy_net.parameters())
  loss_grad = torch.cat([grad.view(-1) for grad in grads]).detach()
  stepdir = conjugate_gradients(Fvp, -loss_grad, 10).double()

  shs = 0.5 * (stepdir.dot(Fvp(stepdir)))
  lm = math.sqrt(max_kl / shs)
  fullstep = stepdir * lm
  expected_improve = -loss_grad.dot(fullstep)

  prev_params = get_flat_params_from(policy_net)
  success, new_params = line_search(policy_net, get_loss, prev_params, fullstep, expected_improve)
  set_flat_params_to(policy_net, new_params)

  return success

#Agent

In [None]:
class RunningStat(object):
    def __init__(self, shape):
        self._n = 0
        self._M = np.zeros(shape)
        self._S = np.zeros(shape)

    def push(self, x):
        x = np.asarray(x)
        assert x.shape == self._M.shape
        self._n += 1
        if self._n == 1:
            self._M[...] = x
        else:
            oldM = self._M.copy()
            self._M[...] = oldM + (x - oldM) / self._n
            self._S[...] = self._S + (x - oldM) * (x - self._M)

    @property
    def n(self):
        return self._n

    @property
    def mean(self):
        return self._M

    @property
    def var(self):
        return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)

    @property
    def std(self):
        return np.sqrt(self.var)

    @property
    def shape(self):
        return self._M.shape

class ZFilter:
    """
    y = (x-mean)/std
    using running estimates of mean,std
    """

    def __init__(self, shape, demean=True, destd=True, clip=10.0):
        self.demean = demean
        self.destd = destd
        self.clip = clip

        self.rs = RunningStat(shape)
        self.fix = False

    def __call__(self, x, update=True):
        if update and not self.fix:
            self.rs.push(x)
        if self.demean:
            x = x - self.rs.mean
        if self.destd:
            x = x / (self.rs.std + 1e-8)
        if self.clip:
            x = np.clip(x, -self.clip, self.clip)
        return x

In [None]:
def update_params(batch):
  states = torch.from_numpy(np.stack(batch.state)).to(device)
  actions = torch.from_numpy(np.stack(batch.action)).to(device)
  rewards = torch.from_numpy(np.stack(batch.reward)).to(device)
  masks = torch.from_numpy(np.stack(batch.mask)).to(device)
  with torch.no_grad():
    values = value_net(states)

  #get advantage estimation from the trajectories
  advantages, returns = estimate_advantages(rewards, masks, values, gamma, tau, device)

  #perform TRPO update
  trpo_step(policy_net, value_net, states, actions, returns, advantages, max_kl, damping, l2_reg)

In [None]:
def run_training_loop(env, policy_net, value_net, num_episodes=1):
  running_state = ZFilter((num_inputs,), clip=5)
  running_reward = ZFilter((1,), demean=False, clip=10)

  for i_episode in range(num_episodes):
    memory = Memory()

    num_steps = 0
    reward_batch = 0
    num_episodes = 0
    while num_steps < batch_size:
      state = env.reset()
      state = np.ndarray.flatten(state)
      state = running_state(state)

      reward_sum = 0
      for t in range(10000): # Don't infinite loop while learning
        with torch.no_grad():
          action = policy_net.select_action(state)[0].cpu().numpy()
        next_state, reward, done, _ = env.step(int(action))
        next_state = np.ndarray.flatten(next_state)
        reward_sum += reward

        next_state = running_state(next_state)

        mask = 1
        if done:
            mask = 0

        memory.push(state, int(action), mask, next_state, reward)

        if render:
          env.render()
        if done:
          break

        state = next_state
      num_steps += (t-1)
      num_episodes += 1
      reward_batch += reward_sum

    reward_batch /= num_episodes
    batch = memory.sample()
    update_params(batch)

    if i_episode % log_interval == 0:
      print('Episode {}\tLast reward: {}\tAverage reward {:.2f}'.format(
          i_episode, reward_sum, reward_batch))

In [None]:
#@title Demo
gamma = 0.995 #@param {type:"number"}
tau = 0.97 #@param {type:"number"}
l2_reg = 1e-3 #@param {type:"number"}
max_kl = 1e-2 #@param {type:"number"}
damping = 1e-1 #@param {type:"number"}
batch_size = 15000 #@param {type:"slider", min:0, max:20000, step:100}
num_episodes = 100 #@param {type:"slider", min:0, max:100, step:1}
log_interval = 1 #@param {type:"number"}
env_name = "procgen:procgen-coinrun-v0" #@param {type:"string"}
render = False #@param {type: "boolean"}

seed = 4

env = gym.make(env_name)

num_inputs = np.prod(env.observation_space.shape)
num_actions = env.action_space.n

torch.manual_seed(seed)
device = torch.device("cuda")

policy_net = Policy(num_inputs, num_actions).double().cuda()
value_net = Value(num_inputs).double().cuda()

run_training_loop(env, policy_net, value_net, num_episodes)


Episode 0	Last reward: 0.0	Average reward 0.83
Episode 1	Last reward: 0.0	Average reward 3.33
Episode 2	Last reward: 10.0	Average reward 3.33
Episode 3	Last reward: 0.0	Average reward 3.53
Episode 4	Last reward: 0.0	Average reward 4.17
Episode 5	Last reward: 0.0	Average reward 4.05
Episode 6	Last reward: 0.0	Average reward 5.32
Episode 7	Last reward: 0.0	Average reward 3.47
Episode 8	Last reward: 0.0	Average reward 3.93
Episode 9	Last reward: 0.0	Average reward 4.65
Episode 10	Last reward: 0.0	Average reward 5.00
Episode 11	Last reward: 10.0	Average reward 4.00
Episode 12	Last reward: 10.0	Average reward 5.36
Episode 13	Last reward: 0.0	Average reward 3.88
Episode 14	Last reward: 0.0	Average reward 3.61
Episode 15	Last reward: 0.0	Average reward 3.94
Episode 16	Last reward: 0.0	Average reward 4.35
Episode 17	Last reward: 0.0	Average reward 4.70
Episode 18	Last reward: 0.0	Average reward 4.53
Episode 19	Last reward: 0.0	Average reward 5.14
Episode 20	Last reward: 0.0	Average reward 3.54

KeyboardInterrupt: ignored