In [None]:
!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common

!apt-get install -y patchelf
!pip install free-mujoco-py

Reading package lists... Done
Building dependency tree       
Reading state information... Done
libgl1-mesa-dev is already the newest version (20.0.8-0ubuntu1~18.04.1).
libgl1-mesa-dev set to manually installed.
software-properties-common is already the newest version (0.96.24.32.18).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
Suggested packages:
  glew-utils
The following NEW packages will be installed:
  libgl1-mesa-glx libglew-dev libglew2.0 libosmesa6 libosmesa6-dev
0 upgraded, 5 newly installed, 0 to remove and 62 not upgraded.
Need to get 2,916 kB of archives.
After this operation, 12.6 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 libgl1-mesa-glx amd64 20.0.8-0ubuntu1~18.04.1 [5,532 B]
Get:2 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libglew2.0 amd64 2.0.0-5 [140 kB]
Get:3 http://archive.ubuntu.com/ubuntu bionic/univ

## Cherry TRPO MAML

In [None]:
!pip install cherry-rl learn2learn &> /dev/null

In [None]:
import random
import math

from copy import deepcopy

import cherry as ch
import gym
import numpy as np
import torch
from cherry.algorithms import a2c, trpo, ppo
from cherry.models.robotics import LinearValue
from tqdm import tqdm

import learn2learn as l2l

import torch as th
import torch.nn as nn
from torch import autograd
from torch.distributions.kl import kl_divergence
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import Normal, Categorical


In [None]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

In [None]:
class Actor(nn.Module):
    def __init__(self, env, lr, hidden_size=100):
        super().__init__()
        self.input_size = env.observation_space.shape[0]
        self.actor_output_size = env.action_space.shape[0]

        self.l1 = layer_init(nn.Linear(self.input_size, hidden_size))
        self.l2 = layer_init(nn.Linear(hidden_size, hidden_size))
        self.output = layer_init(nn.Linear(hidden_size, self.actor_output_size), std=0.01)
        self.activation = nn.ReLU()
        self.distribution = ch.distributions.ActionDistribution(env)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, eps=1e-5)

    def forward(self, x):
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        x = self.output(x)
        mass = self.distribution(x)

        return mass

In [None]:
class Critic(nn.Module):
    def __init__(self, env, lr=0.01, hidden_size=32):
        super().__init__()
        self.input_size = env.observation_space.shape[0]
        self.critic_output_size = 1

        self.l1 = layer_init(nn.Linear(self.input_size, hidden_size))
        self.l2 = layer_init(nn.Linear(hidden_size, hidden_size))
        self.critic_head = layer_init(nn.Linear(hidden_size, self.critic_output_size), std=1.)
        self.activation = nn.ReLU()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, eps=1e-5)

    def forward(self, x):
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        value = self.critic_head(x)

        return value

```
for _ in trainer.step_epochs():

    tasks = self._task_sampler.sample(self._meta_batch_size)
    theta = dict(self._policy.named_parameters())
    for i, env_up in enumerate(tasks):
        clone = self.policy.clone()
        for j in range(self._num_grad_updates + 1):
            episodes = trainer.obtain_episodes(trainer.step_itr,
                                                env_update=env_up)
            batch_samples = self._process_samples(episodes)
            all_samples[i].append(batch_samples)

            # The last iteration does only sampling but no adapting
            if j < self._num_grad_updates:
                # A grad need to be kept for the next grad update
                # Except for the last grad update
                require_grad = j < self._num_grad_updates - 1

                self._adapt(batch_samples, set_grad=require_grad)
                #################################################
                loss = self._inner_algo._compute_loss(*batch_samples[1:])
                self._inner_optimizer.set_grads_none()
                loss.backward(create_graph=set_grad)
                self._inner_optimizer.step()

        all_params.append(dict(self._policy.named_parameters()))
        # Restore to pre-updated policy
        update_module_params(self._policy, theta)


    meta_objective = self._compute_meta_loss(all_samples, all_params)
    ##########
    theta = dict(self._policy.named_parameters())
    old_theta = dict(self._old_policy.named_parameters())

    losses = []
    for task_samples, task_params in zip(all_samples, all_params):
        for i in range(self._num_grad_updates):
            require_grad = i < self._num_grad_updates - 1 or set_grad
            self._adapt(task_samples[i], set_grad=require_grad)

        update_module_params(self._old_policy, task_params)
        with torch.set_grad_enabled(set_grad):
            # pylint: disable=protected-access
            last_update = task_samples[-1]
            loss = self._inner_algo._compute_loss(*last_update[1:])
        losses.append(loss)

        update_module_params(self._policy, theta)
        update_module_params(self._old_policy, old_theta)

    return torch.stack(losses).mean()

    zero_optim_grads(self._meta_optimizer)
    meta_objective.backward()
    self._meta_optimize(all_samples, all_params)

    trainer.step_itr += 1
```

In [None]:
class MAMLPPO():
    def __init__(self, env,
                 actor_class=Actor, critic_class=Critic, 
                 actor_args=dict(), critic_args=dict(),
                 adapt_lr=1e-2, meta_lr=1e-2, 
                 adapt_steps=1, ppo_steps=5,
                 adapt_batch_size=20, meta_batch_size=20,
                 gamma=0.99, tau=1.0,
                 policy_clip=0.2, value_clip=None,
                 seed=42,
                 device='cpu', name="MAML-PPO", tensorboard_log=None):
        
        self.device = torch.device(device)
        # random.seed(seed)
        # np.random.seed(seed)
        # torch.manual_seed(seed)
        # if device == 'cuda':
        #     torch.cuda.manual_seed(seed)

        env = ch.envs.ActionSpaceScaler(env)
        # env.seed(seed)
        env.set_task(env.sample_tasks(1)[0])
        self.env = ch.envs.Torch(env)

        self.gamma = gamma
        self.tau = tau
        self.adapt_lr = adapt_lr
        self.meta_lr = meta_lr
        self.adapt_steps = adapt_steps
        self.adapt_batch_size = adapt_batch_size
        self.meta_batch_size = meta_batch_size
        self.policy_clip = policy_clip
        self.value_clip = value_clip
        self.ppo_steps = ppo_steps

        # self.backtrack_factor = backtrack_factor
        # self.ls_max_steps = ls_max_steps
        # self.max_kl = max_kl

        policy = Actor(env, lr=meta_lr, **actor_args).to(device)
        self.policy = l2l.algorithms.MAML(policy, lr=adapt_lr)
        self.meta_optimizer = torch.optim.Adam(self.policy.parameters(), lr=meta_lr)
        # self.baseline = Critic(env, lr=meta_lr, **critic_args).to(device)
        self.baseline = Critic(env, **critic_args).to(device)
        # self.baseline = LinearValue(env.state_size, env.action_size)


    def collect_steps(self, policy, n_episodes):
        replay = ch.ExperienceReplay(device=self.device)
        for i in range(n_episodes):
            state = self.env.reset()

            while True:
                with torch.no_grad():
                    mass = policy(state)
                action = mass.sample()
                log_prob = mass.log_prob(action).mean(dim=1, keepdim=True)
                next_state, reward, done, _ = self.env.step(action)

                replay.append(state,
                            action,
                            reward,
                            next_state,
                            done,
                            log_prob=log_prob)
                
                if done:
                    break

                state = next_state

        with torch.no_grad():
            next_state_value = self.baseline(replay[-1].next_state)
        values = self.baseline(replay.state())

        advantages = ch.generalized_advantage(self.gamma,
                                                self.tau,
                                                replay.reward(),
                                                replay.done(),
                                                values.detach(),
                                                next_state_value)
        returns = advantages + values.detach()
        advantages = ch.normalize(advantages, epsilon=1e-8)

        for i, sars in enumerate(replay):
            sars.returns = returns[i]
            sars.advantage = advantages[i]

        # values_pred = self.baseline(replay.state())
        if self.value_clip:
            value_loss = ppo.state_value_loss(values,
                                            replay.value(),
                                            returns,
                                            clip=self.value_clip)
        else:
            value_loss = a2c.state_value_loss(values, returns)

        self.baseline.optimizer.zero_grad()
        value_loss.backward()
        self.baseline.optimizer.step()

        # self.baseline.fit(replay.state(), returns)
        return replay


    def fast_adapt(self, clone, train_episodes):
        for ppo_epoch in range(self.ppo_steps):
            new_density = clone(train_episodes.state())
            new_log_probs = new_density.log_prob(train_episodes.action()).mean(dim=1, keepdim=True)

            # Compute the policy loss
            loss = ppo.policy_loss(new_log_probs, 
                                   train_episodes.log_prob(), 
                                   train_episodes.advantage(), 
                                   clip=self.policy_clip)
            clone.adapt(loss)


    def train(self, num_iterations=100):
        for iteration in range(num_iterations):
            iteration_reward = 0.0
            iteration_replays = []
            iteration_policies = []
            iter_loss = 0.0

            for task_config in tqdm(self.env.sample_tasks(self.meta_batch_size), leave=False, desc='Data'):
                clone = self.policy.clone()
                self.env.set_task(task_config)
                task_replay = []

                # Fast Adapt
                for step in range(self.adapt_steps):
                    train_episodes = self.collect_steps(clone, n_episodes=self.adapt_batch_size)
                    self.fast_adapt(clone, train_episodes)

                # Compute Validation Loss
                valid_episodes = self.collect_steps(clone, n_episodes=self.adapt_batch_size)
                iteration_reward += valid_episodes.reward().sum().item() / self.adapt_batch_size

                new_density = clone(valid_episodes.state())
                new_log_probs = new_density.log_prob(valid_episodes.action()).mean(dim=1, keepdim=True)
                # Compute the policy loss
                valid_loss = ppo.policy_loss(new_log_probs,
                                             valid_episodes.log_prob(), 
                                             valid_episodes.advantage(),
                                             clip=self.policy_clip)
                iter_loss += valid_loss

            # Print statistics
            print('\nIteration', iteration)
            adaptation_reward = iteration_reward / self.meta_batch_size
            print('adaptation_reward', adaptation_reward)

            av_loss = iter_loss / self.meta_batch_size

            av_loss.backward()
            self.meta_optimizer.step()


In [None]:
env = gym.make('Particles2D-v1')

In [None]:
aa = MAMLPPO(env)



In [None]:
aa.train(100)




Iteration 0
adaptation_reward -72.14244812011718





Iteration 1
adaptation_reward -64.69322326660156





Iteration 2
adaptation_reward -60.82671463012696





Iteration 3
adaptation_reward -51.2676643371582





Iteration 4
adaptation_reward -46.76594894409181





Iteration 5
adaptation_reward -44.23233413696289





Iteration 6
adaptation_reward -43.03200332641602





Iteration 7
adaptation_reward -46.263327941894524





Iteration 8
adaptation_reward -42.709817886352546





Iteration 9
adaptation_reward -38.209377212524416





Iteration 10
adaptation_reward -34.073232421875





Iteration 11
adaptation_reward -35.51377508163452





Iteration 12
adaptation_reward -39.245884017944334





Iteration 13
adaptation_reward -44.06808357238769





Iteration 14
adaptation_reward -31.109482975006095





Iteration 15
adaptation_reward -44.82811918258666





Iteration 16
adaptation_reward -49.77313026428222





Iteration 17
adaptation_reward -55.34247619628907





Iteration 18
adaptation_reward -51.14251829147337





Iteration 19
adaptation_reward -57.62670774459838





Iteration 20
adaptation_reward -53.029380874633794





Iteration 21
adaptation_reward -55.511291732788074





Iteration 22
adaptation_reward -57.05236518859863





Iteration 23
adaptation_reward -55.101924133300784





Iteration 24
adaptation_reward -54.09776824951172





Iteration 25
adaptation_reward -47.6008623123169





Iteration 26
adaptation_reward -48.87730895996094





Iteration 27
adaptation_reward -46.724959716796874





Iteration 28
adaptation_reward -53.530391159057615





Iteration 29
adaptation_reward -43.14779594421387





Iteration 30
adaptation_reward -38.936241912841794





Iteration 31
adaptation_reward -46.90168914794922





Iteration 32
adaptation_reward -41.86559139251709





Iteration 33
adaptation_reward -44.11037166595459





Iteration 34
adaptation_reward -37.805352020263676





Iteration 35
adaptation_reward -36.87970939636231





Iteration 36
adaptation_reward -44.461333007812506





Iteration 37
adaptation_reward -38.69464485168457





Iteration 38
adaptation_reward -48.24568592071533





Iteration 39
adaptation_reward -54.646986465454106





Iteration 40
adaptation_reward -58.7583726501465





Iteration 41
adaptation_reward -68.95524559020997





Iteration 42
adaptation_reward -66.85666732788087





Iteration 43
adaptation_reward -81.63729476928712





Iteration 44
adaptation_reward -116.36980072021485





Iteration 45
adaptation_reward -148.4190576171875





Iteration 46
adaptation_reward -529.273818359375





Iteration 47
adaptation_reward -545.0820068359374





Iteration 48
adaptation_reward -521.9780322265626





Iteration 49
adaptation_reward -534.7763208007813





Iteration 50
adaptation_reward -534.2482788085938





Iteration 51
adaptation_reward -548.0662548828125





Iteration 52
adaptation_reward -436.07673583984376





Iteration 53
adaptation_reward -311.8321166992188





Iteration 54
adaptation_reward -357.8904663085938





Iteration 55
adaptation_reward -485.66236083984387





Iteration 56
adaptation_reward -627.297911376953





Iteration 57
adaptation_reward -660.2115942382814





Iteration 58
adaptation_reward -668.045947265625





Iteration 59
adaptation_reward -670.2789208984377





Iteration 60
adaptation_reward -671.3915136718749





Iteration 61
adaptation_reward -666.0001757812499





Iteration 62
adaptation_reward -675.0311645507815





Iteration 63
adaptation_reward -667.0026855468749





Iteration 64
adaptation_reward -634.3038525390625





Iteration 65
adaptation_reward -618.6332006835938





Iteration 66
adaptation_reward -599.4000756835937





Iteration 67
adaptation_reward -583.4839428710937





Iteration 68
adaptation_reward -560.2422021484374





Iteration 69
adaptation_reward -543.1413183593748





Iteration 70
adaptation_reward -536.8820166015624





Iteration 71
adaptation_reward -535.0161376953126





Iteration 72
adaptation_reward -519.8163281249999





Iteration 73
adaptation_reward -523.6207202148438





Iteration 74
adaptation_reward -525.7329833984375





Iteration 75
adaptation_reward -522.8726367187501





Iteration 76
adaptation_reward -519.711171875





Iteration 77
adaptation_reward -525.5798388671875





Iteration 78
adaptation_reward -531.2727221679687





Iteration 79
adaptation_reward -521.0558715820313





Iteration 80
adaptation_reward -509.23310791015626





Iteration 81
adaptation_reward -508.65889160156246





Iteration 82
adaptation_reward -520.7678686523439





Iteration 83
adaptation_reward -509.2552709960938





Iteration 84
adaptation_reward -500.5760253906251





Iteration 85
adaptation_reward -516.3190454101561





Iteration 86
adaptation_reward -507.88580078125017





Iteration 87
adaptation_reward -513.3900317382812





Iteration 88
adaptation_reward -495.6379013293981





Iteration 89
adaptation_reward -521.7659228515626





Iteration 90
adaptation_reward -526.4571997070312





Iteration 91
adaptation_reward -517.4612158203125





Iteration 92
adaptation_reward -514.4419921875





Iteration 93
adaptation_reward -512.3902294921875





Iteration 94
adaptation_reward -503.57342041015636





Iteration 95
adaptation_reward -514.4265747070312





Iteration 96
adaptation_reward -501.15360351562504





Iteration 97
adaptation_reward -506.7809838867187





Iteration 98
adaptation_reward -510.89814453125





Iteration 99
adaptation_reward -515.21685546875
