# **Model Based Policy Optimization**

Integrating Enhanced SAC to a world model and training sac on rollouts so it will be actually dreaming while the world model interacts with the env.

In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import warnings
from collections import deque
import copy

import numpy as np

import gymnasium as gym
from gymnasium.wrappers import RescaleAction
import matplotlib.pyplot as plt

In [62]:
warnings.filterwarnings(action = 'ignore', category = UserWarning)

##  **Device Setup**

In [63]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device working on: {device}')

Device working on: cuda


## *Env Setup*

We will be working on **Walker 2d** so it is easy to benchmarks against **pytorch and tensorflow** implementation

In [64]:
env = gym.make('Humanoid-v5', max_episode_steps = 1000)

env = RescaleAction(env, min_action = -1.0, max_action = 1.0)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = env.action_space.high[0]

print(f'State dim: {state_dim}\nAction dim: {action_dim} | Max Action range: {max_action}')

State dim: 348
Action dim: 17 | Max Action range: 1.0


# **World Model**

Model which interacts with the enviornment 

In [65]:
class World_Model(nn.Module):
    
    def __init__(self, head_1, head_2, head_3, head_4, state_dim = state_dim, action_dim = action_dim):
        super(World_Model, self).__init__()
        
        self.world = nn.Sequential(
            
            # first Layer combination of state and action dim 
            
            nn.Linear(state_dim + action_dim, head_1),
            nn.SiLU(),
            
            # Layer Norm for stabilization and proceeding with increment bottleneck technqiue
            
            nn.LayerNorm(head_1),
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            
            # third layer no layer norm for perfect gradients flow without major resitriction
            
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            
            # Forth layer cause head 3 will be a major assembly of neuron
            
            nn.LayerNorm(head_3),
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            
            # Output and final layer
            
            nn.Linear(head_4, state_dim + 1)
        )
        
    def forward(self, state, action):
        
        x = torch.cat([state, action], dim = -1)
        
        x_2 = self.world(x)
        
        next_state = x_2[:, :-1]
        reward = x_2[:, -1:]
        
        return next_state, reward

## **World Model 3**

In [66]:
class World_Model_3(nn.Module):
    
    def __init__(self, head_1, head_2, head_3, head_4, state_dim = state_dim, action_dim = action_dim):
        super(World_Model_3, self).__init__()
        
        self.projection = nn.Linear(state_dim + action_dim, head_1)
        
        self.norm = nn.LayerNorm(head_1)
        
        self.MHA = nn.MultiheadAttention(head_1, num_heads = 8, batch_first = True)
        
        self.fc = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.SiLU(),
            
            nn.LayerNorm(head_2),
            nn.Linear(head_2, head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.SiLU(),
            
            nn.Linear(head_4, state_dim + 1)
        )
        
    def forward(self, state, action):
        
        x = torch.cat([state, action], dim = -1)
        
        x_2 = self.projection(x)
        x_3 = self.norm(x_2)
        
        # Preparing for MHA
        
        x_3 = x_3.unsqueeze(1)
        attn, _ = self.MHA(x_3, x_3, x_3)
        
        # Preparing for FC 
        
        attn = attn.squeeze(1)
        
        x_4 = self.fc(attn)
        
        next_state = x_4[:, :-1]
        reward = x_4[:, -1:]
        
        return next_state, reward

#### **Initialize**

In [67]:
# Neuron assembly 

head_1 = 128
head_2 = 256
head_3 = 512
head_4 = 256

# world Model 

world_model = World_Model(head_1, head_2, head_3, head_4).to(device)

print(world_model)

# World Model 2 (Using an ensemble of model)

world_model_2 = copy.deepcopy(world_model).to(device)

print(" - " * 80)
print(world_model_2)

# World model 3

world_model_3 = World_Model_3(head_1, head_2, head_3, head_4).to(device)

print(" - " * 80)
print(world_model_3)

World_Model(
  (world): Sequential(
    (0): Linear(in_features=365, out_features=128, bias=True)
    (1): SiLU()
    (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): SiLU()
    (5): Linear(in_features=256, out_features=512, bias=True)
    (6): SiLU()
    (7): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (8): Linear(in_features=512, out_features=256, bias=True)
    (9): SiLU()
    (10): Linear(in_features=256, out_features=349, bias=True)
  )
)
 -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  -  - 
World_Model(
  (world): Sequential(
    (0): Linear(in_features=365, out_features=128, bias=True)
    (1): SiLU()
    (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (3): Linear(in_features=128, out_fe

## **Ensemble**

In [68]:
class world_model_ensemble(nn.Module):
    
    def __init__(self, model_1, model_2, model_3):
        super().__init__()
        
        self.models = nn.ModuleList([
            
            model_1,
            model_2, 
            model_3
        ])
        
    def forward(self, state, action):
        
        pred = [model(state, action) for model in self.models]
        
        next_state = torch.stack([p[0] for p in pred], dim = 0)
        
        reward = torch.stack([p[1] for p in pred], dim = 0)
    
        return next_state.mean(dim = 0), reward.mean(dim = 0)

In [69]:
ensemble = world_model_ensemble(world_model, world_model_2, world_model_3).to(device)

print(ensemble)

world_model_ensemble(
  (models): ModuleList(
    (0-1): 2 x World_Model(
      (world): Sequential(
        (0): Linear(in_features=365, out_features=128, bias=True)
        (1): SiLU()
        (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (3): Linear(in_features=128, out_features=256, bias=True)
        (4): SiLU()
        (5): Linear(in_features=256, out_features=512, bias=True)
        (6): SiLU()
        (7): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (8): Linear(in_features=512, out_features=256, bias=True)
        (9): SiLU()
        (10): Linear(in_features=256, out_features=349, bias=True)
      )
    )
    (2): World_Model_3(
      (projection): Linear(in_features=365, out_features=128, bias=True)
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (MHA): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (fc): Sequential(
        (0):

# **Dynamic Buffer**

It will be saving the trajectories for the world model interacting with the env which will be then used in cut mix augmentation and training SAC.


In [70]:
class World_Buffer():
    
    def __init__(self, capacity, batch_size):
        
        self.ptr = 0
        self.batch_size = batch_size
        self.buffer = deque(maxlen = capacity)
        self.capacity = capacity
        
    def add(self, state, action, reward, next_state, done):
        
        # convert to tensors
        
        state = torch.tensor(state, dtype = torch.float32).to(device)
        next_state = torch.tensor(next_state, dtype = torch.float32).to(device)
        action = torch.tensor(action, dtype = torch.float32).to(device)
        reward = torch.tensor(float(reward), dtype = torch.float32).to(device)
        done = torch.tensor(float(done), dtype = torch.float32).to(device)
        
        
        experience = (state, action, reward, next_state, done)           # S A R S A
        
        # If len self.buffer = capacity
        
        if len(self.buffer) < self.capacity:
            
            self.buffer.append(experience)
            
        else:
            self.buffer[self.ptr] = experience
            self.ptr = (1 + self.ptr) % self.capacity                    # Roll out for random attentions
            
            
    def __len__(self):
        
        return len(self.buffer)
    
    
    def sample(self, batch_size):
        
        indices = np.random.choice(len(self.buffer), batch_size)
        
        states, actions, rewards, next_states, dones = zip(*[self.buffer[ind] for ind in indices])
        
        # Safe conversion to tensor 
        
        states = torch.stack(states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        next_states = torch.stack(next_states).to(device)
        dones = torch.stack(dones).to(device)
        
        return (states, actions, rewards, next_states, dones)       

In [71]:
world_buffer = World_Buffer(capacity = 500_000, batch_size = 512)

# **World Loss**

In [72]:
class World_loss():
    
    def __init__(self, world_optimizer, world_scheduler, ensemble):
        
        self.model = ensemble
        self.loss_func = nn.MSELoss()
        self.optimizer = world_optimizer
        self.scheduler = world_scheduler
        
    def compute_loss(self, state, action, reward, next_state, done):
        
        # Safe Convert them to tensor

        
        state = state.to(device)
        action = action.to(device)
        reward = reward.to(device)
        next_state = next_state.to(device)
        done = done.to(device)
        
        # Shape check
        
        if state.dim() == 3:
            state = state.squeeze(1)
        
        if action.dim() == 3:
            action = action.squeeze(1)
            
        # Predict next state and next reward
        
        pred_state, pred_reward = self.model(state, action)
        
        # compute loss
        
        state_loss = self.loss_func(pred_state, next_state)
        reward_loss = self.loss_func(pred_reward.squeeze(-1), reward)
        total_loss = state_loss + reward_loss
        
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = 0.5)
        self.optimizer.step()
        self.scheduler.step()
        
        return total_loss.item(), state_loss.item(), reward_loss.item()        

#### **Initialize**

In [73]:

# Optimizer

world_optimizer = optim.AdamW(world_model.parameters(), lr = 1e-4, weight_decay = 0.001)

# Scheduler 

world_scheduler = optim.lr_scheduler.CosineAnnealingLR(world_optimizer, T_max = 30)

# Loss

world_loss = World_loss(world_optimizer, world_scheduler, world_model)

# **SAC Agent**

This is an different SAC approach and enhanced approach used 

In [74]:
class Feature_extractor(nn.Module):
    
    def __init__(self, input_dim, output_dim, h):
        super(Feature_extractor, self).__init__()
        
        self.feature = nn.Sequential(
            
            nn.Linear(input_dim, h),
            nn.SiLU(),
            
            nn.LayerNorm(h),
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.Linear(h,h),
            nn.SiLU(),
            
            
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.LayerNorm(h),
            nn.Linear(h,h),
            nn.SiLU(),
            
            
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.LayerNorm(h),
            nn.Linear(h,h),
            nn.SiLU(),
            
            nn.Linear(h,output_dim),
            nn.SiLU()
        )
        
    def forward(self, x):
        
            x = self.feature(x)
            
            return x

In [75]:
class Actor_network(nn.Module):
    
    def __init__(self, state_dim, action_dim, n1, n2, n3, n4, h1, max_action = max_action):
        super(Actor_network, self).__init__()
        
        self.max_action = max_action
        
        # Send to Feature Extractor
        
        self.feature = Feature_extractor(input_dim = state_dim, output_dim = n1, h = h1)
        
        # Introduction of MHA
        
        self.norm = nn.LayerNorm(n1)
        self.mha = nn.MultiheadAttention(embed_dim = n1, num_heads = 8, batch_first = True)
        
        # Actor network
        
        self.actor = nn.Sequential(
            
            nn.Linear(n1, n2),
            nn.SiLU(),
            
            nn.LayerNorm(n2),
            nn.Linear(n2, n3),
            nn.SiLU(),
            
            nn.Linear(n3, n4)
        )
        
        # Mean
        
        self.mu = nn.Linear(n4, action_dim)
        
        # Log Std
        
        self.log_std = nn.Linear(n4, action_dim)
        
    def forward(self, state):
        
        # Pass to feature
        
        feature = self.feature(state)
        
        # pass to norm
        
        x = self.norm(feature)
        
        # pass to MHA
        x = x.unsqueeze(1)
        x_2, _  = self.mha(x, x, x)
        x_2 = x_2.squeeze(1)
        
        # pass to actor
        
        x_3 = self.actor(x_2)
        
        # Pass to mu
        
        mu = self.mu(x_3)
        
        # pass to log std
        
        log_std = self.log_std(x_3)
        
        # smooth scaling
        
        mu = torch.tanh_(mu)                # range[-1.0, 1.0]
        log_std = torch.tanh_(log_std)      # range[-1.0, 1.0]
        log_std = log_std.clamp(min = -5, max = 2)
        std = torch.exp(log_std)
        
        # Reparameterization trick
        
        normal = torch.distributions.Normal(mu, std)
        z = normal.rsample()
        tanh_z = torch.tanh_(z)
        log_prob = normal.log_prob(z)
        action = tanh_z * self.max_action
        
        # Squashing
        squash = 2 * (torch.log(torch.tensor(2.0, device=z.device)) - z - F.softplus(-2 * z))
        log_prob = log_prob - squash
        log_prob = torch.sum(log_prob, dim = 1, keepdim = True)
        
        return action, log_prob

In [76]:
class Critic_network(nn.Module):
    
    def __init__(self, state_dim, action_dim, n1, n2, n3, n4, h1):
        super().__init__()
        
        # Pass to feature
        
        self.feature = Feature_extractor(state_dim + action_dim, n1, h1)
        
        # pass to mha
        
        self.norm = nn.LayerNorm(n1)
        self.mha = nn.MultiheadAttention(n1, 8, batch_first = True)
        
        # pass to critic
        
        # critic 1
        
        self.critic_1 = nn.Sequential(
            
            nn.Linear(n1, n2),
            nn.SiLU(),
            
            nn.LayerNorm(n2),
            nn.Linear(n2, n3),
            nn.SiLU(),
            
            nn.Linear(n3, n4),
            nn.SiLU(),
            
            nn.Linear(n4, 1)
        )
        
        # critic 2
        
        self.critic_2 = nn.Sequential(
            
            nn.Linear(n1, n2),
            nn.SiLU(),
            
            nn.LayerNorm(n2),
            nn.Linear(n2, n3),
            nn.SiLU(),
            
            nn.Linear(n3, n4),
            nn.SiLU(),
            
            nn.Linear(n4, 1)
        )
        
    def forward(self, state, action):
        
        # shape check
        
        if state.dim() == 3:
            state = state.squeeze(-1)
            
        if action.dim() == 3:
            action = action.squeeze(-1)
            
        x = torch.cat([state, action], dim = -1)
        
        # Pass to feature
        
        feature = self.feature(x)
        
        # Pass to MHA
        
        x_2 = self.norm(feature)
        x_2 = x_2.unsqueeze(1)
        x_3, _ = self.mha(x_2, x_2, x_2)
        x_3 = x_3.squeeze(1)
        
        # Pass to critic
        
        q_1 = self.critic_1(x_3)
        q_2 = self.critic_2(x_3)
        
        return q_1, q_2

####  **Initialize**

In [77]:
# Assembly of neurons

n1 = 128
h1 = 128
n2 = 256
n3 = 512
n4 = 256

# Initialize

actor_network = Actor_network(state_dim, action_dim, n1, n2, n3, n4, h1).to(device)

print(actor_network)

print('-------------------------------------------------------------------------------')

critic_network = Critic_network(state_dim, action_dim, n1, n2, n3, n4, h1).to(device)

print(critic_network)


# target network
target_critic = copy.deepcopy(critic_network).to(device)

Actor_network(
  (feature): Feature_extractor(
    (feature): Sequential(
      (0): Linear(in_features=348, out_features=128, bias=True)
      (1): SiLU()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (3): Linear(in_features=128, out_features=128, bias=True)
      (4): SiLU()
      (5): Linear(in_features=128, out_features=128, bias=True)
      (6): SiLU()
      (7): Linear(in_features=128, out_features=128, bias=True)
      (8): SiLU()
      (9): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (10): Linear(in_features=128, out_features=128, bias=True)
      (11): SiLU()
      (12): Linear(in_features=128, out_features=128, bias=True)
      (13): SiLU()
      (14): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (15): Linear(in_features=128, out_features=128, bias=True)
      (16): SiLU()
      (17): Linear(in_features=128, out_features=128, bias=True)
      (18): SiLU()
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affin

### **Soft Update**

For smooth updates

In [78]:
def soft_update(source, target, tau = 0.067):
    
    for param , target_param in zip(source.parameters(), target.parameters()):
        
        target_param.data.copy_(tau * param.data + (1-tau) * target_param.data)

# **Sac_Agent**

In [79]:
class SAC_Agent:
    
    def __init__(self, actor_network, critic_network, target_critic, actor_optimizer, actor_scheduler, critic_optimizer, critic_scheduler, gamma):
        
        # Define networks
        
        self.actor = actor_network
        self.critic = critic_network
        self.target_critic = target_critic
        
        # Define hyperparams
        
        self.gamma = gamma
        
        # Define optimizer and scheduler
        
        self.actor_optimizer = actor_optimizer
        self.actor_scheduler = actor_scheduler
        self.critic_optimizer = critic_optimizer
        self.critic_scheduler = critic_scheduler
        
        # Auto tuning
        
        self.target_entropy = - action_dim * 1.5
        self.log_alpha = torch.tensor(np.log(0.2), requires_grad = True, device = device)
        self.alpha_min = 0.1
        self.alpha_optimizer = optim.AdamW([self.log_alpha], lr = 1e-4, weight_decay = 0.001)
        
    def compute_alpha(self):
        
        self.alpha = self.log_alpha.exp().detach()
        self.alpha = self.alpha.clamp(min = self.alpha_min, max = 0.2)
        
        return self.alpha
    
    def update_target(self):
        
        return soft_update(self.critic, self.target_critic)
    
    def select_action(self, state):
        
        state = torch.tensor(state, dtype = torch.float32).to(device)
        
        if state.dim() == 1:
            state = state.unsqueeze(0)
        
        action, log_prob = self.actor(state)
        
        return action, log_prob
        
    def update(self, replay_buffer, batch_size):
        
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)
               
        # Move to device
        
        state = state.to(device)
        action = action.to(device)
        reward = reward.to(device)
        next_state = next_state.to(device)
        done = done.to(device)
        
        # compute target q value
        
        with torch.no_grad():
            
            next_action, _ = self.select_action(next_state)
            
            if next_action.dim() == 1:
                next_action.unsqueeze(0)
            
            target_1, target_2 = self.target_critic(next_state, next_action)
            target_Q = ((0.75 * torch.min(target_1, target_2)) + (0.25 * torch.max(target_1, target_2)))
            target_value = reward + self.gamma * (1 - done) * target_Q
            
            
        # Compute current q value
        
        current_q1, current_q2 = self.critic(state, action)
        critic_loss_1 = F.mse_loss(current_q1, target_value)
        critic_loss_2 = F.mse_loss(current_q2, target_value)
        
        critic_loss = critic_loss_1 + critic_loss_2
        
        # Update critic
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm = 0.5)
        self.critic_optimizer.step() 
        self.critic_scheduler.step()
        
        
        # compute actor loss
        
        new_action, new_log_prob = self.actor(state)
        old_log_prob = new_log_prob.detach()
        
        q_1, q_2 = self.critic(state, new_action)
        q_pi = ((0.75 * torch.min(q_1, q_2)) + (0.25 * torch.max(q_1, q_2)))
        
        #Kl_Div =  torch.sum(torch.exp(old_log_prob) * (old_log_prob - new_log_prob), dim = -1).mean()
        
        actor_loss = ((self.compute_alpha() * new_log_prob - q_pi)).mean()

        # Update Actor network
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm = 0.5)
        self.actor_optimizer.step()
        self.actor_scheduler.step()
        
        
        self.update_target()
        
        # Alpha optimizer
        
        alpha_loss = - (self.log_alpha * (old_log_prob + self.target_entropy)).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm = 1.0)
        self.alpha_optimizer.step()
        
        return actor_loss.item(), critic_loss.item(), self.alpha.item()

#### *Reducing CodeFrame*

In [80]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)

## **SAC Buffer**

In [81]:
class SAC_Buffer:
    
    def __init__(self, capacity, batch_size):
        
        self.capacity = capacity
        self.batch_size = batch_size
        self.sac_buffer = deque(maxlen = capacity)
        self.ptr = 0
        
    def add(self, state,  action, reward, next_state, done):
        
        # Converting to tensor
        
        state = safe_tensor(state)
        action = safe_tensor(action)
        reward = safe_tensor(reward)
        next_state = safe_tensor(next_state)
        done = safe_tensor(done)
        
        experience = (state, action, reward, next_state, done)
        
        if len(self.sac_buffer) < self.capacity:
            
            self.sac_buffer.append(experience)
            
        else:
            
            self.sac_buffer[self.ptr] = experience
            self.ptr = (1 + self.ptr) % self.capacity
            
    def __len__(self):
        return len(self.sac_buffer)
    
    def sample(self, batch_size):
        
        indices = np.random.choice(len(self.sac_buffer), batch_size)
        
        states, actions, rewards, next_states, dones = zip(*[self.sac_buffer[ind] for ind in indices])
        
        # convert to tensor
        
        states = torch.stack(states).to(device)
        actions = torch.stack(actions).to(device)
        rewards = torch.stack(rewards).to(device)
        next_states = torch.stack(next_states).to(device)
        dones = torch.stack(dones).to(device)
        
        return (states, actions, rewards, next_states, dones)

In [82]:
sac_buffer = SAC_Buffer(capacity = 500_000, batch_size = 512)

## **Rollout**

In [83]:
class Rollout:
    def __init__(self, world_model = world_model, actor_network = actor_network, buffer = world_buffer, sac_buffer = sac_buffer, length = 5):
        
        self.world = world_model
        self.actor = actor_network
        self.buffer = buffer
        self.length = length
        self.sac_buffer = sac_buffer
        
    def roll_out(self, batch_size):
        
        state, _, _, _ ,_ = self.buffer.sample(batch_size)
        
        #state  = torch.stack(state).to(device)
        #reward = torch.stack(reward).to(device)
        #next_state = torch.stack(next_state).to(device)
        
        for l in range(self.length):
            
            with torch.no_grad():
                
                action, _ = self.actor(state)
                new_next_state, new_reward = self.world(state, action)
                
                new_done = torch.zeros_like(new_reward).to(device)
                
                if new_reward.dim() == 2:
                    new_reward = new_reward.squeeze(-1)
                
                for i in range(state.shape[0]):
                    self.sac_buffer.add(
                        state[i].cpu().numpy(),
                        action[i].cpu().numpy(),
                        new_reward[i].item(),
                        new_next_state[i].cpu().numpy(),
                        new_done[i].item()
                    )
                
                state = new_next_state.detach()
                

In [84]:
rolling = Rollout()

# **Hyper Params**

In [85]:
# Hyper params

gamma = 0.997
max_iter = 10_000

# Actor params

actor_optimizer = optim.AdamW(actor_network.parameters(), lr = 3e-4, weight_decay = 0.001)
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(actor_optimizer, T_max = max_iter)

# Critic params

critic_optimizer = optim.AdamW(critic_network.parameters(), lr = 1e-4, weight_decay = 0.001)
critic_scheduler = optim.lr_scheduler.CosineAnnealingLR(critic_optimizer, T_max = max_iter)


# SAC agent
agent = SAC_Agent(actor_network, critic_network, target_critic, actor_optimizer, actor_scheduler, critic_optimizer, critic_scheduler, gamma)

## **Mix Batch**

This is the most important part and the ratio between real and synthetic data assummed in this implementation is 0.25

In [86]:
class Mix_Batch:
    
    def __init__(self, sac_buffer = sac_buffer, world_buffer = world_buffer, ratio = 0.25):
        
        self.sac_buffer = sac_buffer
        self.world_buffer = world_buffer
        self.ratio = ratio
        
    def sample(self, batch_size):
        
        real_size = int(self.ratio * batch_size)
        actual_size = batch_size - real_size
        
        real_batch = self.sac_buffer.sample(real_size)
        dyn_batch = self.world_buffer.sample(actual_size)
        
        r_s, r_a, r_r, r_ns, r_d = real_batch
        d_s, d_a, d_r, d_ns, d_d = dyn_batch
        
        # Now Concatenate
        
        state = torch.cat([r_s, d_s], dim = 0).to(device)
        action = torch.cat([r_a, d_a], dim = 0).to(device)
        reward = torch.cat([r_r, d_r], dim = 0).to(device)
        next_state = torch.cat([r_ns, d_ns], dim = 0).to(device)
        done = torch.cat([r_d, d_d], dim = 0).to(device)
        
        return state, action, reward, next_state, done        

In [87]:
mix_batch = Mix_Batch(sac_buffer, world_buffer)

## **Training Block**

In [88]:
# Params

max_epochs = 300
world_train_steps = 50
policy_train_steps = 40
episode_rewards = []
ep_reward = 0

batch_size = 256

In [None]:
for epoch in range(max_epochs):

    #### 1. ENV INTERACTION + WORLD BUFFER FILL ###
    
    for _ in range(1000):
        state, _  = env.reset()
        action, _  = agent.select_action(state)
        action = action.detach().cpu().numpy()[0]
        next_state, reward, done, _, _ = env.step(action)
        
        world_buffer.add(state, action, reward, next_state, done)
        ep_reward += reward

        if done:
            
            episode_rewards.append(ep_reward)
            ep_reward = 0

    #### 2. WORLD MODEL UPDATE ####
    for _ in range(world_train_steps):
        batch = world_buffer.sample(batch_size)
        total_wloss, state_wloss, reward_wloss = world_loss.compute_loss(*batch)

    #### 3. ROLLOUT FROM WORLD MODEL ####
    # Optional: if your rollout class returns stats
    rolling.roll_out(batch_size)

    #### 4. MIXED BATCH SAMPLING ####
    for _ in range(policy_train_steps):
        actor_loss, critic_loss, alpha = agent.update(mix_batch, batch_size)

    #### 6. LOGGING ####
    avg_env_reward = sum(episode_rewards) #/ len(episode_rewards)

    print(f"Epoch: {epoch}/{max_epochs}")
    #print(f"  Avg Env Reward     : {avg_env_reward:.2f}")
    print(f"  World Model Loss   : {total_wloss:.3f} (State: {state_wloss:.3f}, Reward: {reward_wloss:.3f})")
    print(f"  SAC Losses         : Actor: {actor_loss:.3f} | Critic: {critic_loss:.3f} | Alpha: {alpha:.3f}")
    #print(f"  Rollout            : {samples_added} model samples | Avg Reward: {rollout_avg_reward:.2f}")
    print("-" * 80)