### P E A R L: P R O B A B I L I S T I C - E M B E D D I N G - F O R - A C T O R - C R I T I C - R L

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import gymnasium as gym
import random

from torch.utils.tensorboard import SummaryWriter


### L O G G I N G

In [2]:
writer = SummaryWriter(log_dir = './runs/PEARL')


### D E V I C E 

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### H E L P E R 

In [4]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)

def safe_stack(x):
    
    return torch.stack(x).to(device)


### M E T A - E N V


In [5]:
class Meta_Walker_2d:
    
    def __init__(self):
        
        self.base_env = gym.make('Walker2d-v5')
        self.target_velocity = 1.0
        
    def sample_tasks(self, num_tasks):
        
        self.tasks = []
        
        for _ in range(num_tasks):
            
            gravity = np.random.uniform(5.0, 15.0)
            torso_mass = np.random.uniform(1.0, 5.0)
            target_velocity = np.random.uniform(0.5, 3.0)
            
            self.tasks.append((gravity, torso_mass, target_velocity))
            
        return self.tasks
    
    def set_task(self, task):
        
        raw_env = self.base_env.unwrapped
        raw_env.model.opt.gravity[-1] = -task[0]
        raw_env.model.body_mass[1] = task[1]
        self.target_velocity = task[2]
        
    def reset(self, task_idx = 0):
        
        self.set_task(self.tasks[task_idx])
        obs, _ = self.base_env.reset()
        
        return obs
    
    def step(self, action):
        
        next_obs, reward, terminated, truncated, info = self.base_env.step(action)
        done = terminated | truncated
        raw_env = self.base_env.unwrapped
        vel = raw_env.data.qvel[0]
        reward -= 0.5 * abs(vel - self.target_velocity)
        
        return next_obs, reward, done, info
    
    def get_number(self):
        
        state_dim = self.base_env.observation_space.shape[0]
        action_dim = self.base_env.action_space.shape[0]
        max_action = self.base_env.action_space.high[0]
        reward_dim = 1
        
        return state_dim, action_dim, max_action, reward_dim
    
    def close(self):
        
        self.base_env.close()


### T E S T

In [6]:
env = Meta_Walker_2d()

env.sample_tasks(num_tasks = 4)

obs = env.reset()

print(f'obs shape: {obs.shape}')

state_dim, action_dim, max_action, reward_dim = env.get_number()

print(f'state dim: {state_dim} | action dim: {action_dim} | max action: {max_action}')


obs shape: (17,)
state dim: 17 | action dim: 6 | max action: 1.0


### A S S E M B L Y

In [7]:
head_1 = 64
head_2 = 128
head_3 = 128
head_4 = 64

latent_dim = 64


### P O S T E R I O R - q [ z | c ]

In [8]:
class pearl_encoder(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, reward_dim = reward_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, latent_dim = latent_dim):
        super(pearl_encoder, self).__init__()
        
        # find the posterior
        
        self.encoder_net = nn.Sequential(
            
            nn.Linear(2 * state_dim + action_dim + reward_dim, head_1),
            nn.SiLU(),
            
            nn.LayerNorm(head_1),
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.LayerNorm(head_3),
            nn.Linear(head_3, head_4),
            nn.SiLU()
        )
        
        self.mu = nn.Linear(head_4, latent_dim)
        self.log_std = nn.Linear(head_4, latent_dim)
        
        self.apply(self.init_weights)
        
    def init_weights(self, module):
            
        if isinstance(module, nn.Linear):
                
            nn.init.kaiming_normal_(module.weight, a = 0, nonlinearity = 'relu')
            
            if module.bias is not None:
                
                nn.init.zeros_(module.bias)
                
    def forward(self, state, action, reward, next_state, sample = False, reduce = True):
        
        # concat
        
        cat = torch.cat([state, action, reward, next_state], dim = -1)
        
        # encoder net
        
        x = self.encoder_net(cat)
        
        # mu and log std head
        
        mu = self.mu(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, -5 , 2)
        
        std = torch.exp(log_std)
        
        eps = torch.randn_like(std)
        
        if sample:
        
            z = mu + std * eps
            
        else :
            
            z = mu
            
        if reduce:
            
            z = z.mean(dim = 0, keepdim = True)
            mu = mu.mean(dim = 0, keepdim = True)
            std = std.mean(dim = 0, keepdim = True)
        
        return z, mu, log_std


### S E T U P 

In [9]:
PEARL_ENCODER = pearl_encoder().to(device)

print(PEARL_ENCODER)


pearl_encoder(
  (encoder_net): Sequential(
    (0): Linear(in_features=41, out_features=64, bias=True)
    (1): SiLU()
    (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=64, out_features=128, bias=True)
    (4): SiLU()
    (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): SiLU()
    (8): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (9): Linear(in_features=128, out_features=64, bias=True)
    (10): SiLU()
  )
  (mu): Linear(in_features=64, out_features=64, bias=True)
  (log_std): Linear(in_features=64, out_features=64, bias=True)
)


### P E A R L 

In [10]:
class actor_critic(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, latent_dim = latent_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action):
        super(actor_critic, self).__init__()
        
        # max action
        
        self.max_action = max_action
        
        # create mlp
        
        def create_mlp(input_dim):
            
            process = nn.Sequential(
                
            nn.Linear(input_dim, head_1),
            nn.SiLU(),
                
            nn.LayerNorm(head_1),
            nn.Linear(head_1, head_2),
            nn.SiLU(),
                
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
                
            nn.LayerNorm(head_3),
            nn.Linear(head_3, head_4),
            nn.SiLU()
            ) 
        
            return process
        
        # actor and critic heads
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
        self.critic_head = nn.Linear(head_4, 1)
        self.critic_head_2 = nn.Linear(head_4, 1)
        
        # specific actor and critic mlp
        
        self.actor_mlp = create_mlp(input_dim = state_dim + latent_dim)
        self.critic_mlp = create_mlp(input_dim = state_dim + action_dim + latent_dim)
        self.critic_mlp_2 = create_mlp(input_dim = state_dim + action_dim + latent_dim)
        
        # apply initialization
        
        self.apply(self.init_weights)
        
    def init_weights(self, module):
            
        if isinstance(module, nn.Linear):
                
            nn.init.kaiming_normal_(module.weight, a = 0, nonlinearity = 'relu')
            
            if module.bias is not None:
                
                nn.init.zeros_(module.bias)        
                
    def actor_forward(self, state, latent_z, deterministic = False):
        
        # cat
        
        actor_ca = torch.cat([state, latent_z], dim = -1)
        
        # pass to mlp
        
        actor_pass = self.actor_mlp(actor_ca)
        
        # mu and log head
        
        mu = self.mu(actor_pass)
        if deterministic: return mu

        log_std = self.log_std(actor_pass)
        log_std = torch.clamp(log_std, -5, 2)
        std = torch.exp(log_std)
        
        dist = torch.distributions.Normal(mu, std)
        sampled_z = dist.rsample()
        log_prob = dist.log_prob(sampled_z)
        tanh_Z =torch.tanh(sampled_z)
        action = tanh_Z * self.max_action 
        
        squash = (1 - tanh_Z.pow(2) + 1e-6).log()
        log_prob = log_prob - squash
        log_prob = log_prob.sum(dim = -1, keepdim = True)
        
        return action, log_prob, mu, log_std
            
    def critic_forward(self, state, action, latent_z):
        
        # cat
        
        critic_ca = torch.cat([state, action, latent_z], dim = -1)
        
        # pass to mlp
        
        critic_pass = self.critic_mlp(critic_ca)
        critic_pass_2 = self.critic_mlp_2(critic_ca)
        
        # pass mlp output to heads
        
        critic_val = self.critic_head(critic_pass)
        critic_val_2 = self.critic_head_2(critic_pass_2)
        
        
        return critic_val, critic_val_2
    
    def forward(self, state, latent_z):
        
        action, log_prob, mu, log_std = self.actor_forward(state, latent_z)
        
        critic_val, critic_val_2 = self.critic_forward(state, action, latent_z)
        
        return action, log_prob, mu, log_std, critic_val, critic_val_2
        

### S E T U P

In [11]:
PEARL = actor_critic().to(device)

print(PEARL)


actor_critic(
  (mu): Linear(in_features=64, out_features=6, bias=True)
  (log_std): Linear(in_features=64, out_features=6, bias=True)
  (critic_head): Linear(in_features=64, out_features=1, bias=True)
  (critic_head_2): Linear(in_features=64, out_features=1, bias=True)
  (actor_mlp): Sequential(
    (0): Linear(in_features=81, out_features=64, bias=True)
    (1): SiLU()
    (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=64, out_features=128, bias=True)
    (4): SiLU()
    (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): SiLU()
    (8): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (9): Linear(in_features=128, out_features=64, bias=True)
    (10): SiLU()
  )
  (critic_mlp): Sequential(
    (0): Linear(in_features=87, out_features=64, bias=True)
    (1): SiLU()
    (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=64, out

### T A R G E T - C R I T I C

In [12]:
class target_critic(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, latent_dim = latent_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action):
        super(target_critic, self).__init__()

 
        def create_mlp(input_dim):
                    
            process = nn.Sequential(
                        
                nn.Linear(input_dim, head_1),
                nn.SiLU(),
                        
                nn.LayerNorm(head_1),
                nn.Linear(head_1, head_2),
                nn.SiLU(),
                        
                nn.LayerNorm(head_2),
                nn.Linear(head_2, head_3),
                nn.SiLU(),
                        
                nn.LayerNorm(head_3),
                nn.Linear(head_3, head_4),
                nn.SiLU()
            ) 
            
            return process
        
        # critic heads
        
        self.critic_head_1 = nn.Linear(head_4, 1)
        self.critic_head_2 = nn.Linear(head_4, 1)
        
        # initiate mlp
        
        self.critic_1 = create_mlp(input_dim = state_dim + action_dim + latent_dim)
        self.critic_2 = create_mlp(input_dim = state_dim + action_dim + latent_dim)
        
        # Initialize
        
        self.apply(self.init_weights)
        
    def init_weights(self, module):
            
        if isinstance(module, nn.Linear):
                
            nn.init.kaiming_normal_(module.weight, a = 0, nonlinearity = 'relu')
            
            if module.bias is not None:
                
                nn.init.zeros_(module.bias)
                
    def forward(self, state, action, latent_z):
        
        critic_ca = torch.cat([state, action, latent_z], dim = -1)
        
        pass_1 = self.critic_1(critic_ca)
        pass_2 = self.critic_2(critic_ca)
        
        q1 = self.critic_head_1(pass_1)
        q2 = self.critic_head_2(pass_2)
        
        return q1, q2


### S E T U P 

In [13]:
# Setup

TARGET_CRITIC = target_critic().to(device)

# Initialize to CRITIC 

TARGET_CRITIC.critic_1.load_state_dict(PEARL.critic_mlp.state_dict())
TARGET_CRITIC.critic_2.load_state_dict(PEARL.critic_mlp_2.state_dict())
TARGET_CRITIC.critic_head_1.load_state_dict(PEARL.critic_head.state_dict())
TARGET_CRITIC.critic_head_2.load_state_dict(PEARL.critic_head_2.state_dict())


print(TARGET_CRITIC)


target_critic(
  (critic_head_1): Linear(in_features=64, out_features=1, bias=True)
  (critic_head_2): Linear(in_features=64, out_features=1, bias=True)
  (critic_1): Sequential(
    (0): Linear(in_features=87, out_features=64, bias=True)
    (1): SiLU()
    (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=64, out_features=128, bias=True)
    (4): SiLU()
    (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): SiLU()
    (8): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (9): Linear(in_features=128, out_features=64, bias=True)
    (10): SiLU()
  )
  (critic_2): Sequential(
    (0): Linear(in_features=87, out_features=64, bias=True)
    (1): SiLU()
    (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=64, out_features=128, bias=True)
    (4): SiLU()
    (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (6): Linear(

### O P T I M I Z E R 

In [14]:
# Lr

encoder_lr = 1e-4
actor_lr = 3e-4
critic_lr = 5e-4

T_max = 100

# Shared optimizer

actor_params = list(PEARL.actor_mlp.parameters()) + \
               list(PEARL.mu.parameters()) + \
               list(PEARL.log_std.parameters()) 
               
               
critic_params = list(PEARL.critic_mlp.parameters()) + \
                list(PEARL.critic_mlp_2.parameters()) + \
                list(PEARL.critic_head.parameters()) + \
                list(PEARL.critic_head_2.parameters())
                
# Optimizer
                
PEARL_OPTIMIZER = optim.AdamW([    
                
    {'params': critic_params, 'lr': critic_lr, 'weight_decay': 1e-6},
    {'params': PEARL_ENCODER.parameters(), 'lr': encoder_lr, 'weight_decay': 1e-6},
    {'params': actor_params, 'lr': actor_lr, 'weight_decay': 0},
    
])

# PEARL SCHEDULER

PEARL_SCHEDULER = optim.lr_scheduler.CosineAnnealingLR(PEARL_OPTIMIZER, T_max, eta_min = 1e-5)
    

### B U F F E R

In [15]:
class meta_buffer:
    
    def __init__(self, max_episodes, env = env):
        
        self.env = env
        self.max_episodes = max_episodes
        self.current_episode = []
        self.episodes = []
        
    def add(self, state, action, log_probs, reward, done, next_state):
        
        # convert to tensor
        
        state = safe_tensor(state)
        action = safe_tensor(action)
        reward = safe_tensor(reward)
        done = safe_tensor(done)
        next_state = safe_tensor(next_state)
        log_probs = safe_tensor(log_probs)
        
        # add them
        
        self.current_episode.append({
            
            'states': state,
            'actions': action,
            'log_probs': log_probs,
            'rewards': reward,
            'dones': done,
            'next_states': next_state
        })
        
        if done.item() == 1:
        
            self.episodes.append(self.current_episode)
            
            self.current_episode = []
            
        if len(self.episodes) > self.max_episodes:
                
            self.episodes.pop(0)
            
    def sample(self, batch_size, fixed_length):
        
        #for _ in range(batch_size):
        
        segments = []
        masks = []
        
        sampled_episodes = random.sample(self.episodes, k=min(batch_size, len(self.episodes)))
        
        for ep in sampled_episodes:
        
            if len(ep) >= fixed_length:
                
                seg = ep[:fixed_length]
                mask = torch.ones(fixed_length, dtype=torch.float32)
                
            else:
                
                seg, mask = self.pad_episode(ep, fixed_length)
                
            segments.append(seg)
            masks.append(mask)
        
        
        def stack_field(x):
        
            return torch.stack([torch.stack([step[x] for step in seg]) for seg in segments]).to(device)
            
        batch = {
                
                'states': stack_field('states'),
                'actions': stack_field('actions'),
                'log_probs': stack_field('log_probs'),
                'rewards': stack_field('rewards'),
                'dones': stack_field('dones'),
                'next_states': stack_field('next_states'),
                'mask': torch.stack(masks).to(device)
        }
            
        return batch
    
    def pad_episode(self, ep, fixed_length):
        
        pad_length = fixed_length - len(ep)
        last = ep[-1]
        
        pad_step = {}
        
        for k,v in last.items():
            
            if torch.is_tensor(v):
                
                pad_step[k] = torch.zeros_like(v)
                
            else:
                
                pad_step[k] = v
    
        mask = torch.cat([
            
            torch.ones(len(ep), dtype = torch.float32),
            torch.zeros(pad_length, dtype = torch.float32)
        ])

        return ep + [pad_step] * pad_length, mask
    
    def clear(self):
        
        self.episodes.clear()
                
            
        

### S E T U P 

In [16]:
max_episodes = 500

buffer = meta_buffer(max_episodes)


### M E T A - E P I S O D E - R U N N E R 

In [17]:
class meta_episode_runner:
    
    def __init__(self, max_episode_length, env = env, buffer = buffer, model = PEARL, encoder = PEARL_ENCODER):
        
        self.env = env
        self.max_episode_length = max_episode_length
        self.buffer = buffer
        self.model = model
        self.encoder = encoder
        
    def run(self, num_tasks):
        
        tasks = self.env.sample_tasks(num_tasks)
        
        for task in tasks:
            
            self.env.set_task(task)
            
            obs = self.env.reset()
            obs = safe_tensor(obs)
            if obs.dim() == 1: obs = obs.unsqueeze(0)
            
            context = []
            episode_reward = 0.0

            for _ in range(self.max_episode_length):
                
                latent_z = self.get_latent_from_context(context)
                    
                # Get action from policy
                
                action, log_prob, _, _ = self.model.actor_forward(obs, latent_z)
                
                # step the env
                
                action_np = action.detach().cpu().numpy()[0]
                
                next_obs, reward, done, _ = self.env.step(action_np)
                
                episode_reward += reward.item()
                
                # correct shapes
                
                next_obs, reward, done = self.shape_corrector(next_obs, reward, done)
                
                # save to buffer
                
                self.buffer.add(obs.squeeze(0), action.squeeze(0), log_prob.squeeze(0), reward.squeeze(0), done.squeeze(0), next_obs.squeeze(0))

                # save to context
                
                context.append({
                    
                    'states': obs,
                    'actions': action,
                    'rewards': reward,
                    'next_states': next_obs    
                })
                
                obs = next_obs
                
                if done:
                    
                    break
               
            # log reward    
                
            #print(f"Task done. Total reward: {episode_reward}")
                
    def build_context_tensor(self, context):
    
        return {
            
            'states': torch.cat([e['states'] for e in context], dim = 0),
            'actions': torch.cat([e['actions'] for e in context], dim = 0),
            'rewards': torch.cat([e['rewards'] for e in context], dim = 0),
            'next_states': torch.cat([e['next_states'] for e in context], dim = 0)
        }
        
    def get_latent_from_context(self, context):
        
        if context:
                    
            ctx = self.build_context_tensor(context)
                    
            latent_z, _, _ = self.encoder(ctx['states'], ctx['actions'], ctx['rewards'], ctx['next_states'])
                    
        else:
                    
            latent_z = torch.zeros((1, latent_dim)).to(device)

        return latent_z
    
    def shape_corrector(self, next_obs, reward, done):
        
        next_obs = safe_tensor(next_obs).unsqueeze(0)
        reward = safe_tensor(reward)
        done = safe_tensor(done)
        reward = reward.view(-1, 1)
        done = done.view(-1, 1)
        
        return next_obs, reward, done


### S E T U P 

In [18]:
max_episode_length = 512

META_RUNNER = meta_episode_runner(max_episode_length)


### L O S S - F U N C

In [19]:
class loss_func:
    
    def __init__(self, gamma, tau, entropy_scalar, kl_coef, action_dim = action_dim, PEARL = PEARL, PEARL_ENCODER = PEARL_ENCODER, PEARL_OPTIMIZER = PEARL_OPTIMIZER, PEARL_SCHEDULER = PEARL_SCHEDULER, TARGET_CRITIC = TARGET_CRITIC, buffer = buffer):
        
        # network
        
        self.encoder = PEARL_ENCODER
        self.pearl = PEARL
        self.target_critic = TARGET_CRITIC
        
        # optimizer and scheduler
        
        self.pearl_optimizer = PEARL_OPTIMIZER
        self.pearl_scheduler = PEARL_SCHEDULER
        
        # hyper params
        
        self.gamma = gamma
        self.tau = tau
        self.entropy_scalar = entropy_scalar
        self.kl_coef = kl_coef
    
        
        # buffer
        
        self.buffer = buffer
        
        # auto alpha
        
        self.log_alpha = torch.nn.Parameter(torch.tensor(np.log(0.2), device = device, requires_grad = True))
        self.target_entropy = - action_dim * self.entropy_scalar
        self.alpha_optimizer = optim.AdamW([self.log_alpha], lr = 3e-4, weight_decay = 0)
        
    def compute_alpha(self):
        
        alpha = self.log_alpha.exp().detach()
        alpha = alpha.clamp(min = 1e-3)
        
        return alpha
    
    def soft_update(self, target, source):
        
        with torch.no_grad():
        
            for param, target_param in zip(source.parameters(), target.parameters()):
                
                target_param.data.copy_(param.data * self.tau + (1 - self.tau) * target_param.data)
                
    def critic_loss(self, q1, q2, target):
        
        loss_1 = F.mse_loss(q1, target)
        loss_2 = F.mse_loss(q2, target)
        
        critic_loss = loss_1 + loss_2
        
        return critic_loss
    
    def actor_loss(self, detached_q1, detached_q2, alpha, old_log_probs):
        
        q_pi = torch.min(detached_q1, detached_q2)
        
        actor_loss = - (alpha * old_log_probs - q_pi).mean()
        
        return actor_loss
    
    def alpha_loss(self, detached_old_log_probs):
        
        alpha_loss = (self.log_alpha.exp() * (detached_old_log_probs - self.target_entropy)).mean()
        
        return alpha_loss
    
    def update(self, batch_size, fixed_length):
        
        batch = self.buffer.sample(batch_size, fixed_length)
        
        # sample
        
        states = batch['states']
        actions = batch['actions']
        rewards = batch['rewards'] 
        dones = batch['dones']
        next_states = batch['next_states']
        
        # compute target
        
        latent_z_actor, _, _ = self.encoder.forward(states, actions, rewards, next_states, reduce = False)
        
        latent_z_critic = latent_z_actor.detach().clone()
        
        # ''' i think we should also detach these transitions and rely on the gradients of the latent z or sepearate the paths clearly'''
                              
        with torch.no_grad():
            
            alpha = self.compute_alpha()
            
            next_action, next_log_probs, _, _ = self.pearl.actor_forward(next_states, latent_z_critic)
            
            target_1, target_2 = self.target_critic.forward(next_states, next_action, latent_z_critic)
            
            target_val = torch.min(target_1, target_2)
            
            target = rewards + self.gamma * (1 - dones) * (target_val - alpha * next_log_probs)
            
        # compute current vals
            
        q1, q2 = self.pearl.critic_forward(states, actions.detach(), latent_z_critic)
            
        # cal critic loss
        
        critic_loss = self.critic_loss(q1 = q1, q2 = q2, target = target)       
        
        # cal actor loss
        
        new_actions, log_probs, _, _ = self.pearl.actor_forward(states, latent_z_actor.detach())
        
        with torch.no_grad():
            
            actor_alpha = self.compute_alpha()
        
            detached_q1, detached_q2 = self.pearl.critic_forward(states, new_actions, latent_z_actor)
        
        actor_loss = self.actor_loss(detached_q1, detached_q2, actor_alpha, log_probs)
        
        # total loss
        
        pearl_loss = actor_loss + critic_loss
                
        # total agent loss
        
        self.pearl_optimizer.zero_grad()
        pearl_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.pearl.parameters(), max_norm = 0.5)
        self.pearl_optimizer.step()
        self.pearl_scheduler.step()
        
        # cal alpha loss
        
        with torch.no_grad():
            
            _, log_probs_alpha, _, _ = self.pearl.actor_forward(states, latent_z_actor.detach())
                
        alpha_loss = self.alpha_loss(log_probs_alpha)
        
        # update alpha
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm = 0.5)
        self.alpha_optimizer.step()
        
        # soft update
        
        self.soft_update(target = self.target_critic.critic_1, source = self.pearl.critic_mlp)
        self.soft_update(self.target_critic.critic_2, self.pearl.critic_mlp_2)
        self.soft_update(self.target_critic.critic_head_1, self.pearl.critic_head)
        self.soft_update(self.target_critic.critic_head_2, self.pearl.critic_head_2) 
        
        return actor_loss.item(), critic_loss.item(), alpha.item()


### S E T U P

In [20]:
# hyper params

gamma = 0.99
tau = 0.005
entropy_scalar = 1.5
kl_coef = 0.01

# setup

LOSS_FUNCTION = loss_func(gamma, tau, entropy_scalar, kl_coef)


### T R A I N I N G

In [None]:
from os import write


def train_model(num_tasks, epochs, mini_batch, batch_size, fixed_length, META_RUNNER = META_RUNNER, LOSS_FUNCTION = LOSS_FUNCTION, buffer = buffer):
    
    PEARL.train()
    PEARL_ENCODER.train()
    
    for epoch in range(epochs):
        
        total_actor_loss, total_critic_loss = 0.0, 0.0
        
        for _ in range(mini_batch):
            
            buffer.clear()
            
            META_RUNNER.run(num_tasks)
            
            actor_loss, critic_loss, alpha = LOSS_FUNCTION.update(batch_size, fixed_length)
            
            total_actor_loss += actor_loss
            total_critic_loss += critic_loss
            
        avg_actor_loss = total_actor_loss / mini_batch
        avg_critic_loss = total_critic_loss / mini_batch
        
        writer.add_scalar('Actor loss', avg_actor_loss, epoch)
        writer.add_scalar('Critic loss', avg_critic_loss, epoch)
        
        writer.flush()
        
        
        print(f'epoch: {epoch} | avg actor loss: {avg_actor_loss:.3f} | avg critic loss: {avg_critic_loss:.3f}')


In [22]:
train_model(num_tasks = 10, epochs = 10, mini_batch = 64, batch_size = 256, fixed_length = 256)


epoch: 0 | avg actor loss: -11.518952599726617 | avg critic loss: 20.772311560809612
epoch: 1 | avg actor loss: -16.364063054323196 | avg critic loss: 10.611696988344193
epoch: 2 | avg actor loss: -18.82680258154869 | avg critic loss: 7.727895721793175
epoch: 3 | avg actor loss: -22.1043761074543 | avg critic loss: 3.7576239332556725
epoch: 4 | avg actor loss: -23.777189791202545 | avg critic loss: 2.1902577728033066
epoch: 5 | avg actor loss: -25.13528737425804 | avg critic loss: 1.835844200104475
epoch: 6 | avg actor loss: -26.913690745830536 | avg critic loss: 1.4302853904664516
epoch: 7 | avg actor loss: -28.622118592262268 | avg critic loss: 1.1147063923999667
epoch: 8 | avg actor loss: -29.77080574631691 | avg critic loss: 0.9762753564864397
epoch: 9 | avg actor loss: -31.36554816365242 | avg critic loss: 0.8422545026987791
