<a href="https://colab.research.google.com/github/LondonNode/Pearl-tutorials/blob/main/3_Updaters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pearll

# Introduction

This notebook is a tutorial for the `updaters` module within Pearl. This module is responsible for updating the models approximating the policy or value functions. There are three base classes specifying the three types of updater:
- `BaseActorUpdater`
- `BaseCriticUpdater`
- `BaseEvolutionUpdater`

All updaters follow the same design philosophy, with an `__init__` method for initialization and a `__call__` method to run the updater which returns an `UpdaterLog` object.

In [2]:
from pearll.common.type_aliases import UpdaterLog


loss = 0
entropy = 0
divergence = 0
log = UpdaterLog(
  loss=loss, # loss function value
  divergence=divergence, # divergence or distance measurement (e.g. KL divergence)
  entropy=entropy, # entropy of a policy or other distribution parameter
)

print(log)

UpdaterLog(loss=0, divergence=0, entropy=0)


# Actor Updaters

The actor updaters are responsible for updating the actor models via backpropagation. The formulae for updating actors are all very unique, so separate implementations are required for each one. 

Currently implemented updaters:
- `PolicyGradient`: REINFORCE algorithm (e.g. A2C)
- `ProximalPolicyClip`: PPO policy loss
- `DeterministicPolicyGradient`: DDPG policy loss
- `SoftPolicyGradient`: SAC policy loss

In [3]:
from pearll.models import ActorCritic, Actor
from pearll.updaters import BaseActorUpdater
from pearll.common.type_aliases import UpdaterLog

import torch as T
from typing import Type, Union


class YourActorUpdater(BaseActorUpdater):
  def __init__(
      self,
      optimizer_class: Type[T.optim.Optimizer] = T.optim.Adam, # alter optimizer class
      max_grad: float = 0, # clip gradients
  ) -> None:
    super().__init__(optimizer_class, max_grad)

  # __call__ is an abstract method in the base class
  def __call__(
      self,
      model: Union[ActorCritic, Actor],
  ) -> UpdaterLog:
    # Useful inbuilt method to get the requried parameters to be updated.
    actor_parameters = self._get_model_parameters(model=model)
    optimizer = self.optimizer_class(actor_parameters)
    loss = 0
    # Useful inbuilt method to run an optimization step
    self.run_optimizer(optimizer=optimizer, loss=loss, actor_parameters=actor_parameters)
    return UpdaterLog(loss=loss)

# Critic Updaters

The critic updaters are responsible for updating the critic models via backpropagation. These implementations are more generalizable across algorithms compared to the actor updaters since the critics are often updated in the same ways either trying to learn a value or Q function. 

Currently implemented updaters:
- `ValueRegression`
- `ContinuousQRegression`
- `DiscreteQRegression`

In [4]:
# The BaseCriticUpdater class has the same features as the BaseActorUpdater

from pearll.models import ActorCritic, Critic
from pearll.updaters import BaseCriticUpdater
from pearll.common.type_aliases import UpdaterLog

import torch as T
from typing import Type, Union


class YourCriticUpdater(BaseCriticUpdater):
  def __init__(
      self,
      optimizer_class: Type[T.optim.Optimizer] = T.optim.Adam, # alter optimizer class
      max_grad: float = 0, # clip gradients
  ) -> None:
    super().__init__(optimizer_class, max_grad)

  def __call__(
      self,
      model: Union[ActorCritic, Critic],
  ) -> UpdaterLog:
    critic_parameters = self._get_model_parameters(model=model)
    optimizer = self.optimizer_class(critic_parameters)
    loss = 0
    self.run_optimizer(optimizer=optimizer, loss=loss, critic_parameters=critic_parameters)
    return UpdaterLog(loss=loss)

In [5]:
# The implemented updaters all share the same pattern, so let's review by going
# through the ValueRegression updater

from pearll.updaters.critics import ValueRegression
from pearll.models import Critic
from pearll.models.encoders import IdentityEncoder
from pearll.models.torsos import MLP
from pearll.models.heads import ValueHead

import torch as T
import numpy as np


model = Critic(
    encoder = IdentityEncoder(),
    torso = MLP(layer_sizes=[5, 5]),
    head = ValueHead(input_shape=5)
)

updater = ValueRegression(
    loss_class = T.nn.MSELoss(), # alter the regression loss function
    optimizer_class = T.optim.Adam,
    max_grad=0.5,
)

updater(
    model=model,
    observations = np.random.rand(1, 5),
    returns = T.rand(size=(1, 1)),
    learning_rate = 0.001, # alter learning rate of optimizer
    loss_coeff = 1, # weight of the critic loss in relation to the actor loss
)

UpdaterLog(loss=0.5480731129646301, divergence=None, entropy=None)

# Evolution Updaters

The evolution updaters are responsible for updating either the actor or critic models via random search. These are somewhat generalizable implementations.

Currently implemented updaters:
- `NoisyGradientAscent`
- `GeneticUpdater`

In [6]:
from pearll.updaters import BaseEvolutionUpdater
from pearll.common.type_aliases import UpdaterLog


class YourEvolutionUpdater(BaseEvolutionUpdater):
  # The population_type parameter specifies whether to update the actor
  # or critic populations in the ActorCritic model.
  def __init__(self, model: ActorCritic, population_type: str = "actor") -> None:
    super.__init__(model, population_type)

  # Abstract method needs to be implemented
  def __call__(self) -> UpdaterLog:
    population = self.model.numpy_actors()
    # Useful inbuilt method will update the actor or critic population to the
    # state passed as input.
    self.update_networks(population=population)

In [7]:
from pearll.updaters.evolution import NoisyGradientAscent
from pearll.models import Dummy, ActorCritic
from pearll.settings import PopulationSettings

from gym.spaces import Box
import numpy as np


space = Box(-100, 100, shape=(1,))
settings = PopulationSettings(actor_population_size=10, actor_distribution="normal")

actor = Dummy(space)
critic = Dummy(space)
model = ActorCritic(actor, critic, settings)

updater = NoisyGradientAscent(model=model, population_type="actor")
# The NoisyGradientAscent shifts the population in some optimization_direction
# where step size is controlled by the learning_rate.
log = updater(learning_rate=1, optimization_direction=np.array(1))
print(log)
print("divergence = KL divergence between old and new populations")
print("entropy = entropy of new population distribution")
# A useful state of the updater consists of the samples taken from the standard
# normal distribution used to generate a new population
samples = updater.normal_dist
print(f"\nMap of samples from standard normal distribution (left) to the resultant model parameters (right): \n{np.concatenate((samples, model.numpy_actors()), axis=-1)}")

UpdaterLog(loss=None, divergence=tensor(0.5000), entropy=tensor(1.4189))
divergence = KL divergence between old and new populations
entropy = entropy of new population distribution

Map of samples from standard normal distribution (left) to the resultant model parameters (right): 
[[ 2.32818566e-01 -2.82520924e+01]
 [-2.85612791e-01 -2.87705238e+01]
 [ 9.91035826e-02 -2.83858074e+01]
 [-8.65469413e-01 -2.93503804e+01]
 [-1.35573335e-02 -2.84984683e+01]
 [ 1.71631997e-01 -2.83132790e+01]
 [-1.12415523e+00 -2.96090662e+01]
 [-8.46786684e-01 -2.93316976e+01]
 [ 1.71824444e+00 -2.67666665e+01]
 [-1.73978490e+00 -3.02246959e+01]]


In [8]:
from pearll.updaters.evolution import GeneticUpdater
from pearll.models import Dummy, ActorCritic
from pearll.settings import PopulationSettings, Settings, MutationSettings
from pearll.signal_processing import selection_operators, crossover_operators, mutation_operators

from gym.spaces import Box
import numpy as np


space = Box(-100, 100, shape=(1,))
settings = PopulationSettings(actor_population_size=10, actor_distribution="uniform")
rewards = np.random.uniform(size=10)

actor = Dummy(space)
critic = Dummy(space)
model = ActorCritic(actor, critic, settings)

updater = GeneticUpdater(model=model, population_type="actor")
# The GeneticUpdater uses the genetic algorithm approach to update a population.
# Selection operators, crossover operators and mutation operators are
# implemented in the signal_processing module.
# In this case divergence
log = updater(
    rewards=rewards, 
    selection_operator=selection_operators.roulette_selection,
    crossover_operator=crossover_operators.one_point_crossover,
    mutation_operator=mutation_operators.uniform_mutation
)
print(log)
print("divergence = average distance between new and old population")
print("entropy = distance between max and min of new population")

UpdaterLog(loss=None, divergence=61.14169680740664, entropy=181.91257)
divergence = average distance between new and old population
entropy = distance between max and min of new population
