## **Plan 2 Explore**

Model which learns not to chase but why to chase an unsupervised RL model uses the intrinsive reward signal for mastery

In [91]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import warnings
import gymnasium as gym
from gymnasium.wrappers import RescaleAction
import numpy as np
import copy

In [92]:
warnings.filterwarnings(action = 'ignore', category = UserWarning)

#### **Device setup**

In [93]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


### **Env Setup**

In [94]:
env = gym.make('Humanoid-v5', max_episode_steps = 1000)

env = RescaleAction(env, min_action = -1.0, max_action = 1.0)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

max_action = env.action_space.high[0]

print(f'State dim: {state_dim} | Action dim: {action_dim} | Max action range: {max_action}')

State dim: 348 | Action dim: 17 | Max action range: 1.0


### **Replay Buffer**

This will be used by GRU powered Transformer(encoder - decoder style) to interact with env and save the trajectories in the buffer which will be used to train the SAC / PPO agent.

##### **Tensor Safe Conversion**

In [95]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32)

In [96]:
class replay_buffer:
    
    def __init__(self, capacity):
        
        self.capacity = capacity
        self.pos = 0
        self.buffer = deque(maxlen = capacity)
        
    def add(self, state, action, reward, next_state, done):
        
        # Safe conversion
        
        state = safe_tensor(state).to(device)
        action = safe_tensor(action).to(device)
        reward = safe_tensor(reward).to(device)
        next_state = safe_tensor(next_state).to(device)
        done = safe_tensor(done).to(device)
        
        experience = (state, action, reward, next_state, done)
        
        if len(self.buffer) < self.capacity:
            
            self.buffer.append(experience)
            
        else:
            
            self.buffer[self.pos] = experience
            self.pos = (1 + self.pos) % self.capacity
            
    def sample(self, batch_size):
        
        indices = np.random.choice(len(self.buffer), batch_size)
        
        states, actions, rewards, next_states, dones = zip(*[self.buffer[ind] for ind in indices])
        
        # Safe conversion
        
        states = torch.stack(states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        next_states = torch.stack(next_states).to(device)
        dones = torch.stack(dones).to(device)
        
        return states, actions, rewards, next_states, dones

    def __len__(self):
        
        return len(self.buffer)        

In [97]:
imagination_buffer = replay_buffer(capacity = 500_000)

## **GRU Transformer**

#### **MLP Encoder**

In [98]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, latent_dim, head_1, head_2, head_3, head_4):
        super(Encoder, self).__init__()
        
        self.layer = nn.Sequential(
            
            nn.Linear(input_dim, head_1),
            nn.SiLU(),
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, latent_dim)
        )
        
    def forward(self, x):
        
        x = self.layer(x)
        
        return x

### **GRU Powered Predictor**

In [99]:
class World_GRU_Model(nn.Module):
    
    def __init__(self, latent_dim, action_dim, h_1, h_2, h_3):
        super(World_GRU_Model, self).__init__()
        
        self.gru = nn.GRU(latent_dim + action_dim, h_1, num_layers = 2, batch_first = True)
        
        self.projection = nn.Sequential(
            
            nn.Linear(h_1, h_2),
            nn.SiLU(),
            
            nn.Linear(h_2, h_3),
            nn.SiLU(),
            
            nn.LayerNorm(h_3),
            nn.Linear(h_3, latent_dim)
        )
        
    def forward(self, state, action, h = None):
        
        #print(f'Shape of state and action at gru model: {np.shape(state)} | {np.shape(action)}')
        
        if state.dim() == 4:  # [ensemble_size, batch, time_step, latent_dim]
            
            state = state.squeeze(0)  # Remove ensemble dimension
            
        #state = state.squeeze(3)    
        
        x = torch.cat([state, action], dim = -1)      # state and action are in shape [Batch, time_Step, latent_dim / action_dim] States are compressed for time step here
        
        out, h = self.gru(x, h)                 # h is past memory fed so that it does not start from 0
        
        pred_state = self.projection(out)
        
        return pred_state, h

#### **Setup**

In [100]:
# Assembly

latent_dim = 128
head_1 = 256
head_2 = 512
head_3 = 256
head_4 = 256

h_1 = 128
h_2 = 256
h_3 = 512

## Encoder

encoder = Encoder(state_dim, latent_dim, head_1, head_2, head_3, head_4).to(device)

print(" - " * 80)

print(encoder)

# GRU World Model

world_model = World_GRU_Model(latent_dim, action_dim, h_1, h_2, h_3).to(device)

print(" - " * 80)

print(world_model)

 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
Encoder(
  (layer): Sequential(
    (0): Linear(in_features=348, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): SiLU()
    (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (5): Linear(in_features=512, out_features=256, bias=True)
    (6): SiLU()
    (7): Linear(in_features=256, out_features=256, bias=True)
    (8): SiLU()
    (9): Linear(in_features=256, out_features=128, bias=True)
  )
)
 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
World_GRU_Model(
  (gru): GRU(145, 128, num

### **World Model Loss**

Here we will be using the RSS(Residual Squared) Formula which is 

**[*True Latent State - Pred Latent State*]^2**

This will give an signal for world model to maximize uncertainity and dynamics

In [101]:
class World_Model_Loss:
    
    def __init__(self, model, optimizer, scheduler, Encoder):
        
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.encoder = Encoder
        
    def compute_loss(self,buffer, batch_size):
        
        states, actions, _, next_states, _ = buffer.sample(batch_size)
        
        states = states.to(device)
        actions = actions.to(device)
        next_states = next_states.to(device)

        #print(f"Input shapes - states: {states.shape}, next_states: {next_states.shape}")
        
        #latent_state = encoder(states)
            
        #latent_next_state = encoder(next_states)
            
        #print(f'Shape of latent next state: {latent_next_state.shape} | Latent state: {latent_state.shape}')    
            
        pred_latent_next_state, memory = self.model(states, actions)
        
        loss = F.mse_loss(next_states, pred_latent_next_state)
            
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = 0.5)
        self.optimizer.step()
        self.scheduler.step()
            
        return loss, memory

### **Hyper Params and Setup**

In [102]:
## Hyper params

lr = 3e-4
T_max = 1000


# Optimizer

optimizer = optim.AdamW(world_model.parameters(), lr, weight_decay = 0.001)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)

## **Ensemble Dynamics**

For disagreement intrinsic reward signal as more disagreement more signal to learn dynamics

In [103]:
def init_weights(m):
    
    if isinstance(m, nn.Linear):
        
        nn.init.kaiming_normal_(m.weight)
        
        if m.bias is not None:
            
            nn.init.zeros_(m.bias)

In [104]:
class Ensemble_Dynamics(nn.Module):
    
    def __init__(self, ensemble_size, base_model = world_model):
        super().__init__()
        
        self.ensemble = nn.ModuleList(
            [copy.deepcopy(base_model) for _ in range(ensemble_size)]
        )
        
        
        for model in self.ensemble:
            
            model.apply(init_weights)

    def forward(self, latent_state, action):
        
        preds = []
        
        for model in self.ensemble:
            
            pred, _ = model(latent_state, action)
            
            preds.append(pred)    
        
        return torch.stack(preds)   # Shape [ensmeble_size, batch_size, latent_dim]

### **Setup**

In [105]:
# Size of ensemble

ensemble_size = 4


ensemble = Ensemble_Dynamics(ensemble_size)

print(ensemble)

# World loss setip

world_loss = World_Model_Loss(world_model, optimizer, scheduler, encoder)

Ensemble_Dynamics(
  (ensemble): ModuleList(
    (0-3): 4 x World_GRU_Model(
      (gru): GRU(145, 128, num_layers=2, batch_first=True)
      (projection): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): SiLU()
        (2): Linear(in_features=256, out_features=512, bias=True)
        (3): SiLU()
        (4): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (5): Linear(in_features=512, out_features=128, bias=True)
      )
    )
  )
)


## **Intrinsive Reward**

It will be learning signal 

In [106]:
def intrinsive_reward_signal(latent_state):

    disagreement = torch.var(latent_state, dim = 0)              # Compute variance across (batch, latent_state)
    
    intrinsive_reward = torch.mean(disagreement, dim = -1)   # Compute mean per batch
    
    return intrinsive_reward * 4.0

# **SAC Agent**

### **SAC BUFFER**

In [107]:
class SAC_BUFFER:
    
    def __init__(self, capacity):
        
        self.capacity = capacity
        self.pos = 0
        self.buffer = deque(maxlen = capacity)
        
    def add(self, state, action, reward, next_state, done):
        
        # Safe conversion
        
        state = safe_tensor(state).to(device)
        action = safe_tensor(action).to(device)
        reward = safe_tensor(np.array(reward)).to(device)
        next_state = safe_tensor(next_state).to(device)
        done = safe_tensor(float(done)).to(device)
        
        experience = (state, action, reward, next_state, done)
        
        if len(self.buffer) < self.capacity:
            
            self.buffer.append(experience)
            
        else:
            
            self.buffer[self.pos] = experience
            self.pos = (1 + self.pos) % self.capacity
            
    def sample(self, batch_size):
        
        indices = np.random.choice(len(self.buffer), batch_size)
        
        states, actions, rewards, next_states, dones = zip(*[self.buffer[ind] for ind in indices])
        
        states = torch.stack(states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        next_states = torch.stack(next_states).to(device)
        dones  = torch.stack(dones).to(device)
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        
        return len(self.buffer)

In [108]:
sac_buffer = SAC_BUFFER(capacity = 500_000)

### **Agent Network**

In [109]:
class Feature_Extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, head):
        super().__init__()
        
        self.feature = nn.Sequential(
            
            
            nn.Linear(input_dim, head),
            nn.SiLU(),
            
            
            nn.Linear(head, head),
            nn.SiLU(),
            
            
            nn.Linear(head, head),
            nn.SiLU(),
            
            nn.Linear(head, head),
            nn.SiLU(),
            
            nn.Linear(head, output_dim)
            )
        
    def forward(self, x):
        
        return self.feature(x)

##### **Actor network**


In [110]:
class Actor_Network(nn.Module):
    
    def __init__(self, latent_dim, action_dim, head_1, head_2, head_3, head_4, head, max_action = max_action):
        super(Actor_Network, self).__init__()
        
        self.max_action = max_action
        
        # Pass to Feature network
        
        self.feature = Feature_Extractor(latent_dim, head_1, head)
        
        # Pass to norm
        
        self.norm = nn.LayerNorm(head_1)
        
        # Pass norm to MHA
        
        self.mha = nn.MultiheadAttention(head_1, num_heads = 4, batch_first = True)
        
        # Pass to actor sequence
        
        self.actor = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            
            nn.Linear(head_3, head_4),
            nn.SiLU()
        )
        
    
        self.mu = nn.Linear(head_4, action_dim)
        
        self.log_std = nn.Linear(head_4, action_dim)
    
    def forward(self, state):
        
        # Pass state -> Feature extractor
        
        feature = self.feature(state)
        
        # Feature -> Norm
        
        norm = self.norm(feature)
        
        # Norm -> Unsqueeze(1) shape(batch, 1, embedd) -> Multi head attention
        if norm.dim() == 2:
            norm = norm.unsqueeze(1)  # [batch, 1, embed_dim]
        
        attn, _ = self.mha(norm, norm, norm)
        
        # Attn -> Squeeze(1) shape (batch, embedd) -> Actor sequence
        
        attn = attn.squeeze(1)
        
        x = self.actor(attn)
        
        # x -> mu (mean) and x -> log std
        
        mu = self.mu(x)
        
        log_std = self.log_std(x)
        
        # Smooth scaling
        
        mu = torch.tanh_(mu)                # range [-1.0, 1.0]
        log_std = torch.tanh_(log_std)      # range [-1.0, 1.0]
        log_std = log_std.clamp(-10, 2)
        std = torch.exp(log_std)
                
        # Reparameterization trick
        
        normal = torch.distributions.Normal(mu, std)
        
        z = normal.rsample()
        
        tanh_z = torch.tanh_(z)
        
        log_prob = normal.log_prob(z)
        
        action = self.max_action * tanh_z
        
        # Squashing
        
        squash = 2 * (torch.log(torch.tensor(2.0, device=z.device)) - z - F.softplus(-2 * z))
        
        log_prob = log_prob - squash
        
        log_prob = torch.sum(log_prob, dim = -1, keepdim = True)
        
        return action, log_prob

### **Critic Network**

In [111]:
class Critic_Network(nn.Module):
    
    def __init__(self, latent_dim, action_dim, head_1, head_2, head_3, head_4, head):
        super(Critic_Network, self).__init__()
        
        # Pass [state, action] -> Feature Extractor
        
        self.feature = Feature_Extractor(input_dim = latent_dim + action_dim, output_dim = head_1, head = head_1)
        
        # Norm and MHA
        
        self.norm = nn.LayerNorm(head_1)
        
        self.mha = nn.MultiheadAttention(head_1, num_heads = 4, batch_first = True)
        
        # Critic network 1
        
        self.critic_1 = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
        self.critic_2 = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, 1)
        )
        
    def forward(self, state, action):
        
        cat = torch.cat([state, action], dim = -1)
        
        #print(f"[Critic] state shape: {state.shape}, action shape: {action.shape}, cat: {cat.shape}")

        # cat -> Feature
        
        feature_1 = self.feature(cat)
        feature_2 = self.feature(cat)
        
        # Feature -> Norm
        
        norm = self.norm(feature_1)
        norm_2 = self.norm(feature_2)
        
        # Norm -> unsqueeze (batch, 1 , embedd) -> MHA
        if norm.dim() == 2:
            
            norm = norm.unsqueeze(1)  # [batch, 1, embed_dim]
        
            norm_2 = norm_2.unsqueeze(1)
        
        attn_1, _ = self.mha(norm, norm, norm)
        attn_2, _ = self.mha(norm_2, norm_2, norm_2)
        
        # Attn -> Squeeze (batch, embedd) -> Critic
        
        attn_1 = attn_1.squeeze(1)
        attn_2 = attn_2.squeeze(1)
        
        q_1 = self.critic_1(attn_1)
        q_2 = self.critic_2(attn_2)
        
        return q_1, q_2
        

#### **Setup**

In [112]:
# Assembly

head = 128
latent_dim = 128

# Actor network

actor_network = Actor_Network(latent_dim, action_dim, head_1, head_2, head_3, head_4, head).to(device)

print(" - " * 70)

print(actor_network)

# Critic network

critic_network = Critic_Network(latent_dim, action_dim, head_1, head_2, head_3, head_4, head).to(device)

print(" - " * 70)

print(critic_network)

# Target critic

target_critic = copy.deepcopy(critic_network)

 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
Actor_Network(
  (feature): Feature_Extractor(
    (feature): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): SiLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
      (5): SiLU()
      (6): Linear(in_features=128, out_features=128, bias=True)
      (7): SiLU()
      (8): Linear(in_features=128, out_features=256, bias=True)
    )
  )
  (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (mha): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
  )
  (actor): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_fe

### *Soft Update*

Softly updating critics

In [113]:
def soft_update(source, target, tau = 0.067):
    
    for param, target_param in zip(source.parameters(), target.parameters()):
        
        param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

## **Agent Initialize**

In [114]:
class SAC_AGENT:
    
    def __init__(self, actor, critic, target_critic, actor_optimizer, actor_scheduler, critic_optimizer, critic_scheduler, gamma, action_dim = action_dim, world_loss = world_loss):
    
        # Network
        
        self.actor = actor
        self.critic = critic
        self.target_critic = target_critic
        
        # Actor 
        
        self.actor_optimizer = actor_optimizer
        self.actor_scheduler = actor_scheduler
        
        # Critic
        
        self.critic_optimizer = critic_optimizer
        self.critic_scheduler = critic_scheduler
        
        # Hyper param
        
        self.gamma = gamma
        self.world_loss = world_loss
        
        self.target_entropy = - action_dim * 2.5
        self.log_alpha = torch.tensor(np.log(0.2), requires_grad = True, device = device)
        self.alpha_optimizer = optim.AdamW([self.log_alpha], lr = 1e-4, weight_decay = 0.001)
        self.alpha_min = 0.1
        
    def compute_alpha(self):
        
        self.alpha = self.log_alpha.exp().detach()
        self.alpha = self.alpha.clamp(min = self.alpha_min, max = 0.2)
        
        return self.alpha
    
    def update_target(self):
        
        return soft_update(self.critic, self.target_critic)
    
    def select_action(self, state):
        
        #print(f'Shape of state: {np.shape(state)}')
        
        action, _ = self.actor(state)
        
        return action, _ 
        
    def update(self, replay_buffer,  batch_size):
        
        # Sample from buffer
        
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        
        
        next_states = next_states.to(device)
        rewards = rewards.to(device)
        dones = dones.to(device)
        
        
        # World Model Loss
        
        loss, memory = self.world_loss.compute_loss(replay_buffer, batch_size)
        
        # compute target val
        
        with torch.no_grad():
            
            new_action, _ = self.select_action(next_states)
            target_1, target_2 = self.target_critic(next_states, new_action)    
            target_q = (0.75 * torch.min(target_1, target_2) + 0.25 * torch.max(target_1, target_2))
            
            target = rewards + self.gamma * (1 - dones) * target_q
            target = target.detach()
        
        # compute current target
        
        critic_1 , critic_2 = self.critic(states, actions)
        loss_1 = F.smooth_l1_loss(critic_1, target)
        loss_2 = F.smooth_l1_loss(critic_2, target)
        
        critic_loss = loss_1 + loss_2
        
        # Update Critics
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward(retain_graph = True)
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm = 0.5)
        self.critic_optimizer.step()
        self.critic_scheduler.step()
        
        
        # Actor loss
        
        next_action, next_log_prob = self.select_action(states)
        #log_prob = next_log_prob.detach()
        
        q_1, q_2 = self.critic(states, next_action)
        q_pi = (0.75 * torch.min(q_1, q_2) + 0.25 * torch.max(q_1, q_2))
        
        actor_loss = (self.compute_alpha() * next_log_prob - q_pi).mean()
            
        # Update Actor
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm = 0.5)
        self.actor_optimizer.step()
        self.actor_scheduler.step()
        
        # Alpha loss
        
        alpha_loss = - (self.log_alpha * (next_log_prob.detach() + self.target_entropy)).mean()
        
        # update alpha
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm = 0.5)
        self.alpha_optimizer.step()
        
        # Update critics
        
        self.update_target()
        
        #print('Agent updated successfully in class')
        return actor_loss.item(), critic_loss.item(), self.alpha.item(), loss.item()

In [115]:

# Hyper Params

gamma = 0.997

# Actor

actor_optimizer = optim.AdamW(actor_network.parameters(), lr = 3e-4, weight_decay = 0.001)
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(actor_optimizer, T_max)

# Critic

critic_optimizer = optim.AdamW(critic_network.parameters(), lr = 3e-4, weight_decay = 0.001)
critic_scheduler = optim.lr_scheduler.CosineAnnealingLR(critic_optimizer, T_max)

# Agent Setuo

agent = SAC_AGENT(actor_network, critic_network, target_critic, actor_optimizer, actor_scheduler, critic_optimizer, critic_scheduler, gamma)

## **Mix Batch**

Mixing Imaginary and Real batches 

In [116]:
class Mix_Batch:
    
    def __init__(self, imagine_buffer, real_buffer, ratio):
        
        self.real_buffer = real_buffer
        self.imagine_buffer = imagine_buffer
        self.ratio = ratio
        
    def sample(self, batch_size):
        real_size = int(self.ratio * batch_size)
        imag_size = batch_size - real_size

        # Sample from buffers
        
        i_s, i_a, i_r, i_ns, i_d = self.imagine_buffer.sample(imag_size)
        s, a, r, ns, d = self.real_buffer.sample(real_size)

        # In mix_batch.sample():
        
        print(f"Imaginary state shape: {i_s.shape}, Real state shape: {s.shape}")



        # Extract final transition from each rollout in imagination buffer
        # Original: [imag_size, rollout_len, 1, feat_dim]
        i_s  = i_s.squeeze(2)[:, -1, :]     # → [imag_size, feat_dim]
        i_a  = i_a.squeeze(2)[:, -1, :]     # → [imag_size, act_dim]
        i_r  = i_r.squeeze(2)[:, -1]        # → [imag_size]
        i_ns = i_ns.squeeze(2)[:, -1, :]    # → [imag_size, feat_dim]
        i_d  = i_d.squeeze(2)[:, -1]        # → [imag_size]

        # Squeeze real data if needed
        if s.dim() == 3:
            s = s.squeeze(1)
        if ns.dim() == 3:
            ns = ns.squeeze(1)

        # Encode real observations
        s_encoded  = encoder(s.to(device))     # → [real_size, 128]
        ns_encoded = encoder(ns.to(device))    # → [real_size, 128]

        #print(f"[After squeeze] i_s: {i_s.shape}, s: {s_encoded.shape}")

        # Merge real and imagined
        states      = torch.cat([i_s.to(device), s_encoded], dim=0)
        actions     = torch.cat([i_a.to(device), a.to(device)], dim=0)
        rewards     = torch.cat([i_r.to(device), r.to(device)], dim=0)
        next_states = torch.cat([i_ns.to(device), ns_encoded], dim=0)
        dones       = torch.cat([i_d.to(device), d.to(device)], dim=0)

        print(f"Encoded real state shape: {s_encoded.shape}")

        # In world model training:
        print(f"Final mixed states shape: {states.shape}")

        #print(f"[Final shapes] states: {states.shape}, actions: {actions.shape}, rewards: {rewards.shape}")
        return states, actions, rewards, next_states, dones




In [117]:
mix_batch = Mix_Batch(imagination_buffer, sac_buffer, ratio = 0.5)

## **Rollout**

In [118]:
def rollout_trajectories(start_state, agent = agent, ensemble = ensemble , imagination_buffer = imagination_buffer, horizon = 5):
    
    state = start_state
    
    for _ in range(horizon):
        
        with torch.no_grad():
            action, _ = agent.select_action(state.squeeze(1))

            action = action.unsqueeze(1)

            ensemble_preds = ensemble(state, action)
            
            next_state = ensemble_preds.mean(dim = 0)
            
            reward = intrinsive_reward_signal(ensemble_preds)
            
            done = torch.zeros_like(reward)
            
            imagination_buffer.add(state, action, reward, next_state, done)
            
            state = next_state 
            
            #print(f'Rollout succesfull')