# RL²

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math

from torch.utils.tensorboard import SummaryWriter

import numpy as np

import gymnasium as gym


### D E V I C E

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### L O G G I N G

In [3]:
writer = SummaryWriter(log_dir = './runs/ppo_meta_rl')


### H E L P E R

In [4]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)

def safe_stack(x):
    
    return torch.stack(x).to(device)


### M E T A - E N V

In [5]:
class Walker2DMetaEnv:
    def __init__(self, task_batch=5):
        self.base_env = gym.make("Walker2d-v5")
        self.task_batch = task_batch
        self.target_velocity = 1.0

    def sample_tasks(self, num_tasks):
        self.tasks = []
        for _ in range(num_tasks):
            gravity = np.random.uniform(5.0, 15.0)
            torso_mass = np.random.uniform(1.0, 5.0)
            target_velocity = np.random.uniform(0.5, 3.0)
            self.tasks.append((gravity, torso_mass, target_velocity))
            
        return self.tasks
    

    def set_task(self, task):
        
        raw_env = self.base_env.unwrapped
        raw_env.model.opt.gravity[-1] = -task[0]
        raw_env.model.body_mass[1] = task[1]
        self.target_velocity = task[2]

    def reset(self, task_idx=0):
        self.set_task(self.tasks[task_idx])
        obs, _ = self.base_env.reset()
        return obs

    def step(self, action):
        obs, reward, terminated, truncated, info = self.base_env.step(action)
        raw_env = self.base_env.unwrapped
        vel = raw_env.data.qvel[0]
        reward -= 0.5 * abs(vel - self.target_velocity)
        done = terminated or truncated
        return obs, reward, done, info
    
    def get_numbers(self):
        
        state_dim = self.base_env.observation_space.shape[0]
        action_dim = self.base_env.action_space.shape[0]
        max_action = self.base_env.action_space.high[0]
        rewards_dim = 1
        
        return state_dim, action_dim, max_action, rewards_dim

    def close(self):
        self.base_env.close()


### T E S T - E N V

In [6]:
env = Walker2DMetaEnv(task_batch=3)

env.sample_tasks(num_tasks = 4)

obs = env.reset(task_idx=0)

print("Initial obs shape:", np.shape(obs))

state_dim, action_dim, max_action, rewards_dim = env.get_numbers()

print(f'state dim: {state_dim} |'
      f' action dim: {action_dim} |'
      f' max action: {max_action} |'
      f' rewards dim: {rewards_dim}')


Initial obs shape: (17,)
state dim: 17 | action dim: 6 | max action: 1.0 | rewards dim: 1


### A S S E M B L Y

In [7]:
head_1 = 128
head_2 = 256
head_3 = 256
head_4 = 256
head_5 = 128

hidden_size = 128
hidden_size_2 = 256


### F E A T U R E 

In [8]:
class Feature_Extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_size = hidden_size, hidden_size_2 = hidden_size_2):
        super(Feature_Extractor, self).__init__()
        
        self.extract = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.SiLU(),
            
            
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size_2),
            nn.SiLU(),
            
            nn.LayerNorm(hidden_size_2),
            nn.Linear(hidden_size_2, hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, output_dim),
            nn.SiLU()
        )
        
    def forward(self, x):
        
        return self.extract(x)


### R E C U R R E N T -  P O L I C Y

In [9]:
class recurrent_policy(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, rewards_dim = rewards_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, head_5 = head_5, max_action = max_action):
        super(recurrent_policy, self).__init__()
        
        # feature
        
        self.feature = Feature_Extractor(state_dim + action_dim + rewards_dim, head_1)
        
        # norm
        
        self.norm = nn.LayerNorm(head_1)
        
        # seperate LSTM Layers
        
        self.actor_lstm = nn.LSTM(head_1, head_2, num_layers = 2, batch_first = True)
        
        self.critic_lstm = nn.LSTM(head_1, head_2, num_layers = 2, batch_first = True)
        
        # mlp
        
        def create_mlp():
            
            process = nn.Sequential(
                
                nn.Linear(head_2, head_3),
                nn.SiLU(),
                
                nn.LayerNorm(head_3),
                nn.Linear(head_3, head_4),
                nn.SiLU(),
                
                nn.Linear(head_4, head_5),
                nn.SiLU()
            )
            
            return process
            
        # post feature
        
        self.actor_post_feature = create_mlp()
        self.critic_post_feature = create_mlp()
        
        # mu and log head
        
        self.mu = nn.Linear(head_5, action_dim)
        self.log_std = nn.Linear(head_5, action_dim)
        
        # critic head
        
        self.critic_head = nn.Linear(head_5, 1)
        
        # max action 
        
        self.max_action = max_action
        
        # Stabilization
        
        self.apply(self.init_weights)
        
        
    def init_weights(self, module):
            
        if isinstance(module, nn.Linear):
                
            nn.init.kaiming_normal_(module.weight, a = 0, nonlinearity = 'relu')
                
            if module.bias is not None:
                    
                nn.init.zeros_(module.bias)
                    
                    
        elif isinstance(module, nn.LSTM):
                
            for name, param in module.named_parameters():
                    
                if 'weights_ih' in name:
                        
                    nn.init.kaiming_uniform_(param, math.sqrt(5))
                        
                elif 'weights_hh' in name:
                        
                    nn.init.orthogonal_(param)
                        
                elif 'bias' in name:
                        
                    nn.init.zeros_(param)
        
    def forward(self, state, prev_action, prev_reward, actor_memory = None, critic_memory = None):
        
        # cat
        
        cat = torch.cat([state, prev_action, prev_reward], dim = -1)
        
        # feature 
        
        feature = self.feature(cat)
        
        # norm
        
        norm = self.norm(feature)
        
        # lstm layers
        
        if norm.dim() == 2:
            norm = norm.unsqueeze(1)
        
            actor_lstm_out, h_a = self.actor_lstm(norm, actor_memory)
            critic_lstm_out, h_c = self.critic_lstm(norm, critic_memory)
            
            actor_lstm_out = actor_lstm_out.squeeze(1)
            critic_lstm_out = critic_lstm_out.squeeze(1)
            
        else:
            
            actor_lstm_out, h_a = self.actor_lstm(norm, actor_memory)
            critic_lstm_out, h_c = self.critic_lstm(norm, critic_memory)
        
        # post feature
        
        actor_post_feature = self.actor_post_feature(actor_lstm_out)
        
        critic_post_feature = self.critic_post_feature(critic_lstm_out)
        
        # critic head
        
        critic_val = self.critic_head(critic_post_feature)
        
        # actor head
        
        mu = self.mu(actor_post_feature)
        log_std = self.log_std(actor_post_feature)
        
        log_std = torch.clamp(log_std, -10, 2)
        std = torch.exp(log_std)
        
        dist = torch.distributions.Normal(mu, std)
        
        z = dist.rsample()
        tanh_z = torch.tanh(z)
        action = tanh_z * self.max_action
        
        log_prob = dist.log_prob(z).sum(dim = -1, keepdim = True)
        
        squash = (1 - tanh_z.pow(2) + 1e-6).log().sum(dim = -1, keepdim = True)
        
        log_prob = log_prob - squash
        
        return action, log_prob, mu, log_std, critic_val, h_a, h_c
    
    
    def init_hidden(self, ):
        
        h_a = torch.zeros(2, 1, self.actor_lstm.hidden_size).to(next(self.parameters()))
        c_a = torch.zeros(2, 1, self.actor_lstm.hidden_size).to(next(self.parameters()))
        
        h_c = torch.zeros(2, 1, self.critic_lstm.hidden_size).to(next(self.parameters()))
        c_c = torch.zeros(2, 1, self.critic_lstm.hidden_size).to(next(self.parameters()))
        
        return (h_a, c_a), (h_c, c_c)


### S E T U P 

In [10]:
RECURRENT_NETWORK = recurrent_policy().to(device)

print(RECURRENT_NETWORK)


recurrent_policy(
  (feature): Feature_Extractor(
    (extract): Sequential(
      (0): Linear(in_features=24, out_features=128, bias=True)
      (1): SiLU()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=128, out_features=256, bias=True)
      (4): SiLU()
      (5): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (6): Linear(in_features=256, out_features=128, bias=True)
      (7): SiLU()
      (8): Linear(in_features=128, out_features=128, bias=True)
      (9): SiLU()
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (actor_lstm): LSTM(128, 256, num_layers=2, batch_first=True)
  (critic_lstm): LSTM(128, 256, num_layers=2, batch_first=True)
  (actor_post_feature): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): SiLU()
    (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): SiLU()
    (5): L

### B U F F E R

In [11]:
class roller_buffer:
    
    def __init__(self, max_episodes):
        
        self.max_episodes = max_episodes
        self.episodes = []
        self.current_episode = []
        
        
    def add(self, state, action, log_prob, reward, done, h, c, next_state):
        
        # convert to tensor
        
        state = safe_tensor(state)
        action = action.detach()
        log_prob = log_prob.detach()
        reward = safe_tensor(reward)
        done = safe_tensor(done)
        next_state = safe_tensor(next_state)
        
        h_a = h[0]
        c_a = h[1]
        h_c = c[0]
        c_c = c[1]

        self.current_episode.append({
            
            'states': state,
            'actions': action,
            'log_probs': log_prob,
            'rewards': reward,
            'dones': done,
            'h_a': h_a,
            'c_a': c_a,
            'h_c': h_c,
            'c_c': c_c,
            'next_states': next_state
        })
        
        if done.item() == 1:
            
            self.episodes.append(self.current_episode)
            self.current_episode = []
            
            if len(self.episodes) > self.max_episodes:
                
                self.episodes.pop(0)
        
    def sample(self, seq_length, batch_size):
        
        segments = []
        masks = []
        
        for _ in range(batch_size):
            
            ep = np.random.choice(self.episodes)
            
            if len(ep) < seq_length:
                
                padded, mask = self.pad_episode(ep, seq_length)
                segments.append(padded)
                masks.append(mask)
                
            else:
                
                start = np.random.randint(0, len(ep) - seq_length + 1)
                segment = ep[start: start + seq_length]
                mask = torch.ones(seq_length, dtype = torch.float32)
                segments.append(segment)
                masks.append(mask)       
                
        batch = {
            
            'states': safe_stack([safe_stack([s['states'] for s in seg]) for seg in segments]),
            'actions': safe_stack([safe_stack([s['actions'] for s in seg]) for seg in segments]),
            'log_probs': safe_stack([safe_stack([s['log_probs'] for s in seg]) for seg in segments]),
            'rewards': safe_stack([safe_stack([s['rewards'] for s in seg]) for seg in segments]),
            'dones': safe_stack([safe_stack([s['dones'] for s in seg]) for seg in segments]),
            'h_a': torch.stack([torch.stack([s['h_a'] for s in seg], dim=0) for seg in segments], dim=1),
            'c_a': torch.stack([torch.stack([s['c_a'] for s in seg], dim=0) for seg in segments], dim=1),
            'h_c': torch.stack([torch.stack([s['h_c'] for s in seg], dim=0) for seg in segments], dim=1),
            'c_c': torch.stack([torch.stack([s['c_c'] for s in seg], dim=0) for seg in segments], dim=1),
            'next_states': safe_stack([safe_stack([s['next_states'] for s in seg]) for seg in segments]),
            'masks': safe_stack(masks)
        }
        
        return batch
                
    def pad_episode(self, ep, seq_length):
        
        pad_length = seq_length - len(ep)
        
        last_step = ep[-1]
        
        pad_step = {}
        
        for k, v in last_step.items():
            if torch.is_tensor(v):
                pad_step[k] = torch.zeros_like(v)
            elif isinstance(v, tuple):  # handle LSTM hidden states
                pad_step[k] = tuple(torch.zeros_like(t) for t in v)
            else:
                pad_step[k] = v 
                
        mask = torch.cat([
            
            torch.ones(len(ep), dtype = torch.float32),
            torch.zeros(pad_length, dtype = torch.float32)
        ])
        
        return ep + [pad_step] * pad_length, mask
        
        
    def clear(self):
        
        ''' a logic to clear enough data to do no make the agent to suffocate and explode my ram'''
        
        self.episodes.clear()
        self.current_episode.clear()


### S E T U P


In [12]:
### setup

max_episodes = 10

buffer = roller_buffer(max_episodes)


### M E T A - E P I S O D E - R U N N E R

In [13]:
class meta_episode_runner:
    
    def __init__(self, max_episode_length, buffer = buffer, agent = RECURRENT_NETWORK, env = env):
        
        self.agent = agent
        self.buffer = buffer
        self.env = env
        self.max_episode_length = max_episode_length
        
        
    def run(self, num_tasks):
        
        tasks = self.env.sample_tasks(num_tasks)
        
        for task in tasks:
            
            self.env.set_task(task)
            obs = self.env.reset()
            obs = safe_tensor(obs)
            if obs.dim() == 1: obs = obs.unsqueeze(0)
            
            # Init LSTM states
            
            (h_a, c_a), (h_c, c_c) = self.agent.init_hidden()

            
            prev_action = torch.zeros(1, action_dim).to(device)
            prev_reward = torch.zeros(1, 1).to(device)
            
            # episode data
            
            for t in range(self.max_episode_length):
                
                with torch.no_grad():
                    
                    action, log_prob, _, _, _, next_h, next_c = self.agent(obs, prev_action, prev_reward, actor_memory = (h_a, c_a), critic_memory = (h_c, c_c))
                    
                action_np = action.cpu().numpy()[0]
                    
                next_obs, reward, done, _ = self.env.step(action_np)
                next_obs = safe_tensor(next_obs)
                if next_obs.dim() == 1: next_obs = next_obs.unsqueeze(0)
                h = (h_a, c_a) 
                c = (h_c, c_c)
                    
                self.buffer.add(obs.squeeze(0), action.squeeze(0), log_prob.squeeze(0), reward, done, h, c, next_obs.squeeze(0))
                    
                obs = next_obs
                h_a, c_a = next_h
                h_c, c_c = next_c
                     
                if done:
                        
                    break
                    

### S E T U P

In [14]:
max_episode_length = 512


META_RUNNER = meta_episode_runner(max_episode_length)


### O P T I M I Z E R 

In [15]:
### shared feature extractor

shared_lr = 1e-4
actor_lr = 3e-4
critic_lr = 5e-4
T_max = 10

# params

actor_param = list(RECURRENT_NETWORK.actor_lstm.parameters()) + \
              list(RECURRENT_NETWORK.actor_post_feature.parameters()) + \
              list(RECURRENT_NETWORK.mu.parameters()) + \
              list(RECURRENT_NETWORK.log_std.parameters())
              
critic_param = list(RECURRENT_NETWORK.critic_lstm.parameters()) + \
               list(RECURRENT_NETWORK.critic_post_feature.parameters()) + \
               list(RECURRENT_NETWORK.critic_head.parameters())
               
# optimizer

OPTIMIZER = optim.AdamW([
    
    {'params': RECURRENT_NETWORK.feature.parameters(), 'lr': shared_lr},
    {'params': actor_param, 'lr': actor_lr},
    {'params': critic_param, 'lr': critic_lr}
    
], weight_decay = 0)

# Scheduler

SCHEDULER = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max, eta_min = 1e-6)


### L O S S - F U N C

In [16]:
class loss_func:
    
    def __init__(self, gamma, gae_lam, entropy_coef, clip_epsilon, value_coef, RECURRENT_NETWORK = RECURRENT_NETWORK, buffer = buffer, OPTIMIZER = OPTIMIZER, SCHEDULER = SCHEDULER):
        
        # network
        
        self.network = RECURRENT_NETWORK
        
        # optimizer
        
        self.optimizer = OPTIMIZER
        self.scheduler = SCHEDULER
        
        # buffer
        
        self.buffer = buffer
        
        # hyper params
        
        self.gamma = gamma
        self.gae_lam = gae_lam
        self.entropy_coef = entropy_coef
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        
    def compute_gae(self, rewards, dones, value, last_value):
        
        values = torch.cat([value, last_value], dim = 1).to(device)
        
        gae = 0
        advantages = []
        
        for step in reversed(range(rewards.shape[1])):
            
            delta = rewards[:, step] + self.gamma * (1 - dones[:, step]) * values[:, step + 1] - values[:, step]
            gae = delta + self.gamma * self.gae_lam * (1 - dones[:, step]) * gae
            
            advantages.insert(0, gae)
        
        advantages = safe_stack(advantages)
        
        returns = advantages + value
        
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-7)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        return advantages.detach(), returns.detach()
        
    def critic_loss(self, value, returns, mask):
        
        error = F.mse_loss(value, returns)
    
        if mask is not None:
            
            error = error * mask
            error = error.sum() / mask.sum()
            
        return error
    
    def compute_surrogate_loss(self, old_log_probs, log_probs, advantages):
        
        ratio = torch.exp(log_probs - old_log_probs)
        
        surr_1 = (ratio * advantages)
        surr_2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
        
        surrogate_loss = -torch.min(surr_1, surr_2).mean()
        
        return surrogate_loss
    
    def shape_corrector(self, rewards, dones, h_a, h_c, c_a, c_c):
        
        if rewards.dim() == 2: rewards = rewards.unsqueeze(2)
        if dones.dim() == 2: dones = dones.unsqueeze(2)
        if h_a.dim() == 5: h_a = h_a.squeeze(3)
        if h_c.dim() == 5: h_c = h_c.squeeze(3)
        if c_a.dim() == 5: c_a = c_a.squeeze(3)
        if c_c.dim() == 5: c_c = c_c.squeeze(3)
        
        h_a = h_a[:, 0].permute(1, 0, 2).contiguous()
        c_a = c_a[:, 0].permute(1, 0, 2).contiguous()
        h_c = h_c[:, 0].permute(1, 0, 2).contiguous()
        c_c = c_c[:, 0].permute(1, 0, 2).contiguous()
        
        return rewards, dones, h_a, h_c, c_a, c_c
    
    def update(self, seq_length, batch_size):
        
        # sample the batch 
        
        batch = self.buffer.sample(seq_length, batch_size)
        
        states = batch['states']
        actions = batch['actions']
        old_log_probs = batch['log_probs']
        rewards = batch['rewards']
        dones = batch['dones']
        h_a = batch['h_a']
        c_a = batch['c_a']
        h_c = batch['h_c']
        c_c = batch['c_c']
        masks = batch['masks']
        next_states = batch['next_states']
        
        # shape check
        
        rewards, dones, h_a, h_c, c_a, c_c = self.shape_corrector(rewards, dones, h_a, h_c, c_a, c_c)
        
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
        
        final_state = next_states[:, -1:]
        final_prev_action = actions[:, -1:]
        
        final_prev_reward = rewards[:, -1:]
        
        h = (h_a, c_a)
        c = (h_c, c_c)
        
        # compute value, current log probs, last_value
        
        with torch.no_grad():
            
            _, _, _, _, last_value, _, _ = self.network(final_state, final_prev_action, final_prev_reward, h, c)
            
        _, log_probs, mu, log_std, values, _, _ = self.network(states, actions, rewards, h, c)
        
        # compute GAE
        
        advantages, returns = self.compute_gae(rewards, dones, values, last_value)
        
        # compute critic loss
        
        critic_loss = self.critic_loss(values, returns, masks)
        critic_loss = critic_loss * self.value_coef
        
        # compute policy loss
        
        policy_loss = self.compute_surrogate_loss(old_log_probs, log_probs, advantages)
        
        # entropy loss
        
        dist = torch.distributions.Normal(mu, torch.exp(log_std))
        entropy = dist.entropy().mean()
        
        total_loss = policy_loss- self.entropy_coef * entropy + critic_loss
        
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm = 0.1)
        self.optimizer.step()
        self.scheduler.step()
        
        
        return total_loss.item(), policy_loss.item(), critic_loss.item()       
            

### S E T U P

In [17]:
# hyper params

gamma = 0.99
gae_lam = 0.95
entropy_coef = 0.01
value_coef = 0.25
clip_epsilon = 0.2

# setup

LOSS_FUNCTION = loss_func(gamma, gae_lam, entropy_coef, clip_epsilon, value_coef)


### T R A I N I N G

In [18]:
def train(epochs, mini_batch, num_tasks, seq_length, batch_size, RECURRENT_NETWORK = RECURRENT_NETWORK, LOSS_FUNCTION = LOSS_FUNCTION, META_RUNNER = META_RUNNER):
    
    RECURRENT_NETWORK.train()
    
    for epoch in range(epochs):
    
        META_RUNNER.run(num_tasks)
    
        total_policy_loss, total_value_loss, total_agent_loss = 0.0, 0.0, 0.0
        
        for _ in range(mini_batch):
            
            total_loss, policy_loss, critic_loss = LOSS_FUNCTION.update(seq_length, batch_size)
    
            total_policy_loss += policy_loss
            total_value_loss += critic_loss
            total_agent_loss += total_loss
            
        avg_total_loss = total_agent_loss / mini_batch
        avg_policy_loss = total_policy_loss / mini_batch
        avg_value_loss = total_value_loss / mini_batch
        
        writer.add_scalar('Agent loss', avg_total_loss, epoch)
        writer.add_scalar('Policy loss', avg_policy_loss, epoch)
        writer.add_scalar('Value loss', avg_value_loss, epoch)
        
        writer.flush()
        
        print(f'epoch: {epoch} | policy loss: {avg_policy_loss:.3f} | avg value loss: {avg_value_loss:.3f}')
        

### S E T U P

In [19]:
epochs = 10
mini_batch = 64
batch_size = 256
seq_length = 256
num_tasks = 4

train(epochs, mini_batch, num_tasks, seq_length, batch_size)


  ep = np.random.choice(self.episodes)


epoch: 0 | policy loss: 37921.911 | avg value loss: 0.357
epoch: 1 | policy loss: 0.151 | avg value loss: 0.258
epoch: 2 | policy loss: 0.119 | avg value loss: 0.250
epoch: 3 | policy loss: 0.105 | avg value loss: 0.249
epoch: 4 | policy loss: 0.066 | avg value loss: 0.250
epoch: 5 | policy loss: 0.061 | avg value loss: 0.251
epoch: 6 | policy loss: 0.055 | avg value loss: 0.256
epoch: 7 | policy loss: 0.065 | avg value loss: 0.213
epoch: 8 | policy loss: 0.064 | avg value loss: 0.242
epoch: 9 | policy loss: 0.051 | avg value loss: 0.218
