# **P2E**

Plan 2 Explore modeling learning why to chase the rewards so it will be able to learn whole env


In [373]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
 
 
from collections import deque

import numpy as np
import copy

import gymnasium as gym
from gymnasium.wrappers import RescaleAction

import warnings

In [374]:
warnings.filterwarnings(action = 'ignore', category = UserWarning)

## **Env Setup**


In [375]:
env = gym.make('Humanoid-v5')

env = RescaleAction(env, min_action = -1.0, max_action = 1.0)

state_dim = env.observation_space.shape[0]

action_dim = env.action_space.shape[0]

max_action = env.action_space.high[0]

print(f'State dim: {state_dim} | Action dim: {action_dim} | Max action: {max_action}')

State dim: 348 | Action dim: 17 | Max action: 1.0


## **Device Setup**

In [376]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


In [377]:
def safe_tensor(x):
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)

## **Imagination_Buffer**

In [378]:
class replay_buffer:
    
    def __init__(self, capacity):
        
        self.capacity = capacity
        self.pos = 0
        self.buffer = deque(maxlen = capacity)
        
    def add(self, state, action, reward, next_state, done):
        
        # Safe conversion 
        
        state = safe_tensor(state)
        action = safe_tensor(action)
        reward = safe_tensor(reward)
        next_state = safe_tensor(next_state)
        done = safe_tensor(done)
        
        if state.dim() == 3:
            state = state.squeeze(1)
        
        experience = (state, action, reward, next_state, done)
        
        if len(self.buffer) < self.capacity:
            
            self.buffer.append(experience)
            
        else:
            
            self.buffer[self.pos] = experience
            self.pos = (1 + self.pos) % self.capacity
            
    def sample(self, batch_size):
        
        indices = np.random.choice(len(self.buffer), batch_size)
        
        states, actions, rewards, next_states, dones = zip(*[self.buffer[ind] for ind in indices])
        
        # Safe stacking
        
        states = torch.stack(states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        next_states = torch.stack(next_states).to(device)
        dones = torch.stack(dones).to(device)
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        
        return len(self.buffer)

#### **Setup**

In [379]:

capacity = 500_000

imagination_buffer = replay_buffer(capacity)

## **Encoder**

In [380]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, latent_dim, head_1, head_2, head_3, head_4):
        super(Encoder, self).__init__()
        
        self.encode = nn.Sequential(
            
            nn.Linear(input_dim, head_1),
            nn.SiLU(),
            
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, latent_dim)
        )
        
    def forward(self, x):
        
        return self.encode(x)

### **Setup**

In [381]:
# Assembly

latent_dim = 256
head_1 = 128
head_2 = 256
head_3 = 512
head_4 = 256

# Setup

encode = Encoder(input_dim = state_dim, latent_dim = latent_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4).to(device)

print('-' * 70)

print(encode)

----------------------------------------------------------------------
Encoder(
  (encode): Sequential(
    (0): Linear(in_features=348, out_features=128, bias=True)
    (1): SiLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): SiLU()
    (4): Linear(in_features=256, out_features=512, bias=True)
    (5): SiLU()
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SiLU()
    (8): Linear(in_features=256, out_features=256, bias=True)
  )
)


## **GRU World Model**


In [382]:
class Gru_World_Model(nn.Module):
    
    def __init__(self, latent_dim, state_dim, action_dim, head_1, head_2, head_3, head_4):
        super(Gru_World_Model, self).__init__()
        
        self.gru = nn.GRU(latent_dim + action_dim, hidden_size = head_1, num_layers = 4, batch_first = True)
        
        self.fc = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, latent_dim)
        )
        
    def forward(self, state, action, h = None):
        
        
        # Concat
        
        #print(f'Shape of state: {state.shape} | shape of action: {action.shape}')
        
        if state.dim() == 3:
            
            state = state.squeeze(1)
        
        
        x = torch.cat([state, action], dim = -1)
        
        if h is None:
            
            h = torch.zeros(4 , state.size(0), 256).to(device)
        
        out, h_out = self.gru(x.unsqueeze(1), h)
        
        out = out.squeeze(1)
        
        pred_state = self.fc(out)
        
        return pred_state, h_out

### **Setup**

In [383]:

world_model = Gru_World_Model(latent_dim, state_dim, action_dim, head_1 = 256, head_2 = 512, head_3 = 512, head_4 = 256).to(device)

print(world_model)

Gru_World_Model(
  (gru): GRU(273, 256, num_layers=4, batch_first=True)
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): SiLU()
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=256, bias=True)
  )
)


## *Ensembly*

In [384]:
def init_weights(m):
    
    if isinstance (m, nn.Linear):
        
        nn.init.kaiming_normal_(m.weight)
        
        if m.bias is not None:
            
            nn.init.zeros_(m.bias)

In [385]:
class Ensembly(nn.Module):
    
    def __init__(self, ensemble_size, base_model):
        super(Ensembly, self).__init__()
        
        self.models = nn.ModuleList(
            
            [copy.deepcopy(base_model) for _ in range(ensemble_size)]
        )
        
        for model in self.models:
            
            model.apply(init_weights)
            
        for model in self.models:
            for param in model.parameters():
                param.data += 0.01 * torch.randn_like(param)

            
    def forward(self, state, action, h = None):
        
        preds, hidden_states = [], []
        
        for model in self.models:
            
            pred, h_out = model(state, action, h)
            
            preds.append(pred)
            hidden_states.append(h_out)
            
        return torch.stack(preds), torch.stack(hidden_states)

### **Setup**


In [386]:

ensemble_size = 4

ensemble = Ensembly(ensemble_size, world_model)

print(ensemble)

Ensembly(
  (models): ModuleList(
    (0-3): 4 x Gru_World_Model(
      (gru): GRU(273, 256, num_layers=4, batch_first=True)
      (fc): Sequential(
        (0): Linear(in_features=256, out_features=512, bias=True)
        (1): SiLU()
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): SiLU()
        (4): Linear(in_features=512, out_features=256, bias=True)
        (5): SiLU()
        (6): Linear(in_features=256, out_features=256, bias=True)
      )
    )
  )
)


### **World Model Loss**

In [387]:
class world_model_loss:
    
    def __init__(self, model, optimizer, scheduler, encode):
        
        self.encoder = encode
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.model = model
        
    def compute_loss(self, replay_buffer, batch_size, memory):
        
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        
        # Conversion to latent state

        
        latent = self.encoder(states)
        latent_next_states = self.encoder(next_states)
        
        
        pred_next_latent_states, memory = self.model(latent, actions, memory)
        
        loss = F.smooth_l1_loss(latent_next_states, pred_next_latent_states)
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = 0.5)
        self.optimizer.step()
        self.scheduler.step()
        
        return loss.item()

### **Setup**


In [388]:
# Hyper params

lr = 3e-4
T_max = 3000

# optimizers & scheduler

optimizer = optim.AdamW(world_model.parameters(), lr, weight_decay = 0.001)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)

# Loss setup

world_loss_function = world_model_loss(model = world_model, optimizer = optimizer, scheduler = scheduler, encode = encode)


## **SAC Design**


In [389]:
class Feature_extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_size, hidden_size_2):
        super(Feature_extractor, self).__init__()
        
        self.feature = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size_2),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size_2, hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, output_dim)
        )
        
    def forward(self, x):
        
        return self.feature(x)

### **Actor network**

In [390]:
class Actor_Network(nn.Module):
    
    def __init__(self,action_dim, latent_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2, max_action = max_action):
        super().__init__()
        
        self.feature = Feature_extractor(input_dim = latent_dim, output_dim = head_1, hidden_size = hidden_size, hidden_size_2 = hidden_size_2)
        
        self.norm = nn.LayerNorm(head_1)
        
        self.mha = nn.MultiheadAttention(head_1, num_heads = 4, batch_first = True)
        
        self.actor = nn.Sequential(
            
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.LayerNorm(head_3),
            nn.Linear(head_3, head_4),
            nn.SiLU()
        )
        
        self.mu = nn.Linear(head_4, action_dim)
        
        self.log_std = nn.Linear(head_4, action_dim)
        
        self.max_action = max_action
        
    def forward(self, state):
        
        # state -> feature
        
        feature = self.feature(state)
        
        # Feature -> layer norm -> Unsqueeze
        
        norm = self.norm(feature)
        
        #norm = norm.unsqueeze(1)
        
        # Unsqueezed norm -> MHA -> Squeeze
        
        attn, _ = self.mha(norm, norm, norm)
        
        attn = attn.squeeze(1)
        
        # Attn -> Actor
        
        x = self.actor(attn)
        
        # Mean and Log Std
        
        mu = self.mu(x)
        
        log_std = self.log_std(x)
        
        mu = torch.tanh_(mu)                    # Range [-1.0, 1.0]
        log_std = torch.tanh_(log_std)          # Range [-1.0, 1.0]
        log_std = log_std.clamp(min = -5, max = 2)
        std = torch.exp(log_std)
        
        # Reparameterization Trick
        
        normal = torch.distributions.Normal(mu, std)
        z = normal.rsample()
        tanh_z = torch.tanh_(z)
        log_prob = normal.log_prob(z)
        action = tanh_z * self.max_action
        
        # Squashing
        
        squash = 2 * (torch.log(torch.tensor(2.0, device=z.device)) - z - F.softplus(-2 * z))
        log_prob = log_prob - squash
        log_prob = torch.sum(log_prob, dim = -1, keepdim = True)
        
        return action, log_prob
    

### **Critic Network**

In [391]:
class Critic_Network(nn.Module):
    
    def __init__(self, latent_dim, action_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2):
        super(Critic_Network, self).__init__()
        
        self.feature = Feature_extractor(latent_dim + action_dim, head_1, hidden_size, hidden_size_2)
        
        self.norm = nn.LayerNorm(head_1)
        
        self.norm_1 = nn.LayerNorm(head_1)
        
        self.mha = nn.MultiheadAttention(head_1, num_heads = 4, batch_first = True)
        
        self.critic = nn.Sequential(
            
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
        self.critic_2 = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
    def forward(self, state, action):
        
        cat = torch.cat([state, action], dim = -1)
        
        # Feature -> layer norm -> Unsqueeze
        
        feature = self.feature(cat)

        norm = self.norm(feature)
        norm_1 = self.norm_1(feature)
        
        norm = norm.unsqueeze(1)
        norm_1 = norm_1.unsqueeze(1)
        
        # Unsqueezed norm -> MHA -> Squeeze
        
        attn, _ = self.mha(norm, norm, norm)
        attn_1, _ = self.mha(norm_1, norm_1, norm_1)         # Used same MHA Layer due to computation
        
        attn = attn.squeeze(1)
        attn_1 = attn_1.squeeze(1)
        
        q_1 = self.critic(attn)
        q_2 = self.critic_2(attn_1)
        
        return q_1, q_2

### **Setup**


In [392]:

hidden_size = 256
hidden_size_2 = 512


# Actor setup

actor_network = Actor_Network(action_dim, latent_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2).to(device)

print(actor_network)

# Critic setup

critic_network = Critic_Network(latent_dim, action_dim, head_1, head_2, head_3, head_4, hidden_size, hidden_size_2).to(device)

print('-' * 70)

print(critic_network)

# Target critic

target_critic = copy.deepcopy(critic_network).to(device)

# Target actor
target_actor = copy.deepcopy(actor_network).to(device)

Actor_Network(
  (feature): Feature_extractor(
    (feature): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): SiLU()
      (4): Linear(in_features=256, out_features=512, bias=True)
      (5): SiLU()
      (6): Linear(in_features=512, out_features=512, bias=True)
      (7): SiLU()
      (8): Linear(in_features=512, out_features=256, bias=True)
      (9): SiLU()
      (10): Linear(in_features=256, out_features=128, bias=True)
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (mha): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (actor): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): SiLU()
    (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (5):

## **Agent Class**

In [393]:
class Sac_Agent:
    
    def __init__(self, action_dim, actor_network, critic_network, target_critic, actor_opt, critic_opt, actor_sch, critic_sch, gamma, tau, world_loss_function, beta, encode = encode, target_actor = target_actor):
        
        # Network
        
        self.actor = actor_network
        self.critic = critic_network
        self.target_critic = target_critic
        self.target_actor = target_actor
        self.encode = encode
        
        # Opt and Sch
        
        self.actor_opt = actor_opt
        self.actor_sch = actor_sch
        
        self.critic_opt = critic_opt
        self.critic_sch = critic_sch
        
        # hyper params
        
        self.beta = beta
        self.gamma = gamma
        self.tau = tau
        self.world_loss_function = world_loss_function
        
        self.target_entropy = - action_dim * 1.5
        self.log_alpha = torch.tensor(np.log(0.3), requires_grad = True, device = device)
        self.alpha_min = 0.1
        self.alpha_opt = optim.AdamW([self.log_alpha], lr = 1e-4, weight_decay = 0.001)
        
    
    def soft_update(self, source, target, tau):
        
        for param, target_param in zip(source.parameters(), target.parameters()):
            
            param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
            

    def select_action(self, state):
        
        latent = self.encode(state)
        
        action, log_prob = self.actor(latent)
        
        return action, log_prob
    
    
    def compute_alpha(self):
        
        alpha = self.log_alpha.exp().detach()
        alpha = alpha.clamp(self.alpha_min, 0.3)
        return alpha
    
    def update(self, replay_buffer, batch_size):
                
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        
        states = states.to(device)
        rewards = rewards.to(device)
        next_states = next_states.to(device)
        dones = dones.to(device)
        
        # Compute target vals
        
        with torch.no_grad():
            
            #latent_next_states = self.encode(next_states)
            
            next_action, _ = self.actor(next_states)
            
            target_1, target_2 = self.target_critic(next_states, next_action)
            target = (0.75 * torch.min(target_1, target_2) + 0.25 * torch.max(target_1, target_2))
            target = target.detach()
            
            target_qvals = rewards + self.gamma * (1 - dones) * target
        
        
        # compute critic loss
        
        #latent_states = self.encode(states)
        
        new_action, log_prob = self.actor(states)
        
        critic_1, critic_2 = self.critic(states, new_action)
        critic_loss_1 = F.smooth_l1_loss(critic_1, target_qvals)
        critic_loss_2 = F.smooth_l1_loss(critic_2, target_qvals)
        
        critic_loss = critic_loss_1 + critic_loss_2
        
        # Update critic
        
        self.critic_opt.zero_grad()
        critic_loss.backward(retain_graph = True)
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm = 0.5)
        self.critic_opt.step()
        self.critic_sch.step()
        
        # Compute actor loss
        
        #latent_states_2 = self.encode(states)
        
        n_action, n_log_prob = self.actor(states)
        q_1, q_2 = self.critic(states, n_action)
        q_pi = (0.75 * torch.min(q_1, q_2) + 0.25 * torch.max(q_1, q_2)).mean()
        actor_loss = -(self.compute_alpha() * n_log_prob - q_pi).mean()      
        
        with torch.no_grad():
            
            _ ,log_prob = self.target_actor(states)
            log_prob = log_prob.detach()
            
        kl_div = (n_log_prob - log_prob).mean()
        
        
        actor_loss += self.beta * kl_div

        # Update actor
        
        self.actor_opt.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm = 0.5)
        self.actor_opt.step()
        self.actor_sch.step()
        
        # Compute alpha loss
        
        alpha_loss = -(self.log_alpha * (self.target_entropy + n_log_prob.detach())).mean()
        
        # Update alpha
        
        self.alpha_opt.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm = 0.5)
        self.alpha_opt.step()
        
        self.soft_update(self.critic, self.target_critic, self.tau)
        self.soft_update(self.actor, self.target_actor, self.tau)
        
        self.alpha = self.compute_alpha()
        
        return actor_loss.item(), critic_loss.item(),  self.alpha.item()

### **Setup**

In [394]:
# Hyper params

gamma = 0.997
tau = 0.067

# Opt and Sch

actor_optimizer = optim.AdamW(actor_network.parameters(), lr, weight_decay = 0.001)
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)

critic_optimizer = optim.AdamW(critic_network.parameters(), lr, weight_decay = 0.001)
critic_scheduler = optim.lr_scheduler.CosineAnnealingLR(critic_optimizer, T_max)

# Agent setup

sac_agent = Sac_Agent(action_dim,
                      actor_network,
                      critic_network,
                      target_critic,
                      actor_optimizer,
                      critic_optimizer,
                      actor_scheduler,
                      critic_scheduler,
                      gamma,
                      tau,
                      beta = 0.6,
                      world_loss_function = world_loss_function)

## **Intrinsic reward**

In [395]:
def intrinsic_reward_function(ensemble, encoder, state, action):
    
    latent = encoder(state)
    
    preds, hidden_memory = ensemble(latent, action)
    
    preds = preds.permute(1, 0, 2)
    
    if preds.dim() == 3:                                         # shape [batch, ensemble_size, latent]
        
        disagreement = torch.var(preds, dim = 1)
    
        intrinsic_reward = torch.sum(disagreement, dim = 1, keepdim = True)
    
    elif preds.dim() == 2:                                        # shape [ensemble_size, latent]
    
        disagreement = torch.var(preds, dim = 0)
        
        intrinsic_reward = torch.sum(disagreement, dim = 0)
    
    return intrinsic_reward * 4.0

### **SAC Buffer**

In [396]:
class SAC_BUFFER:
    
    def __init__(self, capacity):
        
        self.capacity = capacity
        self.buffer = deque(maxlen = capacity)
        self.pos = 0
        
    def add(self, state, action, reward, next_state, done):
        
        # Safe conversion to tensor
        
        state = safe_tensor(state)
        action = safe_tensor(action)
        reward = safe_tensor(reward)
        next_state = safe_tensor(next_state)
        done = safe_tensor(done)
        
        experience = (state, action, reward, next_state, done)
        
        if len(self.buffer) < self.capacity:
            
            self.buffer.append(experience)
            
        else:
            
            self.buffer[self.pos] = experience
            self.pos = (1 + self.pos) % self.capacity
            
    def sample(self, batch_size):
        
        indices = np.random.choice(len(self.buffer), batch_size)
        
        states, actions, rewards, next_states, dones = zip(*[self.buffer[ind] for ind in indices])
        
        # Safe conversion
        
        states = torch.stack(states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        next_states = torch.stack(next_states).to(device)
        dones = torch.stack(dones).to(device)
        
        if states.dim() == 3:
            
            states = states.squeeze(1)
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        
        return len(self.buffer)

### **Setup**


In [397]:
sac_buffer = SAC_BUFFER(capacity = 500_000)

## **RollOut Trajectories**

In [398]:
def roll_out(world_model, agent, batch_size, ensemble = ensemble, horizon_length = 7, sac_buffer = sac_buffer, encoder = encode, imagination_buffer = imagination_buffer):
    
    states, _, _, _, _ = sac_buffer.sample(batch_size)
    
    latent = encoder(states)
    
    h = None
    
    total_intrinsive_reward = 0.0
    
    for length in range(horizon_length):
        
        with torch.no_grad():
            
            action, _ = agent.select_action(states)
            
            next_states, h = world_model(latent, action, h)
            
            intrinsic_reward = intrinsic_reward_function(ensemble, encoder, states, action)   
            total_intrinsive_reward += intrinsic_reward.max().item()
            
            dones = torch.zeros_like(intrinsic_reward)
            
            imagination_buffer.add(states, action, intrinsic_reward, next_states, dones)
            
            states = next_states
            
            return total_intrinsive_reward

## **Mix Batch**


In [None]:
class Mix_Batch:
    
    def __init__(self, ratio, imagination_buffer = imagination_buffer, sac_buffer = sac_buffer, encode = encode):
        
        self.ratio = ratio
        self.sac_buffer = sac_buffer
        self.imagination_buffer = imagination_buffer
        self.encode = encode
        
        
    def sample(self, batch_size):
        
        size = int(self.ratio * batch_size)
        actual_batch = batch_size - size
        
        # Sample
        
        i_s, i_a, i_r, i_ns, i_d = self.imagination_buffer.sample(actual_batch)
        
        s, a , r, ns, d = self.sac_buffer.sample(size)
                
        i_s = i_s[:, -1, :]
        i_a = i_a[:, -1, :]
        i_ns = i_ns[:, -1, :]
        i_d = i_d[:, -1, :]
        i_r = i_r[:, -1, :]
        
        r = r.view(-1, 1)
        d = d.view(-1, 1)       
        
        
        ns = self.encode(ns)
        s = self.encode(s)
        i_s = self.encode(i_s)
        
        # Now concate
        
        states = torch.cat([s, i_s], dim = 0).to(device)
        actions = torch.cat([a, i_a], dim = 0).to(device)
        rewards = torch.cat([r, i_r], dim = 0).to(device)
        next_states = torch.cat([ns, i_ns], dim = 0).to(device)
        dones = torch.cat([d, i_d], dim = 0).to(device)
        
        
        return states, actions, rewards, next_states, dones

In [400]:
ratio = 0.25

# Setup

mix_batch = Mix_Batch(ratio)

| Tensor          | Shape               |
|-----------------|---------------------|
| Raw State       | [B, state_dim=17]   |
| Encoded State   | [B, latent_dim=256] |
| GRU Input       | [B, 1, 256+6]       |
| GRU Hidden      | [2, B, 128]         |

## **Training Block**


In [401]:
def shaped_reward(reward , state):
    
    is_falling = (state[2] < 0.8 or          # Height threshold
                 abs(state[3]) > 0.4 or      # Angular velocity X
                 abs(state[4]) > 0.4)        # Angular velocity Y
    
    if is_falling:
        
        recovery_bonus = 3.0 * (1 - abs(state[3]))
        survival_bonus = 0.5 * (0.8 - state[2])
        reward += recovery_bonus + survival_bonus
    
    else:
        
        reward += 0.05
        reward += 1.5 * state[1]
        reward += 0.3 * state[2] ** 2 # Height bonus quadritic
        reward += 1.5 * state[0] * (1 + state[2])   # Forward locomotion
        reward -= 0.1 * np.square(state[3:6]).sum()  # Angular penalty
        reward -= 0.1 * (state[3] ** 2 + state[4] ** 2)
        
    return reward / 10.0

In [402]:
def training_block(env, agent, world_model, world_loss_function,
                   sac_buffer, mix_batch, max_episodes,
                   world_optimizer, world_scheduler, encoder,
                   batch_size, roll_out_break, memory, warm_up):

    world_loss, actor_loss, critic_loss, alpha = 0, 0, 0, 0

    for episode in range(1, max_episodes + 1):
        
        state, _ = env.reset()
        ep_reward = 0
        ep_intrinsive = 0
        done = False

        while not done:
            # Prepare state
            
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Select and execute action
            
            action, _ = agent.select_action(state)
            action = action.detach().cpu().numpy()[0]
            next_state, reward, done, _, _ = env.step(action)

            # Store transition
            
            reward += shaped_reward(reward, state.squeeze(0))
            sac_buffer.add(state, action, reward, next_state, float(done))
            ep_reward += reward
            state = next_state

        # Skip training until buffer is warm
        if episode <= warm_up or len(sac_buffer) < batch_size:
            print(f"Episode: {episode} | Reward: {ep_reward:.2f} | Warm-up phase...")
            continue

        # Train world model
        world_loss = world_loss_function.compute_loss(sac_buffer, batch_size, memory)

        # Perform imagination rollouts at fixed intervals
        if episode % roll_out_break == 0:
            
            intrinsive_reward = roll_out(world_model, agent, batch_size)
            ep_intrinsive += intrinsive_reward

            
            print(f"Episode: {episode} | Intrinsic reward: {ep_intrinsive}")
            
            
        # Update SAC with mixed buffer—only if imagination is ready
        if len(mix_batch.imagination_buffer) >= batch_size:
            actor_loss, critic_loss, alpha = agent.update(mix_batch, batch_size)
        #else:
            #actor_loss, critic_loss, alpha = agent.update(sac_buffer, batch_size)

        # Logging
        print(" - " * 30)
        print(f"Episode: {episode} | Reward: {ep_reward:.2f}  "
              f"Actor Loss: {actor_loss:.4f} | Critic Loss: {critic_loss:.4f} | Alpha: {alpha:.4f}")
        print(f"World Loss: {world_loss:.4f}\n")

In [403]:
training_block(
    env,
    sac_agent,
    world_model,
    world_loss_function,
    sac_buffer,
    mix_batch,
    max_episodes = 3000,
    world_optimizer = optimizer,
    world_scheduler = scheduler,
    encoder = encode,
    batch_size = 256,
    roll_out_break = 2,
    memory = None,
    warm_up = 10
)

Episode: 1 | Reward: 91.21 | Warm-up phase...
Episode: 2 | Reward: 162.63 | Warm-up phase...
Episode: 3 | Reward: 112.28 | Warm-up phase...
Episode: 4 | Reward: 91.38 | Warm-up phase...
Episode: 5 | Reward: 89.37 | Warm-up phase...
Episode: 6 | Reward: 91.11 | Warm-up phase...
Episode: 7 | Reward: 90.45 | Warm-up phase...
Episode: 8 | Reward: 103.72 | Warm-up phase...
Episode: 9 | Reward: 127.68 | Warm-up phase...
Episode: 10 | Reward: 104.48 | Warm-up phase...
Episode: 11 | Reward: 84.89 | Warm-up phase...
Episode: 12 | Reward: 113.31 | Warm-up phase...
Episode: 13 | Reward: 109.25 | Warm-up phase...
Episode: 14 | Intrinsic reward: 0.3629230260848999
 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
Episode: 14 | Reward: 85.52  Actor Loss: 0.0000 | Critic Loss: 0.0000 | Alpha: 0.0000
World Loss: 0.0501

 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
Episode: 15 | Reward: 123.72  Actor Loss: 0.0000 | C

KeyboardInterrupt: 