# The Ultimate SAC

# Import tools

In [678]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import gymnasium as gym
from gymnasium.wrappers import RescaleAction
import copy

import warnings

## Preferences

In [679]:
warnings.filterwarnings('ignore' , category = UserWarning)

# Env Setup

In [680]:
env = gym.make('Humanoid-v5' , max_episode_steps = 3000)
env = RescaleAction(env , min_action = -1.0 , max_action = 1.0)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

action_max = env.action_space.high[0]

print(f'State dim : {state_dim} | Action dim : {action_dim} | Action max : {action_max}')

State dim : 348 | Action dim : 17 | Action max : 1.0


## Device setup

In [681]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device : {device}')

Device : cuda


In [682]:
log_std_min = -10
log_std_max = 2

# ULtimate design

### Feature extractor

In [683]:
class feature_extractor(nn.Module):
    def __init__(self, input_dim , output_dim , h):
        super(feature_extractor , self).__init__()
        
        
        self.f1 = nn.Sequential(
            
            nn.Linear(input_dim , h),
            nn.Linear(h,h),
        )
        
        
        self.gru = nn.GRU(input_size = h , hidden_size = h , num_layers = 1 ,batch_first = True)
        #self.gru_prod = nn.Linear(512 , h)
        self.multi = nn.MultiheadAttention(embed_dim = h , num_heads = 8 , batch_first = True)
        
        
        self.f = nn.Sequential(
            
            nn.Linear(h , h),
            nn.LayerNorm(h),
            nn.SiLU(),
            
            
            nn.Linear(h , h),
            nn.SiLU(),
            
            nn.Linear(h , h),
            nn.SiLU(),
            
            nn.Linear(h , h),
            nn.LayerNorm(h),
            nn.SiLU(),
            
            nn.Linear(h , h),
            nn.SiLU(),
            
            
            nn.Linear(h , output_dim),
            nn.SiLU(),
        )
        
        
        self.h = h
        
    def forward(self , x , hidden = None):
        
        if x.dim() == 1:
            x = x.unsqueeze(0)
        else:
            x = x.squeeze(1)
            
            
        x = self.f1(x)
        
        gru_out , hidden = self.gru(x , hidden)
        #gru_prod = self.gru_prod(gru_out)
        
        x = x.unsqueeze(1)
        adv , _ = self.multi(gru_out , gru_out , gru_out)
        adv = adv.reshape(-1 , self.h)
        adv = F.silu(adv)
        
        
        x1 = self.f(adv)
        
        return x1

### Actor network

In [684]:
class actor(nn.Module):
    def __init__(self, state_dim , action_dim , head1 , head2 , head3 , head4 , h1 , feature_extractor = feature_extractor , max_action_range = action_max):
        super(actor , self).__init__()
        
        # Feature extractor
        self.feature = feature_extractor(state_dim , head1 , h1)
        
        # 4 Layers + SiLU
        self.fc = nn.Sequential(
            
            nn.Linear(head1 , head1),
            nn.LayerNorm(head1),
            nn.SiLU(),
            
            nn.Linear(head1 , head2),
            nn.SiLU(),
            
            nn.Linear(head2 , head3),
            nn.SiLU(),
            
            nn.Linear(head3 , head4),
            nn.LayerNorm(head4),
            nn.SiLU()
        )
        
        self.mu = nn.Linear(head4 , action_dim)
        
        self.log_std = nn.Linear(head4 , action_dim)
        
        self.max_action = max_action_range
        
    def forward(self , state):
        
        x = self.feature(state)
        
        x = self.fc(x)
        
        mu = self.mu(x)
        log_std = self.log_std(x)
        
        # Smoother scaling (key improvement)
        mu = torch.tanh(mu) * 0.5  # Constrain initial mu to [-0.5, 0.5]
        log_std = torch.tanh(log_std)  # Constrain log_std to [-1, 1] before scaling
        
        log_std = torch.clamp(log_std , log_std_min , log_std_max)
        std = torch.exp(log_std)
        
        # Reparamaterization trick
        
        normal = torch.distributions.Normal(mu , std)
        z = normal.rsample()
        tanh_z = torch.tanh(z)
        action = tanh_z * self.max_action
        
        
        # Tanh Squashing
        
        log_prob = normal.log_prob(z)
        squash = 2 * (torch.log(torch.tensor(2.0, device=z.device)) - z - F.softplus(-2 * z))
        log_prob = log_prob - squash
        log_prob = log_prob.sum(dim = 1 , keepdim = True)
        
        return action , log_prob

### Critic network


In [685]:
class critic(nn.Module):
    def __init__(self, state_dim , action_dim , head1 , head2 , head3 , head4 , h1 , feature_extractor = feature_extractor):
        super(critic , self).__init__()
        
        # Feature
        self.feature = feature_extractor(state_dim + action_dim , head1 , h1)
        
        # 2 Critic - c1 & c2
        self.c1 = nn.Sequential(
            
            nn.Linear(head1 , head1),
            nn.LayerNorm(head1),
            nn.SiLU(),
            
            nn.Linear(head1 , head2),
            nn.SiLU(),
            
            nn.Linear(head2 , head3),
            nn.SiLU(),
            
            nn.Linear(head3 , head4),
            nn.LayerNorm(head4),
            nn.SiLU(),
            
            nn.Linear(head4 , 1)
        )
        
        
        self.c2 = nn.Sequential(
            
            nn.Linear(head1 , head1),
            nn.LayerNorm(head1),
            nn.SiLU(),
            
            nn.Linear(head1 , head2),
            nn.SiLU(),
            
            nn.Linear(head2 , head3),
            nn.SiLU(),
            
            nn.Linear(head3 , head4),
            nn.LayerNorm(head4),
            nn.SiLU(),
            
            nn.Linear(head4 , 1)
        )
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=0.01)
                nn.init.constant_(m.bias, 0.0)
        
    def forward(self , state , action):
                
        if state.dim() == 3:
            state = state.squeeze(1)
        if action.dim() == 3:
            action = action.squeeze(1)

        
        sa = torch.cat([state , action] , dim = 1)
        feature = self.feature(sa)
        feature = feature + 0.1 * feature.detach()
        
        q1 = self.c1(feature)
        q2 = self.c2(feature)
        
        return q1 , q2

## Model Setup

In [686]:
# Neuron setting
head1 = 256
head2 = 512
head3 = 512
head4 = 256
h1 = 256

# Model

actor_network = actor(state_dim , action_dim , head1 , head2 , head3 , head4 , h1).to(device)
critic_network = critic(state_dim , action_dim , head1 , head2 , head3 , head4 , h1).to(device)

target_critic = copy.deepcopy(critic_network).to(device)

print('-----------------------------------------------------------------------')
print(actor_network)
print('-----------------------------------------------------------------------')
print(critic_network)

-----------------------------------------------------------------------
actor(
  (feature): feature_extractor(
    (f1): Sequential(
      (0): Linear(in_features=348, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=256, bias=True)
    )
    (gru): GRU(256, 256, batch_first=True)
    (multi): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (f): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=256, out_features=256, bias=True)
      (4): SiLU()
      (5): Linear(in_features=256, out_features=256, bias=True)
      (6): SiLU()
      (7): Linear(in_features=256, out_features=256, bias=True)
      (8): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (9): SiLU()
      (10): Linear(in_features=256, out_features=256, bias=True)
      (11):

# PER As Roller Buffer

In [687]:
class PER_new():
    def __init__(self, capacity, batch_size, beta=0.4, alpha_per=0.6):
        self.capacity = capacity
        self.buffer = []
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.avg_priority = 1.0
        self.beta = beta
        self.alpha_per = alpha_per
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done):
        
        # convert them to tensor
        state = state
        next_state = next_state
        reward = torch.tensor(reward , dtype = torch.float32).to(device)
        action = torch.tensor(action ,dtype = torch.float32).to(device)
        done = torch.tensor(done , dtype = torch.float32).to(device)
        
        max_priority = self.priorities.max() if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
            self.priorities[len(self.buffer)-1] = max_priority
        else:
            idx = np.argmin(self.priorities)
            self.buffer[idx] = (state, action, reward, next_state, done)
            self.priorities[idx] = max_priority

    def sample(self):
        if len(self.buffer) == 0:
            raise ValueError("Buffer is empty")
            
        probs = self.priorities[:len(self.buffer)] ** self.alpha_per
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), size=min(self.batch_size, len(self.buffer)), 
                                 p=probs, replace=False)
        samples = [self.buffer[i] for i in indices]
        
        weights = (len(self.buffer) * probs[indices]) ** -self.beta
        weights /= weights.max()
        
        return samples, indices, torch.FloatTensor(weights).to(device)

    def update_priorities(self, indices, td_errors):
        # Ensure td_errors is numpy array with correct shape
        if isinstance(td_errors, torch.Tensor):
            td_errors = td_errors.detach().cpu().numpy()
        td_errors = np.reshape(td_errors, -1)  # Flatten
        
        # Clip and update
        clipped_errors = np.clip(td_errors, -6, 10)
        self.priorities[indices] = np.abs(clipped_errors) + 1e-6
        self.avg_priority = 0.99 * self.avg_priority + 0.01 * np.mean(clipped_errors)

### Replay buffer setup

In [688]:
# Buffer hyper params
capacity = 500_000
batch_size = 512

# Buffer setup
buffer = PER_new(capacity , batch_size)

# Soft update

In [689]:
def soft_update(target , source , tau = 0.0067):
    
    for target_param , param in zip(target.parameters() , source.parameters()):
        
        target_param.data.copy_((tau * param.data) + ((1 - tau) * target_param.data))

# Loss function

In [690]:
class loss_function():
    def __init__(self , gamma , action_dim , actor_network , critic_network , target_critic , policy_delay , total_step , critic_optimizer , critic_scheduler , actor_optimizer , actor_scheduler , soft_update , batch_size , episode):
        
        self.actor = actor_network
        self.critic = critic_network
        self.target_critic = target_critic
        self.gamma = gamma
        self.action_dim = action_dim
        self.critic_optimizer = critic_optimizer
        self.critic_scheduler = critic_scheduler
        self.actor_optimizer = actor_optimizer
        self.actor_scheduler = actor_scheduler
        self.total_step = total_step
        self.policy_delay = policy_delay
        self.update = soft_update
        self.batch_size = batch_size
        self.episode = episode
        
        # Entropy
        
        def update_entropy(episode):
            
            if episode < 3000 :
                self.target_entropy = - self.action_dim * 1.5
            elif 3000 < episode < 7000 :
                self.target_entropy = - self.action_dim * 1.0
            else:
                self.target_entropy = - self.action_dim * 0.8 
                  
            return self.target_entropy
        
        
        self.log_alpha = torch.tensor(np.log(0.2) , requires_grad = True , device = device)
        self.alpha_optimizer = optim.AdamW([self.log_alpha] , lr = 1e-4 , weight_decay = 0.01)
        self.target_entropy = update_entropy(episode)
        self.alpha_min = 0.15
        
        
    def compute_loss(self , state , action , reward , next_state , done , weights):
        
        # Alpha 
        
        self.alpha = self.log_alpha.exp().detach()
        self.alpha = self.alpha.clamp(min = self.alpha_min)
        
        # Tesnor handling
        
        state = torch.stack(state).to(device)
        action = torch.stack(action).to(device)
        reward = torch.stack(reward).unsqueeze(1).to(device)
        next_state = torch.stack(next_state).to(device)
        done = torch.stack(done).unsqueeze(1).to(device)
        weights = torch.tensor(weights , dtype = torch.float32).unsqueeze(1).to(device)
        
        # Compute target_value
        
        with torch.no_grad():
            next_actions , next_log_probs = self.actor(next_state)
            if next_actions.dim() == 1:
                next_actions = next_actions.unsqueeze(0)

            target_1 , target_2 = self.target_critic(next_state , next_actions)
            target_q = (0.75 * torch.min(target_1, target_2) + 0.25 * torch.max(target_1, target_2)).clamp(-200 , 600) - self.alpha * next_log_probs          # Removed clamp from torch.min let them unbound
            target_vals = reward + self.gamma * (1 - done) * target_q
            
            
        # Current q val and critic loss
        
        current_q1 , current_q2 = self.critic(state , action)
        td_error_1 = F.huber_loss(current_q1 , target_vals , reduction = 'none')
        td_error2 = F.huber_loss(current_q2 , target_vals , reduction = 'none')
        
        critic_loss = (td_error2 + td_error_1) * weights
        critic_loss = critic_loss.mean()
        
        # Td error
        
        with torch.no_grad():
            td_error  = ((current_q1 - target_vals).abs() + (current_q2 - target_vals).abs()) / 2.0
            td_error = td_error.view(self.batch_size, -1).mean(1)
            
            
        # Optimize
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters() , max_norm = 0.5)
        self.critic_optimizer.step()
        
        
        # Actor loss
        
        new_action , log_prob = self.actor(state)
        with torch.no_grad():
            _ , old_log_prob = self.actor(state)              # replaced _ , old _log_prob with new_action and log_prob
            q1 , q2 = self.critic(state , new_action)
            q1 = q1.detach()
            q2 = q2.detach()
            q_pi = torch.min(q1 , q2)
            adv = q_pi - q_pi.mean()
        
        # Dummy actor loss
        
        kl_div = (log_prob - old_log_prob).mean()
        actor_loss = torch.tensor(0.0).to(device)
        
        # Optimize & Alpha loss
        if self.episode % self.policy_delay == 0 :  
            
            # Actor optimized
            actor_loss = (self.alpha * log_prob - adv).mean()
            actor_loss += 0.05 * kl_div
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters() , max_norm = 0.5)
            self.actor_optimizer.step()
                    
            # Soft update
            self.update(self.target_critic , self.critic)
            

        # alpha optimized
        log_prob = log_prob.detach()
        alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy + 0.2 * kl_div.detach())).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=float('inf'))
        if grad_norm > 10.0:  # Only clip if exploding
            torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=10.0)
        torch.nn.utils.clip_grad_norm_([self.log_alpha] , max_norm = 1.0)           # Removed gradient clipping let it explode 
        self.alpha_optimizer.step()
            
        if self.episode % 1000 == 0:
            
            self.actor_scheduler.step(actor_loss.item())
            self.critic_scheduler.step(critic_loss.item())
            #self.target_entropy *= 0.99
            
            
        return actor_loss.item() , critic_loss.item() , self.alpha.item() , td_error

# Loss setup

In [691]:
# Hyper params
gamma = 0.999
policy_delay = 3
total_step = 0
episode = 0
actor_lr = 3e-5
critic_lr = 3e-4
max_iter = 5_000_000
n_steps = 1024

# Optimizers

actor_optimizer = optim.SGD(actor_network.parameters() , lr = actor_lr , weight_decay = 0.001 , momentum = 0.9 , nesterov = True)
critic_optimizer = optim.SGD(critic_network.parameters() , lr = critic_lr , weight_decay = 0.001 , momentum = 0.9 , nesterov = True)


# Schedulers

actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(actor_optimizer , T_max = max_iter)
critic_scheduler = optim.lr_scheduler.CosineAnnealingLR(critic_optimizer , T_max = max_iter)

# loss setup

loss = loss_function(gamma = gamma , action_dim = action_dim ,actor_network = actor_network, critic_network = critic_network
                     ,target_critic = target_critic , policy_delay = policy_delay , total_step = total_step
                     ,critic_optimizer = critic_optimizer ,  critic_scheduler = critic_scheduler,
                     actor_optimizer = actor_optimizer , actor_scheduler = actor_scheduler,
                     soft_update = soft_update , batch_size = batch_size , episode = episode)


# Reward Shaping


In [692]:
def shaped_reward(reward , state , max_reward):
    
    is_falling = (state[2] < 0.8 or          # Height threshold
                 abs(state[3]) > 0.4 or      # Angular velocity X
                 abs(state[4]) > 0.4)        # Angular velocity Y
    
    if is_falling:
        
        recovery_bonus = 3.0 * (1 - abs(state[3]))
        survival_bonus = 0.5 * (0.8 - state[2])
        reward += recovery_bonus + survival_bonus
    
    else:
        
        reward += 0.05
        reward += 1.5 * state[1]
        reward += 0.3 * state[2] ** 2 # Height bonus quadritic
        reward += 1.5 * state[0] * (1 + state[2])   # Forward locomotion
        reward -= 0.1 * np.square(state[3:6]).sum()  # Angular penalty
        reward -= 0.1 * (state[3] ** 2 + state[4] ** 2)
        
    return reward / 10.0

In [693]:
def shaped_reward_2(reward, state , step):
    
    # Dynamic components
    progress = min(step / 500_000 , 1.0)
    target_speed = min(1.0 + 0.5 * progress , 2.0)
    forward_vel = state[0]
    height = state[2]
    
    # Progressive falling detection
    fall_risk = max(0, 0.8 - height) + 0.5*(abs(state[3]) + abs(state[4]))
    is_falling = fall_risk > 1.0
    
    # Episode length normalization
    #ep_length_factor = min(1.0, current_step/500)
    
    if is_falling:
        
        reward -= 1.0 * fall_risk  # Gradual penalty
        survival_bonus = 0.5 * (0.8 - state[2])   
        recovery_bonus = 3.0 * (1 - abs(state[3]))
        reward += survival_bonus + recovery_bonus
    else:
        
        # Speed reward (bell curve around target speed)
        speed_reward = 3.0 * np.clip(1 - abs(forward_vel - target_speed)/target_speed, 0, 1) * min(1.5, 1 + height)

        # Core rewards
        reward += 0.05  # Survival
        reward += 0.3 * height**2  # Upright
        reward += speed_reward  # Movement
        reward -= 0.3 * abs(state[1])  # Drift
        reward -= 0.1 * np.square(state[3:6]).sum()  # Angular penalty
        
        # Energy efficiency (if actions available)
        #if action is not None:
            #reward -= 0.01 * np.sum(np.square(action))
        
        # Foot contact bonus (if available)
        #if foot_contacts and all(foot_contacts):
            #reward += 0.2
    
    return reward 

# Trainig Loop

In [None]:
n_steps = 1024
max_iter = 5_000_000
rewards_per_iter = []
ep_reward = 0
episode = 0
alpha = 0
total_env_steps = 0
max_reward = 10.0


state, _ = env.reset()

for iter in range(0, max_iter, n_steps):  # total steps = ~5M, chunked by 1024 steps

    for _ in range(n_steps):
        
        # Convert state to tensor
        if len(state.shape) == 1:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        else:
            state_tensor = torch.tensor(state, dtype=torch.float32).to(device)

        # Get action from actor
        action, _ = actor_network(state_tensor)
        action = action.squeeze(0).detach().cpu().numpy()

        # Step environment
        next_state, reward, done, _, _ = env.step(action)

        # Reward shaping
        reward = shaped_reward(reward = reward , state = state , max_reward = max_reward )#, step = total_env_steps)
        scaled_reward = reward 

        # Next state tensor
        if len(next_state.shape) == 1:
            next_state_tensor = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(device)
        else:
            next_state_tensor = torch.tensor(next_state, dtype=torch.float32).to(device)

        # Add to buffer
        buffer.add(state_tensor, action, scaled_reward, next_state_tensor, done)
        
        # Increment
        
        state = next_state
        ep_reward += scaled_reward
        total_env_steps += 1 
        loss.total_step = total_env_steps

        if done:
            rewards_per_iter.append(ep_reward)
            if ep_reward > max_reward:
                max_reward = ep_reward
            print(f"Alpha: {alpha:.4f} | Episode: {loss.episode} | Step: {total_env_steps} | Reward: {ep_reward:.2f} | Max : {max_reward:.2f}")
            ep_reward = 0
            loss.episode += 1
            state, _ = env.reset()

    # Only update after n_steps
    if len(buffer.buffer) > 10_000:
        batch, indices, weights = buffer.sample()
        states, actions, rewards, next_states, dones = zip(*batch)

        actor_loss, critic_loss, alpha, td_error = loss.compute_loss(
            states, actions, rewards, next_states, dones, weights
        )
        buffer.update_priorities(indices , td_error.detach().cpu().numpy())

        
    # Optional: save model every X steps
    if iter % 50_000 == 0:
        torch.save(actor_network.state_dict(), f"sac_actor_step_{iter}.pth")
        torch.save(critic_network.state_dict(), f"sac_critic_step_{iter}.pth")

# Final save
torch.save(actor_network.state_dict(), "sac_actor_final.pth")
torch.save(critic_network.state_dict(), "sac_critic_final.pth")



Alpha: 0.0000 | Episode: 0 | Step: 22 | Reward: 16.87 | Max : 16.87
Alpha: 0.0000 | Episode: 1 | Step: 62 | Reward: 31.23 | Max : 31.23
Alpha: 0.0000 | Episode: 2 | Step: 81 | Reward: 14.91 | Max : 31.23
Alpha: 0.0000 | Episode: 3 | Step: 98 | Reward: 13.39 | Max : 31.23
Alpha: 0.0000 | Episode: 4 | Step: 126 | Reward: 22.14 | Max : 31.23
Alpha: 0.0000 | Episode: 5 | Step: 148 | Reward: 17.52 | Max : 31.23
Alpha: 0.0000 | Episode: 6 | Step: 166 | Reward: 14.26 | Max : 31.23
Alpha: 0.0000 | Episode: 7 | Step: 188 | Reward: 17.36 | Max : 31.23
Alpha: 0.0000 | Episode: 8 | Step: 224 | Reward: 26.76 | Max : 31.23
Alpha: 0.0000 | Episode: 9 | Step: 241 | Reward: 13.09 | Max : 31.23
Alpha: 0.0000 | Episode: 10 | Step: 278 | Reward: 26.94 | Max : 31.23
Alpha: 0.0000 | Episode: 11 | Step: 295 | Reward: 13.05 | Max : 31.23
Alpha: 0.0000 | Episode: 12 | Step: 312 | Reward: 13.53 | Max : 31.23
Alpha: 0.0000 | Episode: 13 | Step: 335 | Reward: 17.66 | Max : 31.23
Alpha: 0.0000 | Episode: 14 | Step

In [None]:


def evaluate_model(actor_model, critic_model, env, num_episodes=10, max_timesteps=1000, render=False):
    """
    Function to evaluate the Actor-Critic model's performance.

    Args:
        actor_model (nn.Module): The Actor model.
        critic_model (nn.Module): The Critic model.
        env (gym.Env): The environment to evaluate on.
        num_episodes (int): The number of episodes to run during evaluation.
        max_timesteps (int): Maximum timesteps per episode.
        render (bool): Whether to render the environment during evaluation.

    Returns:
        average_reward (float): Average reward across all episodes.
    """
    actor_model.eval()  # Set actor model to evaluation mode
    critic_model.eval()  # Set critic model to evaluation mode
    total_rewards = []

    with torch.no_grad():  # No gradients needed for evaluation
        for episode in range(num_episodes):
            state , _ = env.reset()  # Reset environment at the start of each episode
            episode_reward = 0

            for t in range(max_timesteps):
                if render:
                    env.render()  # Render environment if required

                # Get the action from the actor model
                state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
                action_probs , log_prob  = actor_model(state_tensor)  # Actor generates action probabilities
                action = action_probs.detach().cpu().numpy()[0]  # Assuming discrete actions

                # Use the critic to evaluate the state
                #state_value = critic_model(state_tensor , action)  # Critic gives value estimate

                # Apply the action to the environment
                next_state, reward, done, _ , info= env.step(action)

                episode_reward += reward
                state = next_state

                if done:
                    break

            total_rewards.append(episode_reward)
            
            print(total_rewards)

    average_reward = sum(total_rewards) / num_episodes
    print(f"Average reward over {num_episodes} episodes: {average_reward}")
    

    
    return average_reward
