# **P2E - TRPO**

Integrating Trust to learn whats unknown, fist trial to break AI limits

In [244]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np

from collections import deque
import warnings

import gymnasium as gym
from gymnasium.wrappers import RescaleAction

import copy

In [245]:
warnings.filterwarnings(action = 'ignore', category = UserWarning)

## **Env Setup**


In [246]:
env = gym.make('Humanoid-v5')

env = RescaleAction(env, min_action = -1.0, max_action = 1.0)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = env.action_space.high[0]

print(f'State dim: {state_dim} | Action dim: {action_dim} | Max Action: {max_action}')

State dim: 348 | Action dim: 17 | Max Action: 1.0


### **Device Setup**

In [247]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Device using: {device}')

Device using: cuda


## **Encoder**

converting raw states to latent space


In [248]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, latent_dim, head_1, head_2, head_3, head_4):
        super(Encoder, self).__init__()
        
        self.encoder = nn.Sequential(
            
            
            nn.Linear(input_dim, head_1),
            nn.SiLU(),
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, latent_dim)
        )
        
    def forward(self, x):
        
        return self.encoder(x)

## **GRU Based World Model**


In [249]:
class GRU_World_MODEL(nn.Module):
    
    def __init__(self, latent_dim, action_dim, head_1, head_2, head_3, head_4):
        super(GRU_World_MODEL, self).__init__()
        
        self.gru = nn.GRU(latent_dim + action_dim, head_1, num_layers = 4, batch_first = True)
        
        self.norm = nn.LayerNorm(head_1)
        
        self.extract = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, latent_dim)
        )
        
        self.head_1 = head_1
        
    def forward(self, latent_state, action, h = None):
        
        x = torch.cat([latent_state, action], dim = -1)
        
        batch_size = x.size(0)

        if h is None:
            
            h = torch.zeros(4, batch_size, self.head_1, device = device)        
            
        x = x.unsqueeze(1)
            
        gru_out, h_out = self.gru(x, h)
        
        gru_out = gru_out.squeeze(1)
        
        norm = self.norm(gru_out)
        
        predicted_states = self.extract(norm)
        
        return predicted_states, h_out

## **Ensemble World Models**

In [250]:
def init_weights(m):
    
    if isinstance(m, nn.Linear):
        
        nn.init.kaiming_normal_(m.weight)
        
        if m.bias is not None:
            
            nn.init.zeros_(m.bias)

In [251]:
class ENSEMBLE(nn.Module):
    
    def __init__(self, world_model, ensemble_size):
        super(ENSEMBLE, self).__init__()
        
        self.models = nn.ModuleList(
            
            [copy.deepcopy(world_model) for _ in range(ensemble_size)]
        )
        
        for model in self.models:
            
            model.apply(init_weights)
            
            for param in model.parameters():
                
                param.data += 0.01 * torch.randn_like(param)
                
    def forward(self, latent_state, action, h = None):
        
        preds , hidden_states = [], []
        
        if isinstance(action, np.ndarray):
            
            action = torch.tensor(action, dtype = torch.float32, device = latent_state.device)
            
        if action.dim() == 1:
            
            action = action.unsqueeze(0)
        
        for model in self.models:
            
            pred, h_out = model(latent_state, action)
            
            preds.append(pred)
            hidden_states.append(h_out)
            
        return torch.stack(preds, dim = 0) , torch.stack(hidden_states, dim = 0)

### **Setup**

In [252]:
# Assembly

head_1 = 256
head_2 = 512
head_3 = 512
head_4 = 256

latent_dim = 256

ensemble_size = 5

# Encoder setup

encoder = Encoder(state_dim, latent_dim, head_1, head_2, head_3, head_4).to(device)

# GRU World model

world_model = GRU_World_MODEL(latent_dim, action_dim, head_1, head_2, head_3, head_4).to(device)

# Ensemble

ensemble = ENSEMBLE(world_model, ensemble_size).to(device)

print("-" * 70)

print(encoder)

print("-" * 70)

print(world_model)

print('-' * 70)

print(ensemble)

----------------------------------------------------------------------
Encoder(
  (encoder): Sequential(
    (0): Linear(in_features=348, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): SiLU()
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): SiLU()
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SiLU()
    (8): Linear(in_features=256, out_features=256, bias=True)
  )
)
----------------------------------------------------------------------
GRU_World_MODEL(
  (gru): GRU(273, 256, num_layers=4, batch_first=True)
  (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (extract): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): SiLU()
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=256

## **TRPO NETWORK**

In [253]:
class Feature_Extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_size, hidden_size_2):
        super(Feature_Extractor, self).__init__()
        
        self.feature = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size_2),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size_2, output_dim),
            nn.SiLU()
        )
        
    def forward(self, x):
        
        return self.feature(x)

In [254]:
class Policy_Network(nn.Module):
    
    def __init__(self, latent_dim, action_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2, max_action = max_action):
        super(Policy_Network, self).__init__()
        
        self.max_action = max_action
        
        self.feature = Feature_Extractor(latent_dim, head_1, hidden_size, hidden_size_2)
        
        self.policy = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU()
        )
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
    def forward(self, latent_state):
        
        x = self.feature(latent_state)
        
        policy = self.policy(x)
        
        mu = self.mu(policy)
        log_std = self.log_std(policy)
        
        mu = torch.tanh_(mu)             # Range [-1.0, 1.0]
        log_std = torch.tanh_(log_std)
        
        log_std = log_std.clamp(min = -10, max = 2)
        std = torch.exp(log_std)
        
        dist = torch.distributions.Normal(mu, std)
        z = dist.rsample()
        tanh_z = torch.tanh_(z)
        log_prob = dist.log_prob(z)
        action = tanh_z * self.max_action

        # Squashing
        
        squash = 2 * (torch.log(torch.tensor(2.0, device=z.device)) - z - F.softplus(-2 * z))        
        log_prob = log_prob - squash
        log_prob = torch.sum(log_prob, dim = -1, keepdim = True)
        
        return action, log_prob, mu, log_std

### **Value Network**

In [255]:
class Value_Network(nn.Module):
    
    def __init__(self, latent_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2):
        super(Value_Network, self).__init__()
        
        self.feature = Feature_Extractor(latent_dim, head_1, hidden_size, hidden_size_2)
        
        self.norm = nn.LayerNorm(head_1)
        
        self.values = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
    def forward(self, latent_state):
        
        feature = self.feature(latent_state)
        
        norm = self.norm(feature)
        
        values = self.values(norm)
                
        return values

### **Setup**

In [256]:
# Assembly

hidden_size = 256
hidden_size_2 = 512

lr = 2.5e-4
T_max = 3000

# Policy network

policy_network = Policy_Network(latent_dim, action_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2).to(device)

print('-' * 70)

print(policy_network)

# Value network

value_network = Value_Network(latent_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2).to(device)

print('-' * 70)

print(value_network)

# Value opt and Sch

value_optimizer = optim.AdamW(value_network.parameters(), lr, weight_decay = 0.001)
value_scheduler = optim.lr_scheduler.CosineAnnealingLR(value_optimizer, T_max)

----------------------------------------------------------------------


Policy_Network(
  (feature): Feature_Extractor(
    (feature): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): SiLU()
      (4): Linear(in_features=256, out_features=256, bias=True)
      (5): SiLU()
      (6): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (7): Linear(in_features=256, out_features=512, bias=True)
      (8): SiLU()
      (9): Linear(in_features=512, out_features=512, bias=True)
      (10): SiLU()
      (11): Linear(in_features=512, out_features=256, bias=True)
      (12): SiLU()
    )
  )
  (policy): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): SiLU()
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): SiLU()
    (5): Linear(in_features=512, out_features=256, bias=True)
    (6): SiLU()
  )
  (mu): Linear(in_features=25

## **Intrinsic Reward Signal**

In [257]:
def intrinsic_reward_signal(ensemble, latent_states, actions, extrinsic_reward, beta, intrinsic_scale = 10.0):
    
    preds, h_out = ensemble(latent_states, actions)
    
    if preds.dim() == 3:
        
        preds = preds.permute(1, 0, 2)                                         # Shape [batch, ensemble_size, latent]
    
        Variance = torch.var(preds, dim = 1, keepdim = True)
        
        intrinsic_reward = torch.sum(Variance, dim = 1, keepdim = True)
        
    
    else:
        
        Variance = torch.var(preds, dim = 0, keepdim = True)
        
        intrinsic_reward = torch.sum(Variance, dim = 0, keepdim = True)
        
    intrinsic_reward = intrinsic_reward / (intrinsic_reward.std() + 1e-8)


    #intrinsic_reward = intrinsic_reward / intrinsic_scale
    intrinsic_reward = intrinsic_reward + beta * extrinsic_reward
    
    return intrinsic_reward.squeeze()

## *World replay Buffer*

In [258]:
class OG_Replay_buffer:
    
    def __init__(self, capacity):
        
        
        self.capacity = capacity
        self.pos = 0
        self.buffer = deque(maxlen = capacity)
        self.roller_buffer = []
        
    def add(self, latent_state, action, reward, latent_next_state, done):
        
        # Safe conversion
        
        latent_state = torch.tensor(latent_state, dtype = torch.float32, device = device)
        action = torch.tensor(action, dtype = torch.float32, device = device)
        reward = torch.tensor(reward, dtype = torch.float32, device = device)
        latent_next_state = torch.tensor(latent_next_state, dtype = torch.float32, device = device)
        done = torch.tensor(done, dtype = torch.float32, device = device)
        
        # Add to buffer
        
        experience = (latent_state, action, reward, latent_next_state, done)
        
        if len(self.buffer) < self.capacity:
            
            self.buffer.append(experience)
            
        else:
            
            self.buffer[self.pos] = experience
            self.pos = (1 + self.pos) % self.capacity
            
    def sample(self, batch_size):
        
        if len(self.buffer) < batch_size:
            raise ValueError("Not enough samples in World Buffer.")
        
        indices = np.random.choice(len(self.buffer), batch_size)
        
        latent_states, actions, rewards, latent_next_states, dones = zip(*[self.buffer[ind] for ind in indices])
        
        # Stack these
        
        latent_states = torch.stack(latent_states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        latent_next_states = torch.stack(latent_next_states).to(device)
        dones = torch.stack(dones).to(device)
        
        return (latent_states, actions, rewards, latent_next_states, dones)
    
    def add_rollout_data(self, state, action, reward, next_state, done, value):
        
        self.roller_buffer.append({
            
            'states': state.detach(),
            'actions': action.detach(),
            'rewards': reward.detach(),
            'next_states': next_state.detach(),
            'dones': done.detach(),
            'values': value.detach()
        })
        
    def sample_rollout(self, batch_size):
        
        indices = np.random.choice(len(self.roller_buffer), batch_size)
        
        samples = [self.roller_buffer[ind] for ind in indices]
        
        # Stack all trajectories
        
        def safe_cat(key):
            items = []
            for s in samples:
                val = s[key]
                if isinstance(val, np.ndarray):
                    val = torch.tensor(val, dtype=torch.float32, device=device)
                elif isinstance(val, list):
                    val = torch.tensor(val, dtype=torch.float32, device=device)
                if not isinstance(val, torch.Tensor):
                    val = torch.tensor(val, dtype=torch.float32, device=device)
                if val.ndim == 1:
                    val = val.unsqueeze(0)
                items.append(val)
            return torch.cat(items, dim=0)

        return {
            'states': safe_cat('states'),
            'actions': safe_cat('actions'),
            'rewards': safe_cat('rewards'),
            'next_states': safe_cat('next_states'),
            'dones': safe_cat('dones')
        }
        
    def mix_batch(self, batch_size, ratio):
        
        real_size = int(batch_size * ratio)
        batch_now = batch_size - real_size
        
        world_batch = self.sample(real_size)
        trpo_batch = self.sample_rollout(batch_now)
        
        # Now unzipping both 
        
        w_s, w_a, w_r, w_ns, w_d = world_batch
        
        t_s = trpo_batch['states']
        t_a = trpo_batch['actions']
        t_r = trpo_batch['rewards']
        t_ns = trpo_batch['next_states']
        t_d = trpo_batch['dones']
        
        # Shaping
        
        w_s = w_s.squeeze(1)
        w_ns = w_ns.squeeze(1)
        w_a = w_a.squeeze(1)
        
        w_r = w_r.mean(dim = 1, keepdim = True)
        
        # Concat them
        
        states = torch.cat([w_s, t_s], dim = 0).to(device)
        actions = torch.cat([w_a, t_a], dim = 0).to(device)
        rewards = torch.cat([w_r, t_r], dim = 0).to(device)
        next_states = torch.cat([w_ns, t_ns], dim = 0).to(device)
        dones = torch.cat([w_d, t_d], dim = 0).to(device)
        
        return states, actions, rewards, next_states, dones

### **Setup**

In [259]:
# Hyper params

capacity = 500_000

batch_size = 256

# Setup

world_buffer = OG_Replay_buffer(capacity)

## **World_loss_Function**


In [260]:
class WORLD_LOSS_FUNCTION:
    
    def __init__(self, world_model, world_optimizer, world_scheduler, encoder):
        
        self.model = world_model
        self.world_optimizer = world_optimizer
        self.world_scheduler = world_scheduler
        self.encoder = encoder
        
    def compute_loss(self, trajectories):
        
        states = trajectories['states'].to(device)
        actions = trajectories['actions'].to(device)
        next_states = trajectories['next_states'].to(device)
        
        # Babe here we need to check if states and next states are in latent space
        
        # If not
        if states.shape[-1] != self.encoder.encoder[-1].out_features:
            states = self.encoder(states)
            next_states = self.encoder(next_states)
        
        pred_latent_next_states, h_out = self.model(states, actions)
        
        loss = F.mse_loss(next_states, pred_latent_next_states)
        
        self.world_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = 1.0)
        self.world_optimizer.step()
        self.world_scheduler.step()
        
        return loss.item()

### **Setup**


In [261]:

# Optimizer and Scheduler

world_optimizer = optim.AdamW(world_model.parameters(), lr, weight_decay = 0.001)
world_scheduler = optim.lr_scheduler.CosineAnnealingLR(world_optimizer, T_max)

# Setup

world_loss_function = WORLD_LOSS_FUNCTION(world_model, world_optimizer, world_scheduler, encoder)

### **TRPO LOSS FUNCTION**

In [262]:
class TRPO_LOSS_FUNCTION:
    
    def __init__(self, policy_network, value_network, value_optimizer, value_scheduler, entropy_coef, max_iter, tolerance, damping, max_back_tracks, max_kl):
        
        # Network
        
        self.policy_network = policy_network
        self.value_network = value_network
        
        # optimizer and Scheduler
        
        self.value_opt = value_optimizer
        self.value_sch = value_scheduler
        
        # Hyper params
        
        self.max_kl = max_kl
        self.max_back_tracks = max_back_tracks
        self.entropy_coef = entropy_coef
        self.max_iter = max_iter
        self.tolerance = tolerance
        self.damping = damping
        
        
    def compute_gae(self, rewards, last_value, values, dones, gamma, gae_lambda):
                
        gae = 0
        advantages = []
                
        values = torch.tensor(values.tolist() + [last_value], dtype = torch.float32, device = device)
        rewards = torch.tensor(rewards, dtype = torch.float32, device = device)
        dones = torch.tensor(dones, dtype = torch.float32, device = device)
                
        for step in reversed(range(len(rewards))):
                    
            delta_gae = rewards[step] + gamma * (1 - dones[step]) * values[step + 1] - values[step]
            gae = delta_gae + gamma * gae * gae_lambda
                    
            advantages.append(gae)
                    
        advantages = advantages[::-1]
                
        returns = [adv + val for adv, val in zip(advantages, values)]
                
        advantages = torch.tensor(advantages, dtype = torch.float32, device = device)
        returns = torch.tensor(returns, dtype = torch.float32, device = device)
                
        return advantages, returns
        
    def safe_tensor(self, x, dtype):
        
        return x if torch.is_tensor(x) else torch.tensor(x, dtype = dtype, device = device)
        
    def surrogate_loss(self, old_log_probs, latent_states, advantages):
        
        _ , log_probs, mu, log_std = self.policy_network(latent_states)
        
        ratio = torch.exp(log_probs - old_log_probs)
        
        advantages = self.safe_tensor(advantages, dtype = torch.float32)
        
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        surrogate = -(ratio * advantages).mean()
        
        std = torch.exp(log_std)
        
        dist = torch.distributions.Normal(mu, std)
        
        entropy = dist.entropy().sum(dim = -1).mean()
        
        total_loss = surrogate - entropy * self.entropy_coef
        
        return total_loss
        
        
    def compute_kl_divergence(self, latent_states):
        
        with torch.no_grad():
            
            _, _, mu, log_std = self.policy_network(latent_states)
            
        _, _, new_mu, new_log_std = self.policy_network(latent_states)
        
        old_std = torch.exp(log_std)
        new_std = torch.exp(new_log_std)
        
        KL_Div = (
            
            (log_std - new_log_std) +
            
            (((old_std).pow(2) + (mu - new_mu).pow(2)) / (2.0 * new_std.pow(2))) -0.5
        ) 
        
        return KL_Div.sum(dim = -1).mean()
    
    
    def conjugate_gradient(self, b, Av_func):
        
        x = torch.zeros_like(b)                  # we know nothing
        
        p = b.clone()                            # This is moving dir
        r = b.clone()                            # this is residual
        
        rs_old = torch.dot(r, r)                 # this is squared residual
        
        for _ in range(self.max_iter):
            
            Ap = Av_func(p)
            
            alpha = rs_old / (torch.dot(p, Ap) + 1e-8)
            
            x += alpha * p
            r -= alpha * Ap
            
            rs_new = torch.dot(r, r)
            
            if rs_new < self.tolerance:
            
                break
        
        p = r + (rs_new / rs_old) * p
        rs_old = rs_new
        
        return x
    
    
    def fisher_vector_product(self, vector, latent_states):
        
        # Compute Dist 1
        
        _, _, mu, log_std = self.policy_network(latent_states)
        
        std = torch.exp(log_std)
        
        dist = torch.distributions.Normal(mu, std)
        
        # Compute dist 2
        
        with torch.no_grad():
            
            _, _, old_mu, old_log_std = self.policy_network(latent_states)
            
        old_std = torch.exp(old_log_std)
        
        dist_2 = torch.distributions.Normal(old_mu, old_std)
        
        # Compute Kl Divergence
        
        kl = torch.distributions.kl_divergence(dist, dist_2).mean()
        
        # First order
        
        kl_grad = torch.autograd.grad(kl, self.policy_network.parameters(), create_graph = True)
        flat_kl_grad = torch.cat([g.view(-1) for g in kl_grad])
        
        kl_v = (flat_kl_grad * vector).sum()
        
        # Second order
        
        kl_grad_2 = torch.autograd.grad(kl_v, self.policy_network.parameters(), retain_graph = True)
        flat_kl_grad_2 = torch.cat([g.contiguous().view(-1) for g in kl_grad_2])
        
        return flat_kl_grad_2 + self.damping * vector
        
        
    def get_kl_loss(self, latent_states, old_log_probs, advantages):
        
        loss = self.surrogate_loss(old_log_probs, latent_states, advantages)
        
        kl_div = self.compute_kl_divergence(latent_states)
        
        return kl_div, loss
    
    def line_search(self, full_step, latent_states, old_log_probs, advantages):
        
        prev_params = torch.nn.utils.parameters_to_vector(self.policy_network.parameters())
        
        for step in [0.5 ** i for i in range(self.max_back_tracks)]:
            
            new_params = prev_params + step * full_step
            
            torch.nn.utils.vector_to_parameters(new_params, self.policy_network.parameters())
            
            kl, loss = self.get_kl_loss(latent_states, old_log_probs, advantages)
            
            if kl < self.max_kl and loss < 0 :
                
                return True, new_params
                
        torch.nn.utils.vector_to_parameters(prev_params, self.policy_network.parameters())
                
        return False, prev_params

### **Setup**

In [263]:
# Hyper params

gamma = 0.997
gae_lam = 0.99
entropy_coef = 0.01
max_iter = 10
max_back_tracks = 20
tolerance = 1e-9
damping = 1e-2
max_kl = 0.01

# Setup

trpo_loss_function = TRPO_LOSS_FUNCTION(policy_network, value_network, value_optimizer, value_scheduler, entropy_coef, max_iter, tolerance, damping, max_back_tracks, max_kl)


## **TRPO Agent Class**

In [264]:
class TRPO_AGENT:
    
    def __init__(self, policy_network, value_network, value_optimizer, value_scheduler, gamma, gae_lambda, trpo_loss_function = trpo_loss_function):

        self.policy = policy_network
        self.value = value_network
        
        self.value_opt = value_optimizer
        self.value_sch = value_scheduler
        
        self.gamma = gamma
        self.gae_lam = gae_lambda
        self.succeeded = 0
        
        self.loss_function = trpo_loss_function
        
    def update(self, replay_buffer, batch_size):
        
        states, actions, rewards, next_states, dones = replay_buffer.mix_batch(batch_size, ratio = 0.25)
        
        with torch.no_grad():
            
            value = self.value(states).squeeze(-1).detach()
            #value = value.cpu().numpy()
            
        
        # Compute last value
        
        with torch.no_grad():
            
            last_state = next_states[-1].unsqueeze(0).to(device)
            last_value = self.value(last_state).detach()

            last_value = last_value.item()
            
        advantages , returns = self.loss_function.compute_gae(rewards, last_value, value, dones, self.gamma, self.gae_lam)
        # Compute old log probs
        
        _, old_log_probs, _, _ = self.policy(states)
        
        # compute surrogate loss
        
        surrogate_loss = self.loss_function.surrogate_loss(old_log_probs, states, advantages)
        
        # compute surrogate gradient (b)
        
        grads = torch.autograd.grad(surrogate_loss, self.policy.parameters())
        flat_grads = torch.cat([g.view(-1) for g in grads])
        
        # Compute Fisher matrix
        
        Av_function = lambda v: self.loss_function.fisher_vector_product(v, states)
        step_direction = self.loss_function.conjugate_gradient(flat_grads, Av_function)
        
        # Compute full scale
        
        step_scale = (2 * self.loss_function.max_kl / (step_direction.dot(Av_function(step_direction)) + 1e-8 )) ** 0.5
        full_step = step_direction * step_scale
        
        # Line Search
         
        success, new_params = self.loss_function.line_search(full_step, states, old_log_probs, advantages)
            
        print(f'Success: {success}')
        
        if success:
            
            torch.nn.utils.vector_to_parameters(new_params, self.policy.parameters())
            self.succeeded += 1
            
        # Update Value Network
        
        for _ in range(10):
            
            v = self.value(states).squeeze(-1)
            
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
            
            loss = F.smooth_l1_loss(v, returns)
            
            self.value_opt.zero_grad()
            loss.backward()
            self.value_opt.step()
            self.value_sch.step()
            
        return loss.item(), self.succeeded   

### **Setup**

In [265]:
trpo_agent = TRPO_AGENT(policy_network, value_network, value_optimizer, value_scheduler, gamma, gae_lam)

## **TRPO ROLLOUT**

In [266]:
def trpo_rollout(policy_network, env, max_steps, beta, world_buffer = world_buffer, encoder = encoder, ensemble = ensemble, trpo_loss_function = trpo_loss_function):
    
    state, _ = env.reset()
    
    for _ in range(max_steps):
        
        state = torch.tensor(state, dtype = torch.float32, device = device).unsqueeze(0)
        
        if state.shape[-1] != encoder.encoder[-1].out_features:
            
            state = encoder(state)
            
        action, _, _, _ = policy_network(state)
        
        action = action.squeeze(0).detach().cpu().numpy()
        
        next_state, reward, done, _, _ = env.step(action)
        
        next_state = next_state if torch.is_tensor(next_state) else torch.tensor(next_state, dtype = torch.float32, device= device)
        
        if next_state.dim() == 1:
            next_state = next_state.unsqueeze(0)
        
        if next_state.shape[-1] != encoder.encoder[-1].out_features:
            
            next_state = encoder(next_state)   
            
        action = action if torch.is_tensor(action) else torch.tensor(action, dtype = torch.float32, device = device).unsqueeze(0)
    
        intrinsic_reward = intrinsic_reward_signal(ensemble, state, action, reward, beta)
        
        dones = torch.tensor([done], dtype=torch.float32, device=device)
        
        world_buffer.add(state, action, intrinsic_reward, next_state, dones)
        
        if done:
            
            state, _ = env.reset()
            
        else:
            
            state = next_state
            
        return intrinsic_reward

In [267]:
def shaped_reward(reward , state):
    
    state = state.squeeze(0)
    
    is_falling = (state[2] < 0.8 or          # Height threshold
                 abs(state[3]) > 0.4 or      # Angular velocity X
                 abs(state[4]) > 0.4)        # Angular velocity Y
    
    if is_falling:
        
        recovery_bonus = 3.0 * (1 - abs(state[3]))
        survival_bonus = 0.5 * (0.8 - state[2])
        reward += recovery_bonus + survival_bonus
    
    else:
        
        reward += 0.05
        reward += 1.5 * state[1]
        reward += 0.3 * state[2] ** 2 # Height bonus quadritic
        reward += 1.5 * state[0] * (1 + state[2])   # Forward locomotion
        reward -= 0.1 * (state[3:6] ** 2).sum()  # Angular penalty
        reward -= 0.1 * (state[3] ** 2 + state[4] ** 2)
        
    return reward / 10.0

## **Training**

In [268]:
def get_beta(episode, warmup=600, decay_start= 900, min_beta=0.01):
    if episode < warmup:
        return 1.0  # Pure curiosity at start
    elif episode < decay_start:
        return 0.5  # Mix of curiosity + reward
    else:
        return max(0.1 * (0.99 ** (episode - decay_start)), min_beta)


In [269]:
def training_block(
    max_episodes,
    update_every,
    roll_steps,
    steps,
    warm_up,
    batch_size=batch_size,
    env=env,
    trpo_agent=trpo_agent,
    world_buffer=world_buffer,
    encoder=encoder,
    world_loss_function=world_loss_function,
    value_network = value_network,
    policy_network = policy_network
):
    episode_reward = []
    max_reward = -np.inf

    for episode in range(max_episodes):
        state, _ = env.reset()
        ep_reward = 0.0
        ep_intrinsic = 0.0

        # ----------- STEP PHASE (interact with env) ------------
        for _ in range(steps):
            state = torch.tensor(state, dtype=torch.float32, device=device)
            
            if state.dim() == 1:
                state = state.unsqueeze(0)

            if state.shape[-1] != encoder.encoder[-1].out_features:
                state = encoder(state)

            action, _, _, _ = policy_network(state)
            
            action = action.squeeze(0).detach().cpu().numpy()
            
            next_state, reward, done, _, _ = env.step(action)
            
            reward = shaped_reward(reward, state)

            with torch.no_grad():
                v = value_network(state).squeeze(-1)

            next_state = trpo_agent.loss_function.safe_tensor(next_state, dtype=torch.float32).unsqueeze(0)
            if next_state.shape[-1] != encoder.encoder[-1].out_features:
                next_state = encoder(next_state)

            # Store in roller buffer
            
            action = torch.tensor(action, dtype = torch.float32, device = device)
            reward = torch.tensor([reward], dtype = torch.float32, device = device)
            done = torch.tensor([done], dtype = torch.float32, device = device)
            
            world_buffer.add_rollout_data(state, action, reward, next_state, done, v)

            ep_reward += reward.item()
            
            if ep_reward > max_reward:
                
                max_reward = ep_reward

            if done:
                break

            state = next_state

        # ----------- UPDATE PHASE (only after warmup) ------------
        
        if episode > warm_up and len(world_buffer.roller_buffer) >= batch_size:
            episode_reward.append(ep_reward)

            # 1. Train World Model on trajectory rollouts
            
            for _ in range(10): 
                
                trajectories = world_buffer.sample_rollout(batch_size)
                
                world_loss = world_loss_function.compute_loss(trajectories)

            # 2. Rollout new imagined data using TRPO policy
            
            if episode % update_every == 0:
                
                #beta = get_beta(episode)
                
                intrinsic_reward = trpo_rollout(trpo_agent.policy, env, roll_steps, beta = 0.1)
                ep_intrinsic += intrinsic_reward.sum().item()
                
                print(f"Episode: {episode} | Intrinsic reward: {ep_intrinsic:.4f}")

            # 3. Update TRPO Agent with mixed batch
            
            if len(world_buffer.buffer) >= batch_size:
                
                loss, success = trpo_agent.update(world_buffer, batch_size)
            else:
                print(f"Skipping TRPO update at episode {episode} | Not enough samples in World Buffer...")


        else:
            print(f"Episode: {episode} | Reward: {ep_reward:.2f} | Max reward: {max_reward} | Warm-up phase...")
            world_loss = torch.tensor(0.0)  # Dummy value
            loss, success = 0.0, 0  # Dummy values

        # ----------- LOGGING PHASE ------------
        print(" - " * 30)
        print(f"Episode: {episode} | Reward: {ep_reward:.2f} | Max reward: {max_reward} "
              f"| Loss: {loss:.4f} | Success: {success}")
        print(f"World Loss: {world_loss:.4f}\n")
            

In [270]:
training_block(max_episodes = 3000,
               update_every = 2,
               roll_steps = 5,
               steps = 1024,
               warm_up = 30)

Episode: 0 | Reward: 21.80 | Max reward: 21.800170183181763 | Warm-up phase...
 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
Episode: 0 | Reward: 21.80 | Max reward: 21.800170183181763 | Loss: 0.0000 | Success: 0
World Loss: 0.0000

Episode: 1 | Reward: 20.98 | Max reward: 21.800170183181763 | Warm-up phase...
 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
Episode: 1 | Reward: 20.98 | Max reward: 21.800170183181763 | Loss: 0.0000 | Success: 0
World Loss: 0.0000

Episode: 2 | Reward: 23.01 | Max reward: 23.012564092874527 | Warm-up phase...
 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
Episode: 2 | Reward: 23.01 | Max reward: 23.012564092874527 | Loss: 0.0000 | Success: 0
World Loss: 0.0000

Episode: 3 | Reward: 16.60 | Max reward: 23.012564092874527 | Warm-up phase...
 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 