# **R2D2**

In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import gymnasium as gym

import math

import numpy as np

import random

from torch.utils.tensorboard import SummaryWriter


## **DEVICE**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Device: {device}')


Device: cuda


#### **LOGGING**

In [3]:
writer = SummaryWriter(log_dir = './runs/R2D2')


### **HELPER**

In [4]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)


### **META - ENV**

In [5]:
class meta_env:
    
    def __init__(self, env_name):
        
        self.base_env = gym.make(env_name)
        
    def sample_tasks(self, num_tasks):
        
        self.tasks = []
        
        for _ in range(num_tasks):
            
            gravity = np.random.uniform(5.0, 15.0)
            torso_mass = np.random.uniform(1.0, 5.0)
            target_velocity = np.random.uniform(0.5, 3.0)
            
            self.tasks.append((gravity, torso_mass, target_velocity))
            
        return self.tasks
    
    def set_task(self, task):
        
        self.base_env.unwrapped.model.opt.gravity[-1] = -task[0]
        self.base_env.unwrapped.model.body_mass[1] = task[1]
        self.target_velocity = task[2]
        
    def reset(self):
        
        obs = self.base_env.reset()
        
        if isinstance(obs, tuple):
            
            obs = obs[0]
            
        return obs
    
    def step(self, action):
        
        next_obs, reward, termination, timeout, info = self.base_env.step(action)
        
        vel = self.base_env.unwrapped.data.qvel[0]
        reward -= 0.5 * abs(vel - self.target_velocity)
        
        done = termination or timeout
        
        return next_obs, reward, done, info
    
    def get_number(self):
        
        state_dim = self.base_env.observation_space.shape[0]
        action_dim = self.base_env.action_space.shape[0]
        max_action = self.base_env.action_space.high[0]
        
        reward_dim = 1
        
        return state_dim, action_dim, max_action, reward_dim
    
    def close(self):
        
        self.base_env.close()


### **SET UP**

In [6]:
META_ENV = meta_env('Walker2d-v5')

tasks = META_ENV.sample_tasks(3)

for task in tasks:
    
    META_ENV.set_task(task)
    obs = META_ENV.reset()
    
state_dim, action_dim, max_action, reward_dim = META_ENV.get_number()
    
print(f'obs: {obs.shape}')
print()
print(f'state dim: {state_dim} | action dim: {action_dim} | max action: {max_action}')  
print()
print(f'Tasks: {tasks}')


obs: (17,)

state dim: 17 | action dim: 6 | max action: 1.0

Tasks: [(6.520299342867645, 1.729326276002701, 0.9123206413768365), (13.534885322830993, 1.5328676850417442, 0.941549114309512), (5.750121944926288, 2.1808702862732456, 1.6562976715008144)]


### **ASSEMBLY**

In [7]:

head_1 = 128
head_2 = 256
head_3 = 256
head_4 = 128

hidden_size = 64
hidden_size_2 = 128
hidden_size_3 = 256


### **HYPER X**

In [8]:
class hyper_x(nn.Module):
    
    def __init__(self, state_dim = state_dim, hidden_size = hidden_size, hidden_size_2 = hidden_size_2):
        super(hyper_x, self).__init__()
        
        # input dim
        
        input_dim = state_dim
        
        # hyper x
        
        self.hyper = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size_3),
            nn.LayerNorm(hidden_size_3),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size_3, hidden_size_3),
            nn.LayerNorm(hidden_size_3),
            nn.SiLU(),
            
            nn.Linear(hidden_size_3, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU()
        )
        
    def forward(self, state):
        
        hyper = self.hyper(state)
        
        return hyper
    

### **R2D2 DESIGN**

In [9]:
class r2d2(nn.Module):
    
    def __init__(self, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action, action_dim = action_dim, hidden_size_2 = hidden_size_2):
        super(r2d2, self).__init__()
        
        # max action
        
        self.max_action = max_action
        
        # critic input
        
        critic_input = state_dim
        
        self.critic_pre_process = nn.Linear(critic_input, hidden_size_2)
        self.critic_pre_norm = nn.LayerNorm(hidden_size_2)
        
        # hyper x
        
        self.hyper = hyper_x()
        
        # norm
        
        self.norm = nn.LayerNorm(hidden_size_2)
        
        # process more
        
        self.process = nn.Linear(hidden_size_2, head_1)
        self.process_norm = nn.LayerNorm(head_1)
        
        # actor lstm
        
        self.actor_lstm = nn.LSTM(head_1, head_1, num_layers = 2, batch_first = True)
        
        # critic lstm
        
        self.critic_lstm = nn.LSTM(head_1, head_1, num_layers = 2, batch_first = True)
        
        # post norm 
        
        self.post_lstm_actor_norm = nn.LayerNorm(head_1)
        self.post_lstm_critic_norm = nn.LayerNorm(head_1)
        
        # actor mlp 
        
        self.actor_mlp = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.LayerNorm(head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.LayerNorm(head_4),
            nn.SiLU()
        )
        
        # critic mlp
        
        self.critic_mlp = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.LayerNorm(head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.LayerNorm(head_4),
            nn.SiLU()
        )
        
        # actor mu and log std head
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
        # critic head
        
        self.critic_head = nn.Linear(head_4, 1)
        
        # normalization
        
        self.apply(self.init_weights)
        
        
    def init_weights(self, m):
        
        if isinstance(m, nn.Linear):
            
            nn.init.orthogonal_(m.weight)
            
            if m.bias is not None:
                
                nn.init.zeros_(m.bias)
                
        elif isinstance(m, nn.LSTM):
                
            for name, param in m.named_parameters():
                    
                if 'weight_ih' in name:
                        
                    nn.init.kaiming_uniform_(param, math.sqrt(5))
                        
                elif 'weight_hh' in name:
                        
                    nn.init.orthogonal_(param)
                        
                elif 'bias' in name:
                        
                    nn.init.zeros_(param)
                    
                    n = param.size(0)
                    start, end = n // 4, n // 2
                    param.data[start:end].fill_(1.0)
                    
    def actor_forward(self, state, actor_hidden_memory = None):
        
        # state -> hyper
        
        hyper = self.hyper(state)
        
        # norm
        
        hyper_norm = self.norm(hyper)
        
        # process
        
        process = self.process(hyper_norm)
        
        # process norm
        
        process_norm = self.process_norm(process)
        
        # lstm input
        
        if process_norm.dim() == 2:
            
            process_norm = process_norm.unsqueeze(1)
        
            actor_lstm_out, h_a = self.actor_lstm(process_norm, actor_hidden_memory)
            
            actor_lstm_out = actor_lstm_out.squeeze(1)
            
        else:
            
            actor_lstm_out, h_a = self.actor_lstm(process_norm, actor_hidden_memory)
            
        # post norm
        
        post_lstm_norm = self.post_lstm_actor_norm(actor_lstm_out)
            
        # actor mlp
        
        actor_mlp = self.actor_mlp(post_lstm_norm)
        
        # mu and log head
        
        mu = self.mu(actor_mlp)
        log_std = self.log_std(actor_mlp)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        
        # action
        
        dist = torch.distributions.Normal(mu, std)
        z = dist.rsample()
        tanh_z = torch.tanh(z)
        action = tanh_z * self.max_action
        
        # log prob
        
        log_prob = dist.log_prob(z)
        squash = torch.log(1 - tanh_z.pow(2) + 1e-6)
        log_prob = log_prob - squash
        log_prob = log_prob.sum(dim = -1, keepdim = True)
        
        # entropy
        
        entropy = dist.entropy().sum(dim = -1)
        
        return action, log_prob, entropy, h_a
    
    def critic_forward(self, state, critic_hidden_memory):
        
        # pre process
        
        pre_process = self.critic_pre_process(state)
        
        # pre norm
        
        pre_norm = self.critic_pre_norm(pre_process)
        
        # critic lstm
        
        if pre_norm.dim() == 2: 
            
            pre_norm = pre_norm.unsqueeze(1)

            critic_lstm_out, h_c = self.critic_lstm(pre_norm, critic_hidden_memory)
            
            critic_lstm_out = critic_lstm_out.squeeze(1)
            
        else:
            
            critic_lstm_out, h_c = self.critic_lstm(pre_norm, critic_hidden_memory)
            
        # critic post norm
        
        critic_post_norm = self.post_lstm_critic_norm(critic_lstm_out)
        
        # critic mlp
        
        critic_mlp = self.critic_mlp(critic_post_norm)
        
        # critic head
        
        critic_head = self.critic_head(critic_mlp)
        
        return critic_head, h_c   
    
    def init_hidden(self):
        
        h_a = torch.zeros(self.actor_lstm.num_layers, 1, self.actor_lstm.hidden_size).to(next(self.parameters()))
        c_a = torch.zeros(self.actor_lstm.num_layers, 1, self.actor_lstm.hidden_size).to(next(self.parameters()))
        
        h_c = torch.zeros(self.actor_lstm.num_layers, 1, self.critic_lstm.hidden_size).to(next(self.parameters()))
        c_c = torch.zeros(self.actor_lstm.num_layers, 1, self.critic_lstm.hidden_size).to(next(self.parameters()))
        
        return (h_a, c_a), (h_c, c_c)     
    

### **SET UP**

In [10]:
# r2d2

R2D2_NETWORK = r2d2().to(device)

print(R2D2_NETWORK)


r2d2(
  (critic_pre_process): Linear(in_features=17, out_features=128, bias=True)
  (critic_pre_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (hyper): hyper_x(
    (hyper): Sequential(
      (0): Linear(in_features=17, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=64, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=128, out_features=256, bias=True)
      (7): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (8): SiLU()
      (9): Linear(in_features=256, out_features=256, bias=True)
      (10): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (11): SiLU()
      (12): Linear(in_features=256, out_features=128, bias=True)
      (13): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (14): SiLU()
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_a

### **OPTIMIZER**

In [11]:
# lr

hyper_lr = 1e-5
actor_lr = 1e-4
critic_lr = 3e-4

# T max

warmup = 5
T_max = 10

# param

actor_param = list(R2D2_NETWORK.norm.parameters()) + \
              list(R2D2_NETWORK.process.parameters()) + \
              list(R2D2_NETWORK.process_norm.parameters()) + \
              list(R2D2_NETWORK.actor_lstm.parameters()) + \
              list(R2D2_NETWORK.post_lstm_actor_norm.parameters()) + \
              list(R2D2_NETWORK.actor_mlp.parameters()) + \
              list(R2D2_NETWORK.mu.parameters()) + \
              list(R2D2_NETWORK.log_std.parameters()) 
              
hyper_param = list(R2D2_NETWORK.hyper.parameters())

critic_param = list(R2D2_NETWORK.critic_pre_process.parameters()) + \
               list(R2D2_NETWORK.critic_pre_norm.parameters()) + \
               list(R2D2_NETWORK.critic_lstm.parameters()) + \
               list(R2D2_NETWORK.post_lstm_critic_norm.parameters()) + \
               list(R2D2_NETWORK.critic_mlp.parameters()) + \
               list(R2D2_NETWORK.critic_head.parameters())   

# optimizer

OPTIMIZER = optim.AdamW([
    
    {'params': critic_param, 'lr': critic_lr, 'weight_decay': 1e-6},
    {'params': hyper_param, 'lr': hyper_lr, 'weight_decay': 0},
    {'params': actor_param, 'lr': actor_lr, 'weight_decay': 0}
])

# scheduler

warmup_sch = optim.lr_scheduler.LinearLR(OPTIMIZER, 0.1, total_iters = warmup)
cosine_sch = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max = T_max - warmup, eta_min = 1e-6)

SCHEDULER = optim.lr_scheduler.SequentialLR(OPTIMIZER, [warmup_sch, cosine_sch], [warmup])


### **BUFFER**


In [12]:
class meta_buffer:
    
    def __init__(self, max_episodes):
        
        self.max_episodes = max_episodes
        self.episodes = []
        self.current_episode = []
        
    def add(self, state, action, log_prob, reward, done, next_state, actor_memory, critic_memory):
        
        # actor memory
        
        h_a = actor_memory[0]
        c_a = actor_memory[1]
        
        # critic memory
        
        h_c = critic_memory[0]
        c_c = critic_memory[1]
        
        # add to current episode
        
        step = {
            
            'states': safe_tensor(state),
            'actions': safe_tensor(action),
            'log_probs': safe_tensor(log_prob),
            'rewards': safe_tensor(reward),
            'dones': safe_tensor(done),
            'next_states': safe_tensor(next_state),
            'h_a': h_a,
            'c_a': c_a,
            'h_c': h_c,
            'c_c': c_c
            
        }
        
        self.current_episode.append(step)
        
        if safe_tensor(done).item() == 1:
            
            self.episodes.append(self.current_episode)
            self.current_episode = []
            
        if len(self.episodes) > self.max_episodes:
                
            self.episodes.pop(0)
                
                
    def sample(self, batch_size, unroll_len, burn_in):
        
        sampled_ep = random.sample(self.episodes, k = min(batch_size, len(self.episodes)))
        
        segments = []
        
        for ep in sampled_ep:
            
            if len(ep) < unroll_len + burn_in:
                
                seg = self.padding(ep, burn_in, unroll_len)
            
            else:
                
                start_idx = random.randint(0, len(ep) - (unroll_len + burn_in))
                
                seg = ep[start_idx : start_idx + unroll_len + burn_in]
            
            segments.append(seg)
            
        
        def safe_stack(x, seq):
            
            return torch.stack([torch.stack([s[x] for s in seg]) for seg in seq]).to(device)
        
        burn_in_batch = {
            
            'states': safe_stack('states', [seg[:burn_in] for seg in segments]),
            'actions': safe_stack('actions', [seg[:burn_in] for seg in segments]),
            'log_probs': safe_stack('log_probs', [seg[:burn_in] for seg in segments]),
            'rewards': safe_stack('rewards', [seg[:burn_in] for seg in segments]),
            'dones': safe_stack('dones', [seg[:burn_in] for seg in segments]),
            'next_states': safe_stack('next_states', [seg[:burn_in] for seg in segments]),

            'h_a': torch.stack([seg[0]['h_a'] for seg in segments]).to(device),
            'c_a': torch.stack([seg[0]['c_a'] for seg in segments]).to(device),
            'h_c': torch.stack([seg[0]['h_c'] for seg in segments]).to(device),
            'c_c': torch.stack([seg[0]['c_c'] for seg in segments]).to(device)
            
        }

        # training portion (after burn-in)
        
        train_batch = {
            
            'states': safe_stack('states', [seg[burn_in:] for seg in segments]),
            'actions': safe_stack('actions', [seg[burn_in:] for seg in segments]),
            'log_probs': safe_stack('log_probs', [seg[burn_in:] for seg in segments]),
            'rewards': safe_stack('rewards', [seg[burn_in:] for seg in segments]),
            'dones': safe_stack('dones', [seg[burn_in:] for seg in segments]),
            'next_states': safe_stack('next_states', [seg[burn_in:] for seg in segments]),

            # hidden states here are NOT reset — burn-in already set them up
            'h_a': torch.stack([seg[burn_in]['h_a'] for seg in segments]).to(device),
            'c_a': torch.stack([seg[burn_in]['c_a'] for seg in segments]).to(device),
            'h_c': torch.stack([seg[burn_in]['h_c'] for seg in segments]).to(device),
            'c_c': torch.stack([seg[burn_in]['c_c'] for seg in segments]).to(device)
            
        }
        
        return burn_in_batch, train_batch
        
    def padding(self, ep, burn_in, unroll_len):
        
        pad_length = (burn_in + unroll_len) - len(ep) 
        
        last_step = ep[-1]
        
        pad_step = {}
        
        for k, v in last_step.items():
            
            if torch.is_tensor(v):
                
                pad_step[k] = v
                
            else:
                
                pad_step[k] = v.clone()
                
        return ep + [pad_step] * pad_length
        


### **SET UP**

In [13]:
# max ep

max_episodes = 500

# setup

META_BUFFER = meta_buffer(max_episodes)


### **META RUNNER**

In [14]:
class meta_runner:
    
    def __init__(self, max_episode_range, R2D2_NETWORK = R2D2_NETWORK, META_BUFFER = META_BUFFER, META_ENV = META_ENV):
        
        self.network = R2D2_NETWORK
        self.buffer = META_BUFFER
        self.env = META_ENV
        self.max_episode_range = max_episode_range
        
    def run(self, num_tasks):
        
        tasks = self.env.sample_tasks(num_tasks)
        
        for task in tasks:
            
            self.env.set_task(task)
            
            obs = self.env.reset()
            
            obs = safe_tensor(obs).unsqueeze(0)
            
            hidden_actor_memory, hidden_critic_memory = self.network.init_hidden()
            
            for _ in range(self.max_episode_range):
                
                action, log_prob, _, actor_memory = self.network.actor_forward(obs, hidden_actor_memory)
                
                _, critic_memory = self.network.critic_forward(obs, hidden_critic_memory)

                action_np = action.detach().cpu().numpy()[0]

                next_obs, reward, done, _ = self.env.step(action_np)
                
                next_obs = safe_tensor(next_obs).unsqueeze(0)
                
                self.buffer.add(obs.squeeze(0), action.squeeze(0), log_prob.squeeze(0), [reward], done, next_obs.squeeze(0), hidden_actor_memory, hidden_critic_memory)
                
                obs = next_obs
                hidden_actor_memory = actor_memory
                hidden_critic_memory = critic_memory
                
                if done:
                    
                    break


### **SET UP**

In [15]:
# meta runner

max_episode_range = 512

# setup

META_RUNNER = meta_runner(max_episode_range)


### **LOSS FUNCTION**

In [16]:
class loss_func:
    
    def __init__(self, gamma, gae_lam, entropy_ceof, value_coef, clip_epsilon, OPTIMIZER = OPTIMIZER, SCHEDULER = SCHEDULER, R2D2_NETWORK = R2D2_NETWORK):
        
        # network
        
        self.r2d2 = R2D2_NETWORK
        
        # hyper params
        
        self.gamma = gamma
        self.gae_lam = gae_lam
        self.value_coef = value_coef
        self.entropy_coef = entropy_ceof
        self.clip_epsilon = clip_epsilon
        
        # optimizer
        
        self.optimizer = OPTIMIZER
        self.scheduler = SCHEDULER
        
        # buffer
        
        self.buffer = META_BUFFER
        
    def compute_gae(self, rewards, dones, value, last_value):
        
        values = torch.cat([value, last_value], dim = 1)
        
        gae = 0
        advantages = []
        
        for step in reversed(range(rewards.size(1))):
            
            delta = rewards[:, step] + self.gamma * (1 - dones[:, step]) * values[:, step + 1] - values[:, step]
            gae = delta + self.gae_lam * (1 - dones[:, step]) * gae
            
            advantages.insert(0, gae)
        
        advantages = torch.stack(advantages, dim = 1).to(device)
        
        returns = advantages + value
        
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-7)
        
        returns = safe_tensor(returns)
    
        return advantages.detach(), returns.detach()
    
    def policy_loss(self, old_log_probs, log_probs, advantages, entropy):
        
        ratio = torch.exp(log_probs - old_log_probs)
        
        surr1 = ratio * advantages    
        
        surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
        
        surrogate_loss = -torch.min(surr1, surr2).mean()
        
        policy_loss = surrogate_loss - self.entropy_coef * entropy.mean()
        
        return policy_loss
    
    def value_loss(self, value, returns):
        
        v_loss = F.mse_loss(value, returns)
        
        value_loss = v_loss * self.value_coef
        
        return value_loss
    
    def update(self, batch):
        
        # unpack batch
        
        states = batch['states']
        actions = batch['actions']
        old_log_probs = batch['log_probs']
        rewards = batch['rewards']
        dones = batch['dones']
        next_states = batch['next_states']
        
        h_a = batch['h_a']
        c_a = batch['c_a']
        h_c = batch['h_c']
        c_c = batch['c_c']
        
        # shape correction
        
        last_state = next_states[:, -1:]
        
        h_a, c_a = h_a.squeeze(2), c_a.squeeze(2)
        
        c_c, h_c = c_c.squeeze(2), h_c.squeeze(2)
        
        dones = dones.unsqueeze(2)
        
        # memory allocation
        
        actor_memory = (h_a, c_a)
        critic_memory = (h_c, c_c)
        
        # compute log probs and value
        
        _, log_probs, entropy, _ = self.r2d2.actor_forward(states, actor_memory)
        
        value, _ = self.r2d2.critic_forward(states, critic_memory)
        
        # compute last value
        
        with torch.no_grad():
            
            last_value, _ = self.r2d2.critic_forward(last_state, critic_memory)
                    
        # compute gae
        
        advantages, returns = self.compute_gae(rewards, dones, value, last_value)
        
        # policy loss
        
        policy_loss = self.policy_loss(old_log_probs.detach(), log_probs, advantages, entropy)
        
        # value loss 
        
        value_loss = self.value_loss(value, returns)
        
        # lets total loss and update optimize
        
        Agent_loss = policy_loss + value_loss
        
        self.optimizer.zero_grad()
        Agent_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.r2d2.parameters(), max_norm = 0.5)
        self.optimizer.step()
        self.scheduler.step()
        
        al = Agent_loss.detach().item()
        vl = value_loss.detach().item()
        pl = policy_loss.detach().item()
        
        return al, vl, pl
    

### **SET UP**

In [17]:
# hyper param

value_coef = 0.5
clip_epsilon = 0.2
entropy_coef = 0.01
gamma = 0.99
gae_lam = 0.95

# setup

LOSS_FUNCTION = loss_func(gamma, gae_lam, entropy_coef, value_coef, clip_epsilon)


### **R2D2 LOGIC**

In [18]:
def R2D2_LOGIC(burn_in_batch, NETWORK, train_batch):

    with torch.no_grad():
        
        # Actor burn-in
        
        h_a, c_a = burn_in_batch['h_a'], burn_in_batch['c_a']
        
        h_a, c_a = h_a.squeeze(2), c_a.squeeze(2)
        
        h_a = h_a.permute(1, 0, 2).contiguous()
        c_a = c_a.permute(1, 0, 2).contiguous()
        
        for t in range(burn_in_batch['states'].size(1)): 
            
            _, _, _, (h_a, c_a) = NETWORK.actor_forward(
                burn_in_batch['states'][:, t],
                (h_a, c_a)
            )
        
        # Critic burn-in
        
        h_c, c_c = burn_in_batch['h_c'], burn_in_batch['c_c']
        
        c_c, h_c = c_c.squeeze(2), h_c.squeeze(2)
        
        h_c = h_c.permute(1, 0, 2).contiguous()
        c_c = c_c.permute(1, 0, 2).contiguous()
        
        for t in range(burn_in_batch['states'].size(1)):
            
            _, (h_c, c_c) = NETWORK.critic_forward(
                burn_in_batch['states'][:, t],
                (h_c, c_c)
            )

    # 2. Build the training batch using updated hidden states
    
    batch = {
        
        'states': train_batch['states'],
        'actions': train_batch['actions'],
        'log_probs': train_batch['log_probs'],
        'rewards': train_batch['rewards'],
        'dones': train_batch['dones'],
        'next_states': train_batch['next_states'],
        'h_a': h_a,  # from post-burn-in
        'c_a': c_a,
        'h_c': h_c,
        'c_c': c_c
    }
    
    return batch


### **TRAINING**

In [19]:
def train_loop(meta_iteration, mini_steps, num_tasks, burn_in, unroll_len, batch_size, R2D2_NETWORK = R2D2_NETWORK, LOSS_FUNCTION = LOSS_FUNCTION, META_BUFFER = META_BUFFER, META_RUNNER = META_RUNNER):
    
    for iteration in range(meta_iteration):
        
        META_RUNNER.run(num_tasks)
        
        total_agent_loss, total_policy_loss, total_value_loss = 0.0, 0.0, 0.0
        
        for _ in range(mini_steps):
            
            burn_in_batch, train_batch = META_BUFFER.sample(batch_size, unroll_len, burn_in)
            
            batch = R2D2_LOGIC(burn_in_batch, R2D2_NETWORK, train_batch)
            
            agent_loss, policy_loss, value_loss = LOSS_FUNCTION.update(batch)
            
            total_agent_loss += agent_loss
            total_policy_loss += policy_loss
            total_value_loss += value_loss
            
        avg_agent_loss = total_agent_loss / mini_steps
        avg_policy_loss = total_policy_loss / mini_steps
        avg_value_loss = total_value_loss / mini_steps
        
        writer.add_scalar('Agent loss', avg_agent_loss, iteration)
        writer.add_scalar('Policy loss', avg_policy_loss, iteration)
        writer.add_scalar('Value loss', avg_value_loss, iteration)
        
        writer.flush()
        
        print(f'Epoch: {iteration} | agent loss: {avg_agent_loss:.3f} | policy loss: {avg_policy_loss:.3f} | value loss: {avg_value_loss:.3f}')
            

### **SET UP**

In [20]:
meta_iteration = 10
mini_steps = 64
num_tasks = 10
batch_size = 256            # why 512 cause max ep range is 512
burn_in = 64
unroll_len = 64

train_loop(meta_iteration, mini_steps, num_tasks, burn_in, unroll_len, batch_size)




Epoch: 0 | agent loss: 15.449 | policy loss: 0.644 | value loss: 14.805
Epoch: 1 | agent loss: 0.728 | policy loss: 0.115 | value loss: 0.613
Epoch: 2 | agent loss: 0.467 | policy loss: 0.206 | value loss: 0.261
Epoch: 3 | agent loss: 0.177 | policy loss: 0.032 | value loss: 0.144
Epoch: 4 | agent loss: 0.110 | policy loss: 0.026 | value loss: 0.083
Epoch: 5 | agent loss: 0.101 | policy loss: 0.031 | value loss: 0.070
Epoch: 6 | agent loss: 0.111 | policy loss: 0.015 | value loss: 0.095
Epoch: 7 | agent loss: 0.108 | policy loss: 0.006 | value loss: 0.102
Epoch: 8 | agent loss: 0.150 | policy loss: 0.012 | value loss: 0.138
Epoch: 9 | agent loss: 0.114 | policy loss: 0.007 | value loss: 0.107
