## Метод Option-Critic

Статья по [Option-Critic архитектуре](https://ojs.aaai.org/index.php/AAAI/article/download/10916/10775).

Общая архитектура агента:

![image.png](https://d3i71xaburhd42.cloudfront.net/15b26d8cb35d7e795c8832fe08794224ee1e9f84/3-Figure1-1.png)


Общий вид алгоритма:

![image.png](https://campusai.github.io/_papers/the_option_critic_architecture/algo1optioncritic.png)

Good implementation [link](https://github.com/lweitkamp/option-critic-pytorch/blob/master/option_critic.py)! 

In [1]:
try:
    import google.colab
    COLAB = True
except ModuleNotFoundError:
    COLAB = False
    pass

if COLAB:
    !apt install swig
    !pip -q install "gymnasium[classic-control, atari, accept-rom-license, box2d]"
    !pip -q install piglet
    !pip -q install imageio_ffmpeg
    !pip -q install moviepy==1.0.3

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  swig4.0
Suggested packages:
  swig-doc swig-examples swig4.0-examples swig4.0-doc
The following NEW packages will be installed:
  swig swig4.0
0 upgraded, 2 newly installed, 0 to remove and 34 not upgraded.
Need to get 1,086 kB of archives.
After this operation, 5,413 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu focal/universe amd64 swig4.0 amd64 4.0.1-5build1 [1,081 kB]
Get:2 http://archive.ubuntu.com/ubuntu focal/universe amd64 swig all 4.0.1-5build1 [5,528 B]
Fetched 1,086 kB in 1s (1,323 kB/s)
Selecting previously unselected package swig4.0.
(Reading database ... 122545 files and directories currently installed.)
Preparing to unpack .../swig4.0_4.0.1-5build1_amd64.deb ...
Unpacking swig4.0 (4.0.1-5build1) ...
Selecting previously unselected package swig.
Preparing to unpack .../swig_4.0.1-5build1_a

In [2]:
import torch
import torch.nn as nn
from torch.distributions import Categorical, Bernoulli
import gymnasium as gym
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

### Основной цикл

In [38]:
def print_mean_reward(step, episode_rewards):
    if not episode_rewards:
        return

    t = min(50, len(episode_rewards))    
    mean_reward = sum(episode_rewards[-t:]) / t
    print(f"step: {str(step).zfill(6)}, mean reward: {mean_reward:.2f}")
    return mean_reward


def to_tensor(x, dtype=np.float32):
    if isinstance(x, torch.Tensor):
        return x
    x = np.asarray(x, dtype=dtype)
    x = torch.from_numpy(x).to(device)
    return x


def to_np(t):
    return t.detach().cpu().numpy()


def softmax(x: np.ndarray, temp=1.) -> np.ndarray:
    """Computes softmax values for a vector `x` with a given temperature."""
    temp = np.clip(temp, 1e-5, 1e+3)
    e_x = np.exp((x - np.max(x, axis=-1)) / temp)
    return e_x / e_x.sum(axis=-1)


def run(
        env: gym.Env, hidden_size: int, n_options: int,
        softmax_temp: float, ac_lr: float, cr_lr: float, gamma: float,
        termination_regularizer: float, entropy_weight: float,
        max_episodes: int, replay_buffer_size: int, update_schedule: int,
        batch_size: int, critic_batch_size: int, critic_updates_per_actor: int,
        seed: int, print_schedule: int, success_return: float
):
    # Инициализируйте агента `agent`, когда сделаете саму реализацию агента ниже по заданию.
    ####### Здесь ваш код ########
    try:
        state_dim = env.observation_space.shape[0]
    except IndexError:
        state_dim = env.observation_space.n
    action_dim = env.action_space.n

    agent =  OptionCriticAgent(state_dim, hidden_size, action_dim, n_options, softmax_temp, 
                                gamma, ac_lr, cr_lr, termination_regularizer, entropy_weight,
                                replay_buffer_size, seed)
    ##############################

    episode_rewards = []
    step, rollout_step = 0, 0
    for i_episode in range(1, max_episodes + 1):
        cumulative_reward = 0
        terminated = False
        state, _ = env.reset()
        
        while not terminated:
            step += 1

            action, option = agent.act(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            
            agent.append_to_replay_buffer(state, action, option, reward, next_state, terminated)
            state = next_state
            cumulative_reward += reward
            terminated |= truncated
            if step % update_schedule == 0:
                agent.update(batch_size, critic_batch_size, critic_updates_per_actor)

        episode_rewards.append(cumulative_reward)
        
        # выполняем обновление
        if i_episode % print_schedule == 0:
            mean_reward = print_mean_reward(step, episode_rewards) 
            if mean_reward >= success_return:
                print('Accepted!')
                return agent
            episode_rewards = []
            
    return agent

In [39]:
from collections import deque, namedtuple
from operator import attrgetter


Transition = namedtuple('Transition', ['state', 'action', 'option', 'reward', 'next_state', 'done'])

class ReplayBuffer:
    def __init__(self, size, seed):
        self.buffer = deque(maxlen=size)
        self._rng = np.random.default_rng(seed)
    
    def append(self, state, action, option, reward, next_state, done):
        sample = Transition(state, action, option, reward, next_state, done)
        self.buffer.append(sample)
        
    def get_last_n_samples(self, n_samples):
        # Get last `n_samples` samples from replay buffer
        buffer_size = len(self.buffer)
        if buffer_size < n_samples:
            return None

        ####### Здесь ваш код ########
        indices = np.arange(buffer_size-n_samples, buffer_size)
        ##############################
        return self._get_batch(indices)

    def sample_batch(self, n_samples):
        buffer_size = len(self.buffer)
        if buffer_size < n_samples:
            return None
        
        ####### Здесь ваш код ########
        indices = np.random.choice(len(self.buffer), n_samples, replace=False)
        ##############################

        return self._get_batch(indices)

    def _get_batch(self, indices):
        states, actions, options, rewards, next_states, dones = [], [], [], [], [], []
        for i in indices:
            s, a, o, r, n_s, done = self.buffer[i]
            states.append(s)
            actions.append(a)
            options.append(o)
            rewards.append(r)
            next_states.append(n_s)
            dones.append(done)

        batch = (
            np.array(states), np.array(actions), np.array(options), 
            np.array(rewards), np.array(next_states), np.array(dones)
        )
        return batch
    
    def __len__(self):
        return len(self.buffer)

In [40]:
class MLPModel(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh()
        )

    def forward(self, state):
        state = to_tensor(state)
        return self.net(state)

class OptionCriticNet(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, num_options, softmax_temp):
        super(OptionCriticNet, self).__init__()

        self.body = MLPModel(state_dim, hidden_dim)
        self.q = nn.Linear(hidden_dim, num_options)
        self.pi = nn.Linear(hidden_dim, num_options * action_dim)
        self.beta = nn.Linear(hidden_dim, num_options) #options terminantion
        
        self.num_options = num_options
        self.action_dim = action_dim
        self.softmax_temp = softmax_temp


    @torch.no_grad()   
    def predict_option_temination(self, state, option):
        beta = self.beta(state)[:, option].sigmoid() 
        option_termination = bool(Bernoulli(beta).sample().item())
        next_option = self.q_values(state).argmax(-1).item()
        return option_termination, next_option

    def forward(self, state):
        x = self.body(state)

        ####### Здесь ваш код ########
        q_values = self.q(x)
        beta = self.beta(x).sigmoid() 
        pi = self.pi(x).view(-1, self.num_options, self.action_dim).squeeze(0)

        pi = (pi / self.softmax_temp).softmax(-1)
        log_pi = (pi / self.softmax_temp).log_softmax(-1)    

        ##############################
        
        return {
            'q': q_values,
            'beta': beta,
            'pi': pi,
            'log_pi': log_pi,
        }

class OptionCriticAgent:
    def __init__(
        self, state_dim, hidden_dim, action_dim, num_options, softmax_temp, 
        gamma, ac_lr, cr_lr, termination_regularizer, entropy_weight,
        replay_buffer_size, seed
    ):
        self.network = OptionCriticNet(state_dim, hidden_dim, action_dim, num_options, softmax_temp).to(device)
        self.cr_optimizer = torch.optim.Adam(self.network.parameters(), lr=ac_lr)
        self.ac_optimizer = torch.optim.Adam(self.network.parameters(), lr=cr_lr)

        self.n_options = num_options
        self.n_actions = action_dim
        self.gamma = gamma
        self.softmax_temp = softmax_temp
        self.termination_regularizer = termination_regularizer
        self.entropy_weight = entropy_weight
        
        self.replay_buffer = ReplayBuffer(replay_buffer_size, seed)

        self._rng = np.random.default_rng(seed)
        self.option = None

    @torch.no_grad()    
    def act(self, state):
        ####### Здесь ваш код ########
        preds = self.network(to_tensor(state))
        
        self.option = self.sample_option(
            q_options=preds['q'],
            beta_options=preds['beta']
        )
        action_dist = Categorical(probs=preds['pi'][self.option])
        action = action_dist.sample().item()

        ##############################
        option = self.option
        return action, option
    
    def update(self, batch_size, critic_batch_size, critic_updates_per_actor):
        if len(self.replay_buffer) < batch_size:
            return False
        
        self.update_actor(batch_size)
        self.update_critic(critic_batch_size, critic_updates_per_actor)
        return True

    def update_actor(self, batch_size):
        batch = self.replay_buffer.get_last_n_samples(batch_size)
        if not batch:
            return
        
        states, actions, options, rewards, next_states, is_done = batch
        
        ####### Здесь ваш код ########
        s = to_tensor(states)                       
        a = to_tensor(actions, int).long().unsqueeze(-1)          
        o = to_tensor(options, int).long().unsqueeze(-1)       
        r = to_tensor(rewards)                       
        s_n = to_tensor(next_states)                
        not_done = 1 - to_tensor(is_done, int)
        
        preds = self.network(s)
        # q: torch.Size([200, 4])
        # beta: Size([200, 4])
        # pi: Size([200, 4, 2])
        # log_pi: Size([200, 4, 2])

        options_for_actions = o.unsqueeze(-1).expand(*o.size(), self.n_actions)
        pi_o = preds['pi'].gather(1, options_for_actions).squeeze(1)
        log_pi_o = preds['log_pi'].gather(1, options_for_actions).squeeze(1)
        log_pi_o_for_actions = log_pi_o.gather(1, a).squeeze(-1)
        beta_o = preds['beta'].gather(1, o).squeeze(-1)
        
        q_values = preds['q'].detach()
        v_s = q_values.max(-1)[0]

        q_under_options = q_values.gather(1, o).squeeze(-1)
        beta_advantage = q_under_options - v_s
        beta_advantage  = (beta_advantage - beta_advantage.mean()) / (beta_advantage.std() + 1e-7)

        with torch.no_grad():
            preds_next = self.network(s_n)
            beta_o_prime = preds_next['beta'].gather(1, o).squeeze(-1)
            q_values_next_options = preds_next['q'].gather(1, o).squeeze(-1)
            v_s_prime = preds_next['q'].max(-1)[0]

        utility_upon_arrival =  ((1 - beta_o_prime) * q_values_next_options + beta_o_prime * v_s_prime)
        gt = r + not_done * self.gamma * utility_upon_arrival
        entropy = -self.entropy_weight * (log_pi_o * pi_o).sum(-1)

        beta_loss = (beta_o * beta_advantage + self.termination_regularizer) * not_done
        policy_loss = -log_pi_o_for_actions * (gt.detach() - q_under_options) + entropy
        
        self.ac_optimizer.zero_grad()
        loss = (policy_loss + beta_loss).mean()
        loss.backward()
        self.ac_optimizer.step()

        ##############################
    
    def update_critic(self, batch_size, critic_updates_per_actor):
        # ограничивает сверху количество эпох для буфера небольшого размера
        critic_updates_per_actor = min(
            critic_updates_per_actor, 
            5 * len(self.replay_buffer.buffer) // batch_size
        )
        
        for _ in range(critic_updates_per_actor):
            self.update_critic_step(batch_size)


    def update_critic_step(self, batch_size):
        batch = self.replay_buffer.sample_batch(batch_size)
        if not batch:
            return
        
        states, actions, options, rewards, next_states, is_done = batch
        
        # Реализуйте шаг обучения критика
        ####### Здесь ваш код ########
        s = to_tensor(states)                       
        a = to_tensor(actions, int).long()          
        o = to_tensor(options, int).long().unsqueeze(-1)          
        r = to_tensor(rewards)                       
        s_n = to_tensor(next_states)                
        not_done = 1 - to_tensor(is_done, int)               
        
        preds = self.network(s)
        
        q_under_options = preds['q'].gather(1, o).squeeze(-1)

        with torch.no_grad():
            preds_next = self.network(s_n)
            beta_o_prime = preds_next['beta'].gather(1, o).squeeze(-1)
            q_values_next_options = preds_next['q'].gather(1, o).squeeze(-1)
            v_s_prime = preds_next['q'].max(-1)[0]

        utility_upon_arrival = ((1 - beta_o_prime) * q_values_next_options + beta_o_prime * v_s_prime)
        td_target = r + not_done * self.gamma * utility_upon_arrival

        td_error = q_under_options - td_target.detach()

        loss = td_error.pow(2).mean()

        self.cr_optimizer.zero_grad()
        loss.backward()
        self.cr_optimizer.step()

        ##############################

    def sample_option(self, q_options: torch.Tensor, beta_options: torch.Tensor):
        # Реализуйте выбор опции
        ####### Здесь ваш код ########
        if self.option is not None:
             option_termination = bool(Bernoulli(beta_options[self.option]).sample().item())
             if not option_termination:
                return self.option
        
        option = Categorical(logits=q_options).sample().item()
        
        ##############################
        return option
    
    def append_to_replay_buffer(self, s, a, o, r, next_s, done):
        self.replay_buffer.append(s, a, o, r, next_s, done)

In [41]:
from gymnasium.wrappers.time_limit import TimeLimit
env_name = "CartPole-v1"

# env_name = "LunarLander-v2"

agent = run(
    env = TimeLimit(gym.make(env_name), 1000),
    max_episodes = 1000,  # количество эпизодов обучения
    hidden_size = 64,  # кол-во переменных в скрытых слоях
    n_options = 4,
    update_schedule = 10,
    batch_size = 200, 
    softmax_temp = 0.01,  # softmax temperature
    ac_lr = 0.001, # actor learning rate
    cr_lr = 0.0002, # critic learning rate
    termination_regularizer = 0.01, # punish early termination
    entropy_weight = 0.5,  # punish intra-option policy over-determination
    gamma = 0.995,  # дисконтирующий множитель,
    replay_buffer_size = 5000,
    critic_batch_size = 64,
    critic_updates_per_actor = 32,
    seed = 1337,
    print_schedule = 10,
    success_return = 200
)

step: 000119, mean reward: 11.90
step: 000245, mean reward: 12.60
step: 000368, mean reward: 12.30
step: 000688, mean reward: 32.00
step: 002380, mean reward: 169.20
step: 006096, mean reward: 371.60
Accepted!


In [None]:
env_name = "LunarLander-v2"

agent = run(
    env = TimeLimit(gym.make(env_name), 1000),
    max_episodes = 1000,  # количество эпизодов обучения
    hidden_size = 128,  # кол-во переменных в скрытых слоях
    n_options = 8,
    update_schedule = 10,
    batch_size = 200, 
    softmax_temp = 1.00,  # softmax temperature
    ac_lr = 0.001, # actor learning rate
    cr_lr = 0.0002, # critic learning rate
    termination_regularizer = 0.1, # punish early termination
    entropy_weight = 0.5,  # punish intra-option policy over-determination
    gamma = 0.99,  # дисконтирующий множитель,
    replay_buffer_size = 50000,
    critic_batch_size = 64,
    critic_updates_per_actor = 64,
    seed = 1337,
    print_schedule = 10,
    success_return = 200
)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

def print_options(env, agent: OptionCriticAgent, option, seed=42):
    @torch.no_grad()
    def calc_sample(sample):
        prediction = agent.network(sample)

        q = prediction['q']
        beta = prediction['beta']
        pi = to_np(prediction['pi'])
        log_pi = to_np(prediction['log_pi'])
        
        option = agent.sample_option(q_options=q, beta_options=beta)
        return q, beta, pi, log_pi, option
    
    rng = np.random.default_rng()

    states_dim = env.observation_space.shape[0]
    xs_0 = np.linspace(-2.4, 2.4, num=50)
    xs_1 = np.linspace(-2.4, 2.4, num=50)
    xs_2 = np.linspace(-2.1, 2.1, num=4)
    xs_3 = np.linspace(-2.4, 2.4, num=4)
    
    sample = [
        rng.choice(xs)
        for xs in [xs_0, xs_1, xs_2, xs_3]
    ]
    
    n_plots = xs_2.size * xs_3.size
    cols = 4
    rows = (n_plots - 1) // cols + 1
    
    plt.figure(figsize=(20, 10))
    i_plot = 1
    for x2 in xs_2:
        for x3 in xs_3:
            beta_map = np.zeros((xs_0.size, xs_1.size))
            q_map = np.zeros_like(beta_map)            
            
            for i, x0 in enumerate(xs_0):
                for j, x1 in enumerate(xs_1):
                    obs = np.array([x0, x1, x2, x3])
                    q, beta, pi, log_pi, option = calc_sample(obs)
                    if option > 0:
                        beta_map[i, j] = 0
                        q_map[i, j] = 0
                    else:
                        beta_map[i, j] = beta[option]
                        q_map[i, j] = q[option] / 200
            
            plt.subplot(rows, 2 * cols, i_plot)
            plt.imshow(q_map, vmin=0, vmax=1)
            plt.subplot(rows, 2 * cols, i_plot + 1)
            plt.imshow(beta_map, vmin=0, vmax=1)
            i_plot += 2
    plt.show()

for i in range(agent.n_options):
    print_options(
        env=gym.make(env_name),
        agent=agent,
        option=i
    )
    print('===========')