# PEARL TRAINING CODE

In [63]:
import copy
import numpy as np

import torch
from torch import nn as nn
import torch.optim as optim
import torch.nn.functional as F
from rlkit.torch.torch_rl_algorithm import TorchTrainer
from collections import OrderedDict
from itertools import chain
from torch.distributions import kl_divergence

In [64]:
def _product_of_gaussians(mus, sigmas_squared):
    '''
    compute mu, sigma of product of gaussians
    '''
    sigmas_squared = torch.clamp(sigmas_squared, min=1e-7)
    sigma_squared = 1. / torch.sum(torch.reciprocal(sigmas_squared), dim=0)
    mu = sigma_squared * torch.sum(mus / sigmas_squared, dim=0)
    return mu, sigma_squared


class PEARLAgent:
    def __init__(self,
                 latent_dim,
                 context_encoder,
                 reward_predictor,
                 obs_keys=None,
                 use_next_obs_in_context=False,
                 ):
        self.latent_dim = latent_dim
        self.context_encoder = context_encoder
        self.context_encoder_rp = copy.deepcopy(self.context_encoder)
        self.context_encoder_rp.to(ptu.device)
        self.reward_predictor = reward_predictor
        self.obs_keys = obs_keys
        self.use_next_obs_in_context = use_next_obs_in_context
        
        self.latent_prior = torch.distributions.Normal(
            ptu.zeros(self.latent_dim),
            ptu.ones(self.latent_dim)
        )

        
    def latent_posterior(self, context, squeeze=False, use_encoder_copy=False):
        if isinstance(context, np.ndarray):
            context = ptu.from_numpy(context)
        if use_encoder_copy:
            context_encoder = self.context_encoder_rp
        else:
            context_encoder = self.context_encoder

        t, b = context.size(0), context.size(1)
        context_flat = context.view(t*b, -1)
        params = context_encoder(context_flat)
        params = params.view(context.size(0), -1, context_encoder.output_size)
        mu = params[..., :self.latent_dim]
        sigma_squared = F.softplus(params[..., self.latent_dim:])
        z_params = [_product_of_gaussians(m, s) for m, s in zip(torch.unbind(mu), torch.unbind(sigma_squared))]
        z_means = torch.stack([p[0] for p in z_params])
        z_vars = torch.stack([p[1] for p in z_params])
        if squeeze:
            z_means = z_means.squeeze(dim=0)
            z_vars = z_vars.squeeze(dim=0)
            
        return torch.distributions.Normal(z_means, torch.sqrt(z_vars))

In [88]:
class PearlTrainer(TorchTrainer):
    def __init__(
        self,
        agent: PEARLAgent,
        env,
        latent_dim,
        context_encoder,
        reward_predictor,
        context_decoder,

        train_context_decoder=True,
        context_lr=1e-3,
        kl_lambda=1.0,
        optimizer_class=optim.Adam,

        discount=0.99,
        reward_scale=1.0,
        
        kl_annealing=False,
        kl_annealing_x0=50000,
        kl_annealing_k=0.00005,
        kl_annealing_start=20000
    ):

        self.train_context_decoder = train_context_decoder
        self.train_encoder_decoder = True
        self.reward_scale = reward_scale
        self.discount = discount

        self.agent = agent
        self.context_encoder = context_encoder
        self.context_decoder = context_decoder
        self.reward_predictor = reward_predictor
        
        self.kl_annealing = kl_annealing
        self.kl_annealing_x0 = kl_annealing_x0
        self.kl_annealing_k = kl_annealing_k
        self.kl_annealing_start = kl_annealing_start

        self.env = env
        self.latent_dim = latent_dim

        self.kl_lambda = kl_lambda
        if train_context_decoder:
            self.context_optimizer = optimizer_class(
                chain(
                    self.context_encoder.parameters(),
                    self.context_decoder.parameters(),
                ),
                lr=context_lr,
            )
        else:
            self.context_optimizer = optimizer_class(
                self.context_encoder.parameters(),
                lr=context_lr,
            )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        
    
    def train_from_torch(self, batch, learn_task_z=True):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        context = batch['context']


        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)
        unscaled_rewards_flat = rewards.view(t * b, 1)
        rewards_flat = unscaled_rewards_flat * self.reward_scale
        terms_flat = terminals.view(t * b, 1)
        
        
        """
        Policy and Alpha Loss
        """
        p_z = self.agent.latent_posterior(context, use_encoder_copy = not learn_task_z)
        task_z_with_grad = p_z.rsample()
        task_z_with_grad = [z.repeat(b, 1) for z in task_z_with_grad]
        task_z_with_grad = torch.cat(task_z_with_grad, dim=0)
        task_z_detached = task_z_with_grad.detach()


        """
        Context Encoder Loss
        """
        kl_div = kl_divergence(p_z, self.agent.latent_prior).mean(dim=0).sum()
        if self.kl_annealing:
            if self._n_train_steps_total < self.kl_annealing_start:
                kl_lambda = 0
            else:
                kl_lambda = self.kl_lambda * float(1 / (1 + np.exp(-self.kl_annealing_k * 
                    (self._n_train_steps_total - self.kl_annealing_start - self.kl_annealing_x0))))
        else:
            kl_lambda = self.kl_lambda
        kl_loss = kl_lambda * kl_div

        if self.train_context_decoder:
            # TODO: change to use a distribution
            if learn_task_z:
                task_z = task_z_with_grad
            else:
                task_z = task_z_detached
            reward_pred = self.context_decoder(obs, actions, task_z)
            reward_prediction_loss = ((reward_pred - unscaled_rewards_flat)**2).mean()
            context_loss = kl_loss + reward_prediction_loss
        else:
            context_loss = kl_loss
            reward_prediction_loss = ptu.zeros(1)

        """
        Update networks
        """
        self.context_optimizer.zero_grad()
        context_loss.backward()
        self.context_optimizer.step()

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            if self.kl_annealing:
                self.eval_statistics['task_embedding/kl_lambda'] = (
                    kl_lambda
                )
            self.eval_statistics['task_embedding/kl_divergence'] = (
                ptu.get_numpy(kl_div)
            )
            self.eval_statistics['task_embedding/kl_loss'] = (
                ptu.get_numpy(kl_loss)
            )
            self.eval_statistics['task_embedding/reward_prediction_loss'] = (
                ptu.get_numpy(reward_prediction_loss)
            )
            self.eval_statistics['task_embedding/context_loss'] = (
                ptu.get_numpy(context_loss)
            )

        self._n_train_steps_total += 1
        
    @property
    def networks(self):
        return [
            self.context_encoder,
            self.context_decoder,
            self.reward_predictor,
        ]

# RUN PEARL

In [66]:
import numpy as np
import torch

import rlkit.torch.pytorch_util as ptu
from rlkit.data_management.multitask_replay_buffer import ObsDictMultiTaskReplayBuffer
from rlkit.misc.roboverse_utils import add_data_to_buffer_multitask_v2, get_buffer_size_multitask
from rlkit.torch.sac.policies import GaussianCNNPolicy
from rlkit.torch.networks.cnn import CNN, ConcatCNN
from rlkit.torch.core import np_to_pytorch_batch
import roboverse

import matplotlib.pyplot as plt
import os
gpu_id = 3
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
ptu.set_gpu_mode(True)

In [106]:
#file paths
BUFFER = ('/nfs/kun1/users/avi/scripted_sim_datasets/'
    'may26_Widow250PickPlaceMetaTrainMultiObjectMultiContainer-v0_16K_save_all_noise_0.1_2021-05-26T16-12-04/'
    'may26_Widow250PickPlaceMetaTrainMultiObjectMultiContainer-v0_16K_save_all_noise_0.1_2021-05-26T16-12-04_16000.npy')
VALIDATION_BUFFER = ('/nfs/kun1/users/jonathan/minibullet_data/'
    'jun_14_validation_Widow250PickPlaceMetaTrainMultiObjectMultiContainer-v0_1K_save_all_noise_0.1_2021-06-14T11-27-40/'
    'jun_14_validation_Widow250PickPlaceMetaTrainMultiObjectMultiContainer-v0_1K_save_all_noise_0.1_2021-06-14T11-27-40_1024.npy')
ENV = 'Widow250PickPlaceMetaTestMultiObjectMultiContainer-v0'

#agent kwargs
LATENT_DIM = 5
USE_NEXT_OBS_IN_CONTEXT = False
_DEBUG_DO_NOT_SQRT = False

#context kwargs
META_BATCH_SIZE = 4
TASK_EMBEDDING_BATCH_SIZE = 64

#trainer kwargs
NUM_BATCHES = 1000 * 200
TRAIN_TASKS = np.arange(32)
BATCH_SIZE = 128
LOGGING_PERIOD = 1000
KL_ANNEALING_X0 = 1000 * 20
KL_ANNEALING_K = 0.05 / 1000
KL_ANNEALING_START = 1000 * 0

In [74]:
expl_env = roboverse.make(ENV, transpose_image=True)
state_observation_dim = expl_env.observation_space.spaces['state'].low.size
action_dim = expl_env.action_space.low.size
reward_dim = 1
cnn_params = dict(
        input_width=48,
        input_height=48,
        input_channels=3,
        kernel_sizes=[3, 3, 3],
        n_channels=[16, 16, 16],
        strides=[1, 1, 1],
        hidden_sizes=[1024, 512, 256],
        paddings=[1, 1, 1],
        pool_type='max2d',
        pool_sizes=[2, 2, 1],  # the one at the end means no pool
        pool_strides=[2, 2, 1],
        pool_paddings=[0, 0, 0],
        image_augmentation=True,
        image_augmentation_padding=4,
    )
context_encoder_output_dim = LATENT_DIM * 2
cnn_params.update(
    added_fc_input_size=state_observation_dim + action_dim + reward_dim,
    output_size=context_encoder_output_dim,
    hidden_sizes=[256, 256],
)
context_encoder = ConcatCNN(**cnn_params)
context_encoder.to(ptu.device)
cnn_params.update(
    added_fc_input_size=state_observation_dim + action_dim + LATENT_DIM,
    output_size=1,
    hidden_sizes=[256, 256],
    image_augmentation=False,
)
context_decoder = ConcatCNN(**cnn_params)
context_decoder.to(ptu.device)

ConcatCNN(
  (hidden_activation): ReLU()
  (conv_layers): ModuleList(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv_norm_layers): ModuleList()
  (pool_layers): ModuleList(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layers): ModuleList(
    (0): Linear(in_features=2327, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
  )
  (fc_norm_layers): ModuleList()
  (last_fc): Linear(in_features=256, out_features=1, bias=True)
)

In [71]:
train_task_indices = list(range(32))
observation_keys = ['image', 'state']
with open(BUFFER, 'rb') as fl:
    data = np.load(fl, allow_pickle=True)
num_transitions = get_buffer_size_multitask(data)
max_replay_buffer_size = num_transitions + 10

replay_buffer = ObsDictMultiTaskReplayBuffer(
    max_replay_buffer_size,
    expl_env,
    train_task_indices,
    use_next_obs_in_context=False,
    sparse_rewards=False,
    observation_keys=observation_keys
)
add_data_to_buffer_multitask_v2(data, replay_buffer, observation_keys)

In [107]:
train_task_indices = list(range(32))
observation_keys = ['image', 'state']
with open(VALIDATION_BUFFER, 'rb') as fl:
    data = np.load(fl, allow_pickle=True)
num_transitions = get_buffer_size_multitask(data)
max_replay_buffer_size = num_transitions + 10

validation_buffer = ObsDictMultiTaskReplayBuffer(
    max_replay_buffer_size,
    expl_env,
    train_task_indices,
    use_next_obs_in_context=False,
    sparse_rewards=False,
    observation_keys=observation_keys
)
add_data_to_buffer_multitask_v2(data, validation_buffer, observation_keys)

In [89]:
agent = PEARLAgent(
    LATENT_DIM,
    context_encoder,
    None,
    obs_keys=observation_keys,
)

trainer = PearlTrainer(
    agent=agent,
    env=expl_env,
    latent_dim=LATENT_DIM,
    reward_predictor=context_decoder,
    context_encoder=context_encoder,
    context_decoder=context_decoder,
    kl_annealing=True,
    kl_annealing_x0=KL_ANNEALING_X0,
    kl_annealing_k=KL_ANNEALING_K,
    kl_annealing_start=KL_ANNEALING_START
)

In [90]:
for i in range(NUM_BATCHES):
    if i % LOGGING_PERIOD == 0:
        trainer._need_to_update_eval_statistics = True
    task_indices = np.random.choice(
        TRAIN_TASKS, META_BATCH_SIZE,
    )
    train_data = replay_buffer.sample_batch(
        task_indices,
        BATCH_SIZE,
    )
    train_data = np_to_pytorch_batch(train_data)
    train_data['context'] = (
        replay_buffer.sample_context(
        task_indices,
        TASK_EMBEDDING_BATCH_SIZE,
    ))

    trainer.train_from_torch(train_data, learn_task_z=True)
    if i % LOGGING_PERIOD == 0:
        print(trainer.eval_statistics)

OrderedDict([('task_embedding/kl_lambda', 0.2689414213699951), ('task_embedding/kl_divergence', array(7.4598436, dtype=float32)), ('task_embedding/kl_loss', array(2.006261, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00091146, dtype=float32)), ('task_embedding/context_loss', array(2.0071726, dtype=float32))])
OrderedDict([('task_embedding/kl_lambda', 0.2788848219771369), ('task_embedding/kl_divergence', array(6.292349, dtype=float32)), ('task_embedding/kl_loss', array(1.7548406, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00094373, dtype=float32)), ('task_embedding/context_loss', array(1.7557844, dtype=float32))])
OrderedDict([('task_embedding/kl_lambda', 0.289050497374996), ('task_embedding/kl_divergence', array(5.4834385, dtype=float32)), ('task_embedding/kl_loss', array(1.5849906, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00047288, dtype=float32)), ('task_embedding/context_loss', array(1.5854635, dtype=float32))]

OrderedDict([('task_embedding/kl_lambda', 0.5621765008857981), ('task_embedding/kl_divergence', array(0.9137697, dtype=float32)), ('task_embedding/kl_loss', array(0.5136999, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00072458, dtype=float32)), ('task_embedding/context_loss', array(0.51442444, dtype=float32))])
OrderedDict([('task_embedding/kl_lambda', 0.574442516811659), ('task_embedding/kl_divergence', array(0.85190517, dtype=float32)), ('task_embedding/kl_loss', array(0.48937052, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00145546, dtype=float32)), ('task_embedding/context_loss', array(0.49082598, dtype=float32))])
OrderedDict([('task_embedding/kl_lambda', 0.5866175789173301), ('task_embedding/kl_divergence', array(0.793808, dtype=float32)), ('task_embedding/kl_loss', array(0.46566173, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00280774, dtype=float32)), ('task_embedding/context_loss', array(0.46846947, dtype=flo

OrderedDict([('task_embedding/kl_lambda', 0.8175744761936437), ('task_embedding/kl_divergence', array(0.10187259, dtype=float32)), ('task_embedding/kl_loss', array(0.08328843, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00015292, dtype=float32)), ('task_embedding/context_loss', array(0.08344135, dtype=float32))])
OrderedDict([('task_embedding/kl_lambda', 0.8249137318359602), ('task_embedding/kl_divergence', array(0.08961202, dtype=float32)), ('task_embedding/kl_loss', array(0.07392219, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.0035808, dtype=float32)), ('task_embedding/context_loss', array(0.07750299, dtype=float32))])
OrderedDict([('task_embedding/kl_lambda', 0.8320183851339245), ('task_embedding/kl_divergence', array(0.07831882, dtype=float32)), ('task_embedding/kl_loss', array(0.0651627, dtype=float32)), ('task_embedding/reward_prediction_loss', array(0.00047984, dtype=float32)), ('task_embedding/context_loss', array(0.06564254, dtype=

KeyboardInterrupt: 

In [93]:
torch.save(trainer.context_encoder.state_dict(), 'encoder.pt')
torch.save(trainer.reward_predictor.state_dict(), 'decoder.pt')

In [117]:
train_data = validation_buffer.sample_batch(
    [0],
    10,
)
train_data = np_to_pytorch_batch(train_data)


obs = train_data['observations']
nobs = train_data['next_observations']
actions = train_data['actions']
rewards = train_data['rewards']
contexts = (
    replay_buffer.sample_context(
    [0],
    TASK_EMBEDDING_BATCH_SIZE,
))
task_embeddings = agent.latent_posterior(contexts)
task_z = task_embeddings.rsample()
t, b, _ = obs.size()
obs = obs.view(t * b, -1)
nobs = nobs.view(t * b, -1)
actions = actions.view(t * b, -1)
task_z = [z.repeat(b, 1) for z in task_z]
task_z = torch.cat(task_z, dim=0)
reward_pred = context_decoder(obs, actions, task_z)

print(torch.round(reward_pred))
print(rewards)



tensor([[-0.],
        [1.],
        [0.],
        [-0.],
        [0.],
        [1.],
        [0.],
        [-0.],
        [-0.],
        [0.]], device='cuda:0', grad_fn=<RoundBackward>)
tensor([[[0.],
         [0.],
         [0.],
         [1.],
         [0.],
         [0.],
         [1.],
         [0.],
         [0.],
         [1.]]], device='cuda:0')
