In [None]:
!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common

!apt-get install -y patchelf
!pip install free-mujoco-py

Reading package lists... Done
Building dependency tree       
Reading state information... Done
libglew-dev is already the newest version (2.0.0-5).
libgl1-mesa-dev is already the newest version (20.0.8-0ubuntu1~18.04.1).
libgl1-mesa-glx is already the newest version (20.0.8-0ubuntu1~18.04.1).
libosmesa6-dev is already the newest version (20.0.8-0ubuntu1~18.04.1).
software-properties-common is already the newest version (0.96.24.32.18).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 19 not upgraded.
Reading package lists... Done
Building dependency tree       
Reading state information... Done
patchelf is already the newest version (0.9-1).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 19 not upgraded.
Looking in indexe

## Cherry TRPO MAML

In [None]:
!pip install cherry-rl learn2learn &> /dev/null

In [None]:
import random
import math
import time

from copy import deepcopy

import cherry as ch
import gym
import numpy as np
import torch
from cherry.algorithms import a2c, ppo
from cherry.models.robotics import LinearValue
from tqdm import tqdm

import learn2learn as l2l

import torch as th
import torch.nn as nn
from torch import autograd
from torch.distributions.kl import kl_divergence
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import Normal, Categorical
from torch.utils.tensorboard import SummaryWriter

In [None]:
env = gym.make('HalfCheetahForwardBackward-v1')
env.reset()

array([ 0.00295027,  0.00932332, -0.09128403, -0.06178263,  0.07330716,
       -0.09275652,  0.02119905, -0.03270184,  0.11129588, -0.0150656 ,
       -0.05936605,  0.03230394,  0.06749459,  0.17624147, -0.08633809,
        0.081727  ,  0.01158038,  0.03684237,  0.        ,  0.7029503 ],
      dtype=float32)

In [None]:
EPSILON = 1e-6

def linear_init(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        module.bias.data.zero_()
    return module


class DiagNormalPolicy(nn.Module):

    def __init__(self, input_size, output_size, hiddens=None, activation='relu', device='cpu'):
        super(DiagNormalPolicy, self).__init__()
        self.device = device
        if hiddens is None:
            hiddens = [100, 100]
        if activation == 'relu':
            activation = nn.ReLU
        elif activation == 'tanh':
            activation = nn.Tanh
        layers = [linear_init(nn.Linear(input_size, hiddens[0])), activation()]
        for i, o in zip(hiddens[:-1], hiddens[1:]):
            layers.append(linear_init(nn.Linear(i, o)))
            layers.append(activation())
        layers.append(linear_init(nn.Linear(hiddens[-1], output_size)))
        self.mean = nn.Sequential(*layers)
        self.sigma = nn.Parameter(torch.Tensor(output_size))
        self.sigma.data.fill_(math.log(1))

    # def forward(self, state):
    #     state = state.to(self.device, non_blocking=True)
    #     loc = self.mean(state)
    #     scale = torch.exp(torch.clamp(self.sigma, min=math.log(EPSILON)))
    #     return Normal(loc=loc, scale=scale)

    def density(self, state):
        state = state.to(self.device, non_blocking=True)
        loc = self.mean(state)
        scale = torch.exp(torch.clamp(self.sigma, min=math.log(EPSILON)))
        return Normal(loc=loc, scale=scale)

    def log_prob(self, state, action):
        density = self.density(state)
        return density.log_prob(action).mean(dim=1, keepdim=True)

    def forward(self, state):
        density = self.density(state)
        action = density.sample()
        return action


In [None]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

In [None]:
class Actor(nn.Module):
    def __init__(self, env, lr, hidden_size=100):
        super().__init__()
        self.input_size = env.observation_space.shape[0]
        self.actor_output_size = env.action_space.shape[0]

        self.l1 = layer_init(nn.Linear(self.input_size, hidden_size))
        self.l2 = layer_init(nn.Linear(hidden_size, hidden_size))
        self.output = layer_init(nn.Linear(hidden_size, self.actor_output_size), std=0.01)
        self.activation = nn.ReLU()
        self.distribution = ch.distributions.ActionDistribution(env)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, eps=1e-5)

    def forward(self, x):
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        x = self.output(x)
        mass = self.distribution(x)

        return mass

In [None]:
class Critic(nn.Module):
    def __init__(self, env, lr=0.01, hidden_size=32):
        super().__init__()
        self.input_size = env.observation_space.shape[0]
        self.critic_output_size = 1

        self.l1 = layer_init(nn.Linear(self.input_size, hidden_size))
        self.l2 = layer_init(nn.Linear(hidden_size, hidden_size))
        self.critic_head = layer_init(nn.Linear(hidden_size, self.critic_output_size), std=1.)
        self.activation = nn.ReLU()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, eps=1e-5)

    def forward(self, x):
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        value = self.critic_head(x)

        return value

In [None]:
class MAMLPPO():
    def __init__(self, env_name,
                 actor_class=Actor, critic_class=Critic, 
                 actor_args=dict(), critic_args=dict(),
                 adapt_lr=1e-1, meta_lr=1e-2, 
                 adapt_steps=3, ppo_steps=5,
                 adapt_batch_size=20, meta_batch_size=20,
                 gamma=0.99, tau=1.0,
                 policy_clip=0.2, value_clip=None,
                 num_workers=10,
                 seed=42,
                 device=None, name="MAMLPPO", tensorboard_log="./logs"):
        
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            torch.cuda.manual_seed(seed)
        else:
            self.device = torch.device("cpu")
        if device:
            self.device = torch.device(device)
        print("Running on: " + str(self.device))

        def make_env():
            env = gym.make(env_name)
            env = ch.envs.ActionSpaceScaler(env)
            return env

        env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
        env.seed(seed)
        env.set_task(env.sample_tasks(1)[0])
        self.env = ch.envs.Torch(env)

        self.gamma = gamma
        self.tau = tau
        self.adapt_lr = adapt_lr
        self.meta_lr = meta_lr
        self.adapt_steps = adapt_steps
        self.adapt_batch_size = adapt_batch_size
        self.meta_batch_size = meta_batch_size
        self.policy_clip = policy_clip
        self.value_clip = value_clip
        self.ppo_steps = ppo_steps
        self.global_iteration = 0

        # self.policy = Actor(env, **actor_args).to(device)
        # self.baseline = Critic(env, lr=0.001, **critic_args).to(device)
        self.policy = DiagNormalPolicy(self.env.state_size, self.env.action_size, device=self.device)
        self.baseline = LinearValue(self.env.state_size, self.env.action_size)

        self.policy.to(self.device)
        self.baseline.to(self.device)

        self.optimizer = torch.optim.Adam(self.policy.parameters(), meta_lr)

        if tensorboard_log is not None:
            self.run_name = name + "_" + str(int(time.time()))
            self.writer = SummaryWriter(f"{tensorboard_log}/{self.run_name}")
        else:
            self.writer = None


    def save(self, path="./"):
        torch.save(self.baseline.state_dict(), path + "/baseline.pt")
        torch.save(self.policy.state_dict(), path + "/policy.pt")


    def load(self, path="./"):
        self.baseline.load_state_dict(torch.load(path + "/baseline.pt"))
        self.policy.load_state_dict(torch.load(path + "/policy.pt"))


    def collect_steps(self, policy, n_episodes):
        # replay = ch.ExperienceReplay(device=self.device)
        # for i in range(n_episodes):
        #     state = self.env.reset()

        #     while True:
        #         with torch.no_grad():
        #             mass = policy.density(state)
        #         action = mass.sample()
        #         log_prob = mass.log_prob(action).mean(dim=1, keepdim=True)
        #         next_state, reward, done, _ = self.env.step(action)

                # replay.append(state,
                #             action,
                #             reward,
                #             next_state,
                #             done,
                #             log_prob=log_prob)
                
        #         if done.any():
        #             break

        #         state = next_state

        self.env.reset()
        task = ch.envs.Runner(self.env)
        replay = task.run(policy, episodes=n_episodes).to(self.device)

        with torch.no_grad():
            next_state_value = self.baseline(replay[-1].next_state)
            mass = policy.density(replay.state())

        log_probs = mass.log_prob(replay.action()).mean(dim=1, keepdim=True)
        values = self.baseline(replay.state())

        advantages = ch.generalized_advantage(self.gamma,
                                                self.tau,
                                                replay.reward(),
                                                replay.done(),
                                                values.detach(),
                                                next_state_value)
        returns = advantages + values.detach()
        advantages = ch.normalize(advantages, epsilon=1e-8)

        for i, sars in enumerate(replay):
            sars.returns = returns[i]
            sars.advantage = advantages[i]
            sars.log_prob = log_probs[i]

        # if self.value_clip:
        #     value_loss = ppo.state_value_loss(values,
        #                                     replay.value(),
        #                                     returns,
        #                                     clip=self.value_clip)
        # else:
        #     value_loss = a2c.state_value_loss(values, returns)
        # self.baseline.optimizer.zero_grad()
        # value_loss.backward()
        # self.baseline.optimizer.step()

        self.baseline.fit(replay.state(), returns)
        return replay


    def maml_a2c_loss(self, train_episodes, learner):
        # Update policy and baseline
        states = train_episodes.state()
        actions = train_episodes.action()
        density = learner.density(states)
        log_probs = density.log_prob(actions).mean(dim=1, keepdim=True)

        advantages = train_episodes.advantage()
        return a2c.policy_loss(log_probs, train_episodes.advantage())


    def fast_adapt(self, clone, train_episodes, first_order=False):
        second_order = not first_order
        loss = self.maml_a2c_loss(train_episodes, clone)
        gradients = autograd.grad(loss,
                                clone.parameters(),
                                retain_graph=second_order,
                                create_graph=second_order)
        return l2l.algorithms.maml.maml_update(clone, self.adapt_lr, gradients)
        

    def meta_loss(self, iteration_replays, iteration_policies, policy):
        mean_loss = 0.0
        for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies),
                                            total=len(iteration_replays),
                                            desc='Surrogate Loss',
                                            leave=False):
            train_replays = task_replays[:-1]
            valid_episodes = task_replays[-1]
            new_policy = l2l.clone_module(policy)

            # Fast Adapt
            for train_episodes in train_replays:
                new_policy = self.fast_adapt(new_policy, train_episodes, first_order=False)

            # Useful values
            states = valid_episodes.state()
            actions = valid_episodes.action()

            # Compute KL
            old_densities = old_policy.density(states)
            new_densities = new_policy.density(states)

            # Compute Surrogate Loss
            advantages = valid_episodes.advantage()
            old_log_probs = old_densities.log_prob(actions).mean(dim=1, keepdim=True).detach()
            new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True)
            mean_loss += ppo.policy_loss(new_log_probs, 
                                         old_log_probs, 
                                         advantages,
                                         clip=self.policy_clip)
        mean_loss /= len(iteration_replays)
        return mean_loss


    def meta_optimize(self, iteration_replays, iteration_policies):
        for ppo_epoch in range(self.ppo_steps):
            loss = self.meta_loss(iteration_replays, iteration_policies, self.policy)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        
        if self.writer is not None:
            self.writer.add_scalar("loss", loss, self.global_iteration)


    def train(self, num_iterations=100):
        for iteration in range(num_iterations):
            self.global_iteration += 1
            iteration_reward = 0.0
            iteration_replays = []
            iteration_policies = []
            iter_loss = 0.0

            for task_config in tqdm(self.env.sample_tasks(self.meta_batch_size), leave=False, desc='Data'):
                clone = deepcopy(self.policy)
                self.env.set_task(task_config)
                task_replay = []

                # Fast Adapt
                for step in range(self.adapt_steps):
                    train_episodes = self.collect_steps(clone, n_episodes=self.adapt_batch_size)
                    self.fast_adapt(clone, train_episodes, first_order=True)
                    task_replay.append(train_episodes)

                # Compute Validation Loss
                valid_episodes = self.collect_steps(clone, n_episodes=self.adapt_batch_size)
                task_replay.append(valid_episodes)
                iteration_reward += valid_episodes.reward().sum().item() / self.adapt_batch_size
                iteration_replays.append(task_replay)
                iteration_policies.append(clone)

            # Print statistics
            print('\nIteration', self.global_iteration)
            adaptation_reward = iteration_reward / self.meta_batch_size
            print('adaptation_reward', adaptation_reward)

            if self.writer is not None:
                self.writer.add_scalar("adaptation_reward", adaptation_reward, self.global_iteration)

            self.meta_optimize(iteration_replays, iteration_policies)


In [None]:
aa = MAMLPPO('HalfCheetahForwardBackward-v1')

Running on: cuda




In [None]:
aa.train(100)

torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in the returned tuple (although it returns other information about the problem).
To get the qr decomposition consider using torch.linalg.qr.
The returned solution in torch.lstsq stored the residuals of the solution in the last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the residuals in the field 'residuals' of the returned named tuple.
The unpacking of the solution, as in
X, _ = torch.lstsq(B, A).solution[:A.size(1)]
should be replaced with
X = torch.linalg.lstsq(A, B).solution (Triggered internally at  ../aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp:3276.)
  coeffs, _ = th.lstsq(b, A)



Iteration 1
adaptation_reward -15.239832963943481





Iteration 2
adaptation_reward -21.80439588546753





Iteration 3
adaptation_reward -14.98247666358948





Iteration 4
adaptation_reward -26.02806373596191





Iteration 5
adaptation_reward -40.9422484588623





Iteration 6
adaptation_reward -21.61346605300903





Iteration 7
adaptation_reward -22.95705913543701





Iteration 8
adaptation_reward -19.712455844879152





Iteration 9
adaptation_reward -19.058130903244017





Iteration 10
adaptation_reward -16.015074186325073





Iteration 11
adaptation_reward -19.549395561218265





Iteration 12
adaptation_reward -25.23745170593262





Iteration 13
adaptation_reward -16.58354446411133





Iteration 14
adaptation_reward -20.103928909301754





Iteration 15
adaptation_reward -30.201771316528323





Iteration 16
adaptation_reward -21.809322128295896





Iteration 17
adaptation_reward -20.62111457824707





Iteration 18
adaptation_reward -18.594097137451172





Iteration 19
adaptation_reward -15.111018524169916





Iteration 20
adaptation_reward -53.71457153320313





Iteration 21
adaptation_reward -23.496156616210936





Iteration 22
adaptation_reward -32.101380615234376





Iteration 23
adaptation_reward -20.220051879882813





Iteration 24
adaptation_reward -22.165815582275393





Iteration 25
adaptation_reward -8.222628173828124





Iteration 26
adaptation_reward -37.30268112182617





Iteration 27
adaptation_reward -14.722551269531252





Iteration 28
adaptation_reward -32.08169250488281





Iteration 29
adaptation_reward -33.24332565307617





Iteration 30
adaptation_reward -33.2961939239502





Iteration 31
adaptation_reward -22.374862670898438





Iteration 32
adaptation_reward -22.475312881469726





Iteration 33
adaptation_reward -9.784945907592775





Iteration 34
adaptation_reward -44.5146508026123





Iteration 35
adaptation_reward -37.16360359191894





Iteration 36
adaptation_reward -16.57937843322754





Iteration 37
adaptation_reward -15.947536010742189





Iteration 38
adaptation_reward -23.439432907104496





Iteration 39
adaptation_reward -25.786630191802978





Iteration 40
adaptation_reward -4.652285308837891





Iteration 41
adaptation_reward -16.95811584472656





Iteration 42
adaptation_reward -13.229454650878903





Iteration 43
adaptation_reward -9.597256927490236





Iteration 44
adaptation_reward -14.462917175292969





Iteration 45
adaptation_reward -15.42728607177734





Iteration 46
adaptation_reward -24.236006469726554





Iteration 47
adaptation_reward -16.12834014892578





Iteration 48
adaptation_reward -4.841907653808594





Iteration 49
adaptation_reward -65.41198593139649





Iteration 50
adaptation_reward -13.75320785522461





Iteration 51
adaptation_reward -20.569122009277343





Iteration 52
adaptation_reward -0.9126600646972678


Data:  20%|██        | 4/20 [00:21<01:25,  5.37s/it]