In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions import Normal, TransformedDistribution
from torch.distributions.transforms import TanhTransform, AffineTransform
import imageio
from dmc import make_dmc_env
from collections import deque
import os
import random
import time

  fn()
  from pkg_resources import resource_stream, resource_exists


In [None]:
DEBUG = False

In [2]:
env_name = "humanoid-walk"
env = make_dmc_env(env_name, np.random.randint(0, 1000000), flatten=True, use_pixels=False)
print(env.action_space)
print(env.observation_space)
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.shape[0]

Box(-1.0, 1.0, (21,), float64)
Box(-inf, inf, (67,), float64)


In [None]:
env.reset()
frames = []
for _ in range(100):
    _,_,_,_,_ = env.step(env.action_space.sample())
    frame = env.render()
    frames.append(frame)
imageio.mimsave("humanoid.gif", frames)

In [17]:
print(input_dim, output_dim)

67 21


In [4]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = ('cpu')
print(device)

cpu


In [5]:
# to ensure stability of neural networks learning
LOG_STD_MIN = 2
LOG_STD_MAX = -20
EPS = 1e-6
# absolute bounds for weights initialization

In [5]:
# --- Replay Buffer ---
class ReplayBufferSAC:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        # Ensure inputs are numpy arrays for consistent stacking later
        state = np.asarray(state)
        action = np.asarray(action)
        reward = np.asarray([reward]) # Store reward as an array
        next_state = np.asarray(next_state)
        done = np.asarray([done])     # Store done as an array

        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done.astype(np.float32) # Ensure done is float

    def __len__(self):
        return len(self.buffer)

In [6]:
class ValueNet(nn.Module): # V(s) - state value (kept as per user, but not used in this SAC variant)
    def __init__(self, input_size, output_size=1, hidden_size=128):
        super(ValueNet, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )
    def forward(self, x):
        return self.seq(x)

# Q(s,a) - action value
class QNet(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=128): # input_size = state_dim + action_dim, output_size = 1
        super(QNet, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1) # Output a single Q-value
        )
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        return self.seq(x)

# Policy Network pi(a|s) for SAC
class PolicyNetSAC(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=128, action_bound_val=1.0):
        super(PolicyNetSAC, self).__init__()
        self.action_dim = action_dim
        # Assuming action_bound_val means actions are in [-action_bound_val, action_bound_val]
        # For actions in [-1, 1], action_bound_val = 1.0
        self.action_scale = torch.tensor(action_bound_val)

        self.seq = nn.Sequential(
            nn.Linear(state_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        self.mean_layer = nn.Linear(hidden_size, action_dim)
        self.log_std_layer = nn.Linear(hidden_size, action_dim)

    def forward(self, state):
        x = self.seq(state)
        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        dist = Normal(mean, std)
        sample = dist.rsample()
        action = torch.tanh(sample)
        log_prob = dist.log_prob(sample) - torch.log(1 - action.pow(2) + EPS)
        log_prob = log_prob.sum(axis=-1, keepdim=True)

        return action, log_prob
    def get_action(self, state):
        state = torch.from_numpy(state).unsqueeze(0).to(device, dtype=torch.float)
        true_action, _ = self.sample(state)
        return true_action.detach().cpu().numpy()
    
# --- SAC Agent ---
class SAC_agent:
    def __init__(self, state_dim, action_dim, device,
                 hidden_size=128, # User requested hidden_size for networks
                 lr_actor=3e-4, lr_critic=3e-4, lr_alpha=3e-4,
                 gamma=0.99, tau=0.005,
                 alpha=0.2, # Initial alpha, can be learned
                 target_update_interval=1,
                 replay_capacity=1000000,
                 action_bound_val=1.0): # Corresponds to action values in [-action_bound_val, action_bound_val]

        self.device = device
        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha # Will be learnable if log_alpha is defined
        self.target_update_interval = target_update_interval
        self.action_dim = action_dim

        # Actor Network
        self.actor = PolicyNetSAC(state_dim, action_dim, hidden_size, action_bound_val=action_bound_val).to(self.device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)

        # Critic Networks (Twinned Q-networks)
        self.critic1 = QNet(state_dim, action_dim, hidden_size).to(self.device)
        self.critic1_target = QNet(state_dim, action_dim, hidden_size).to(self.device)
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        #self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=lr_critic)

        self.critic2 = QNet(state_dim, action_dim, hidden_size).to(self.device)
        self.critic2_target = QNet(state_dim, action_dim, hidden_size).to(self.device)
        self.critic2_target.load_state_dict(self.critic2.state_dict())
        #self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=lr_critic)
        self.critic_optimizer = optim.Adam(list(self.critic1.parameters()) + list(self.critic2.parameters()), lr = lr_critic)
        
        # Automatic Entropy Tuning (for alpha)
        self.target_entropy = -torch.prod(torch.Tensor(self.action_dim).to(device)).item()

        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha = self.log_alpha.exp().item()
        #self.alpha_optimizer = optim.Adam([self.alpha], lr=lr_alpha)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr_alpha)

        self.replay_buffer = ReplayBufferSAC(replay_capacity)
        self.updates = 0 # Counter for target updates

    def select_action(self, state, evaluate=False):
        
        if evaluate: # Deterministic action
            action  = self.actor.get_action(state)
            return action
        else: # Stochastic action
            state_tensor = torch.FloatTensor(state).to(self.device).unsqueeze(0)
            action, _ = self.actor.sample(state_tensor) # No reparam needed for just acting
            return action.detach().cpu().numpy().flatten()

    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)

    def update_parameters(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return None, None, None # Not enough samples to train

        state_np, action_np, reward_np, next_state_np, done_np = self.replay_buffer.sample(batch_size)

        state = torch.FloatTensor(state_np).to(self.device)
        action = torch.FloatTensor(action_np).to(self.device)
        reward = torch.FloatTensor(reward_np).to(self.device)
        next_state = torch.FloatTensor(next_state_np).to(self.device)
        done = torch.FloatTensor(done_np).to(self.device)

        # --- Update Critic Networks ---
        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample(next_state)
            q1_next_target = self.critic1_target(next_state, next_actions)
            q2_next_target = self.critic2_target(next_state, next_actions)
            min_q_next_target = torch.min(q1_next_target, q2_next_target) - self.alpha * next_log_probs
            next_q_value = reward + (1 - done) * self.gamma * min_q_next_target
        
        q1 = self.critic1(state, action)
        q2 = self.critic2(state, action)
        
        critic1_loss = F.mse_loss(q1, next_q_value)
        critic2_loss = F.mse_loss(q2, next_q_value)
        critic_total_loss = critic1_loss+critic2_loss

        self.critic_optimizer.zero_grad()
        critic_total_loss.backward()
        self.critic_optimizer.step()
        
        #self.critic1_optimizer.zero_grad()
        #critic1_loss.backward()
        #self.critic1_optimizer.step()

        #self.critic2_optimizer.zero_grad()
        #critic2_loss.backward()
        #self.critic2_optimizer.step()

        # --- Update Actor Network ---
        # Freeze Q-networks to prevent gradient flow from actor loss
        #for p in self.critic1.parameters(): p.requires_grad = False
        #for p in self.critic2.parameters(): p.requires_grad = False

        pi_actions, pi_log_probs = self.actor.sample(state) # Uses reparameterization
        q1_pi = self.critic1(state, pi_actions)
        q2_pi = self.critic2(state, pi_actions)
        min_q_pi = torch.min(q1_pi, q2_pi)
        
        actor_loss = ((self.alpha * pi_log_probs) - min_q_pi).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Unfreeze Q-networks
        #for p in self.critic1.parameters(): p.requires_grad = True
        #for p in self.critic2.parameters(): p.requires_grad = True

        # --- Update Alpha (Temperature) ---
        alpha_loss = -(self.log_alpha.exp() * (pi_log_probs + self.target_entropy).detach()).mean()
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp().item()

        # --- Soft Update Target Networks ---
        self.updates +=1
        if self.updates % self.target_update_interval == 0:
            self._soft_update(self.critic1_target, self.critic1)
            self._soft_update(self.critic2_target, self.critic2)
            
        return critic1_loss.item(), actor_loss.item(), self.alpha


    def _soft_update(self, target_net, source_net):
        for target_param, source_param in zip(target_net.parameters(), source_net.parameters()):
            target_param.data.copy_(self.tau * source_param.data + (1.0 - self.tau) * target_param.data)

    def save_checkpoint(self, path, filename_prefix="sac_checkpoint"):
        if not os.path.exists(path):
            os.makedirs(path)
        checkpoint = {
            'actor_state_dict': self.actor.state_dict(),
            'critic1_state_dict': self.critic1.state_dict(),
            'critic2_state_dict': self.critic2.state_dict(),
            'critic1_target_state_dict': self.critic1_target.state_dict(),
            'critic2_target_state_dict': self.critic2_target.state_dict(),
            'alpha_state_dict': self.alpha, # Save log_alpha tensor directly
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'alpha_optimizer_state_dict': self.alpha_optimizer.state_dict(),
            'replay_buffer': self.replay_buffer.buffer # Save deque content
        }
        torch.save(checkpoint, os.path.join(path, f"{filename_prefix}.pth"))
        print(f"SAC checkpoint saved to {os.path.join(path, f'{filename_prefix}.pth')}")

    def load_checkpoint(self, path, filename_prefix="sac_checkpoint"):
        checkpoint_path = os.path.join(path, f"{filename_prefix}.pth")
        if os.path.exists(checkpoint_path):
            print(checkpoint_path)
            checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
            self.actor.load_state_dict(checkpoint['actor_state_dict'])
            self.critic1.load_state_dict(checkpoint['critic1_state_dict'])
            self.critic2.load_state_dict(checkpoint['critic2_state_dict'])
            self.critic1_target.load_state_dict(checkpoint['critic1_target_state_dict'])
            self.critic2_target.load_state_dict(checkpoint['critic2_target_state_dict'])
            
            # Load log_alpha
            # If saved as a tensor:
            self.alpha = checkpoint['alpha_state_dict']

            self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
            self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.alpha_optimizer.load_state_dict(checkpoint['alpha_optimizer_state_dict'])
            
            self.replay_buffer.buffer = checkpoint.get('replay_buffer', deque(maxlen=self.replay_buffer.buffer.maxlen))

            # Update current alpha value from loaded log_alpha
            #self.alpha = self.alpha.exp().detach().item()
            print(f"SAC checkpoint loaded from {checkpoint_path}")
            # Ensure target networks are identical to loaded main networks initially after loading
            self.critic1_target.load_state_dict(self.critic1.state_dict())
            self.critic2_target.load_state_dict(self.critic2.state_dict())
            return 0
        else:
            print(f"No SAC checkpoint found at {checkpoint_path}, starting from scratch.")
            return -1

def evaluate_agent(env, agent, num_episodes, max_episode_len, seed_offset=100000):
    total_reward = 0
    for i in range(num_episodes):
        # Use a different seed for each evaluation episode for more robust evaluation
        # Pass the seed to env.reset if your wrapper supports it.
        # If your make_dmc_env creates a new env instance, you can pass seed there.
        # For now, assuming env.reset() can take a seed.
        obs, _ = env.reset(seed=seed_offset + i)
        episode_reward = 0
        for _ in range(max_episode_len):
            action = agent.select_action(obs, evaluate=True)
            next_obs, reward, terminated, truncated, _ = env.step(action)
            episode_reward += reward
            obs = next_obs
            if terminated or truncated:
                break
        total_reward += episode_reward
    return total_reward / num_episodes

In [8]:
def evaluate_agent(env, agent, num_episodes, max_episode_len, seed_offset=100000):
    total_reward = 0
    for i in range(num_episodes):
        # Use a different seed for each evaluation episode for more robust evaluation
        # Pass the seed to env.reset if your wrapper supports it.
        # If your make_dmc_env creates a new env instance, you can pass seed there.
        # For now, assuming env.reset() can take a seed.
        obs, _ = env.reset(seed=seed_offset + i)
        episode_reward = 0
        for _ in range(max_episode_len):
            action = agent.select_action(obs, evaluate=True)
            next_obs, reward, terminated, truncated, _ = env.step(action)
            episode_reward += reward
            obs = next_obs
            if terminated or truncated:
                break
        total_reward += episode_reward
    return total_reward / num_episodes

In [3]:
if __name__ == '__main__':
    # --- Hyperparameters ---
    ENV_NAME = "humanoid-walk"
    SEED = np.random.randint(0, 1000000) # Initial seed for the training environment
    DEVICE = torch.device("cpu")
    print(f"Using device: {DEVICE}")

    # SAC Agent Hyperparameters
    HIDDEN_SIZE = 256 # Increased from 128, better for humanoid
    LR_ACTOR = 3e-4
    LR_CRITIC = 3e-4
    LR_ALPHA = 3e-4 # Learning rate for temperature
    GAMMA = 0.99    # Discount factor
    TAU = 0.005     # Target network soft update rate
    ALPHA_INIT = 0.2 # Initial temperature, will be tuned
    REPLAY_CAPACITY = int(1e6)
    ACTION_BOUND_VAL = 1.0 # From env.action_space: Box(-1.0, 1.0, ...)

    # Training Loop Hyperparameters
    TOTAL_TIMESTEPS = int(1e7)       # Total timesteps for training
    START_TIMESTEPS = int(1e4)       # Timesteps for random actions before training starts
    BATCH_SIZE = 256                 # Batch size for SAC updates
    UPDATES_PER_STEP = 1             # Number of SAC updates per environment step
    MAX_EPISODE_LEN = 1000           # Max length of each episode (typical for DMC humanoid)
    
    EVAL_FREQ = int(2e4)             # Evaluate agent every N timesteps
    EVAL_EPISODES = 10               # Number of episodes for evaluation
    SAVE_FREQ = int(1e5)             # Save model every N timesteps
    MODEL_SAVE_PATH = "./sac_models" # Path to save models

    # --- Initialization ---
    # Note: For DMC, ensure your `make_dmc_env` correctly sets up the environment.
    # The `flatten=True` and `use_pixels=False` are specific to how you wrap it.
    train_env = make_dmc_env(ENV_NAME, SEED, flatten=True, use_pixels=False)
    # Create a separate environment for evaluation with a different seed if possible
    eval_env = make_dmc_env(ENV_NAME, SEED + 12345, flatten=True, use_pixels=False) 

    state_dim = train_env.observation_space.shape[0]
    action_dim = train_env.action_space.shape[0]
    print(f"State dim: {state_dim}, Action dim: {action_dim}")
    print(f"Action space low: {train_env.action_space.low[0]}, high: {train_env.action_space.high[0]}")


    agent = SAC_agent(state_dim, action_dim, DEVICE,
                      hidden_size=HIDDEN_SIZE,
                      lr_actor=LR_ACTOR, lr_critic=LR_CRITIC, lr_alpha=LR_ALPHA,
                      gamma=GAMMA, tau=TAU, alpha=ALPHA_INIT,
                      replay_capacity=REPLAY_CAPACITY,
                      action_bound_val=ACTION_BOUND_VAL)

    # Optional: Load checkpoint if resuming
    # agent.load_checkpoint(MODEL_SAVE_PATH, "sac_checkpoint_some_timestep")

    obs, _ = train_env.reset(seed=SEED)
    episode_reward = 0
    episode_timesteps = 0
    episode_num = 0
    
    # For logging losses
    recent_critic_loss = 0
    recent_actor_loss = 0
    recent_alpha = agent.alpha

    print(f"Starting training for {TOTAL_TIMESTEPS} timesteps...")
    start_time = time.time()

    for t in range(TOTAL_TIMESTEPS):
        episode_timesteps += 1

        if t < START_TIMESTEPS:
            action = train_env.action_space.sample() # Random action
        else:
            action = agent.select_action(obs, evaluate=False)

        next_obs, reward, terminated, truncated, _ = train_env.step(action)
        done = terminated or truncated

        agent.store_transition(obs, action, reward, next_obs, done)
        obs = next_obs
        episode_reward += reward

        if t >= START_TIMESTEPS:
            for _ in range(UPDATES_PER_STEP):
                c_loss, a_loss, current_alpha_val = agent.update_parameters(BATCH_SIZE)
                if c_loss is not None: # If an update happened
                    recent_critic_loss = c_loss
                    recent_actor_loss = a_loss
                    recent_alpha = current_alpha_val
        
        if done or episode_timesteps >= MAX_EPISODE_LEN: # MAX_EPISODE_LEN handles non-terminating tasks
            elapsed_time = time.time() - start_time
            print(f"Total T: {t+1}/{TOTAL_TIMESTEPS} | Episode Num: {episode_num+1} | Episode T: {episode_timesteps} | Reward: {episode_reward:.2f} | Alpha: {recent_alpha:.4f} | C_Loss: {recent_critic_loss:.8f} | A_Loss: {recent_actor_loss:.8f} | Time: {elapsed_time/60:.1f}m")
            
            obs, _ = train_env.reset()
            episode_reward = 0
            episode_timesteps = 0
            episode_num += 1

        if (t + 1) % EVAL_FREQ == 0 and t >= START_TIMESTEPS:
            avg_eval_reward = evaluate_agent(eval_env, agent, EVAL_EPISODES, MAX_EPISODE_LEN)
            print("--------------------------------------------------------")
            print(f"Evaluation at T: {t+1} | Avg Reward over {EVAL_EPISODES} episodes: {avg_eval_reward:.2f}")
            print("--------------------------------------------------------")

        if (t + 1) % SAVE_FREQ == 0 and t >= START_TIMESTEPS :
            agent.save_checkpoint(MODEL_SAVE_PATH, f"sac_humanoid_t{t+1}")

Using device: cpu
State dim: 67, Action dim: 21
Action space low: -1.0, high: 1.0



KeyboardInterrupt



In [29]:
#load all file in sac_models and find the best one with average total reward over 100 episodes
#agent.load_checkpoint(MODEL_SAVE_PATH, "sac_checkpoint_some_timestep")

step_range = (5600000, 5900001)
test_agent = SAC_agent(state_dim, action_dim, DEVICE,
                      hidden_size=HIDDEN_SIZE,
                      lr_actor=LR_ACTOR, lr_critic=LR_CRITIC, lr_alpha=LR_ALPHA,
                      gamma=GAMMA, tau=TAU, alpha=ALPHA_INIT,
                      replay_capacity=REPLAY_CAPACITY,
                      action_bound_val=ACTION_BOUND_VAL)
best_step = 0
best_reward = -np.inf
for steps in range(step_range[0], step_range[1] + 1, 100000):
    if test_agent.load_checkpoint(MODEL_SAVE_PATH, f"sac_humanoid_t{steps}") == -1:
        continue
    avg_eval_reward = evaluate_agent(eval_env, test_agent, 100, 2000)
    print(f"Average evaluation reward at step {steps}: {avg_eval_reward:.2f}")
    if avg_eval_reward > best_reward:
        best_reward = avg_eval_reward
        best_step = steps
print(f"Best reward: {best_reward:.2f} at step: {best_step}")
    

./sac_models\sac_humanoid_t5600000.pth
SAC checkpoint loaded from ./sac_models\sac_humanoid_t5600000.pth
Average evaluation reward at step 5600000: 804.01
No SAC checkpoint found at ./sac_models\sac_humanoid_t5700000.pth, starting from scratch.
./sac_models\sac_humanoid_t5800000.pth
SAC checkpoint loaded from ./sac_models\sac_humanoid_t5800000.pth
Average evaluation reward at step 5800000: 727.81
./sac_models\sac_humanoid_t5900000.pth
SAC checkpoint loaded from ./sac_models\sac_humanoid_t5900000.pth
Average evaluation reward at step 5900000: 482.80
Best reward: 804.01 at step: 5600000


In [7]:
#dump ou polycy only
HIDDEN_SIZE = 256 # Increased from 128, better for humanoid
LR_ACTOR = 3e-4
LR_CRITIC = 3e-4
LR_ALPHA = 3e-4 # Learning rate for temperature
GAMMA = 0.99    # Discount factor
TAU = 0.005     # Target network soft update rate
ALPHA_INIT = 0.2 # Initial temperature, will be tuned
REPLAY_CAPACITY = int(1e6)
ACTION_BOUND_VAL = 1.0 # From env.action_space: Box(-1.0, 1.0, ...)

    # Training Loop Hyperparameters
TOTAL_TIMESTEPS = int(1e7)       # Total timesteps for training
START_TIMESTEPS = int(1e4)       # Timesteps for random actions before training starts
BATCH_SIZE = 256                 # Batch size for SAC updates
UPDATES_PER_STEP = 1             # Number of SAC updates per environment step
MAX_EPISODE_LEN = 1000           # Max length of each episode (typical for DMC humanoid)
    
EVAL_FREQ = int(2e4)             # Evaluate agent every N timesteps
EVAL_EPISODES = 10               # Number of episodes for evaluation
SAVE_FREQ = int(1e5)             # Save model every N timesteps
MODEL_SAVE_PATH = "./" # Path to save models
test_agent = SAC_agent(state_dim, action_dim, DEVICE,
                      hidden_size=HIDDEN_SIZE,
                      lr_actor=LR_ACTOR, lr_critic=LR_CRITIC, lr_alpha=LR_ALPHA,
                      gamma=GAMMA, tau=TAU, alpha=ALPHA_INIT,
                      replay_capacity=REPLAY_CAPACITY,
                      action_bound_val=ACTION_BOUND_VAL)

test_agent.load_checkpoint(MODEL_SAVE_PATH, f"sac_humanoid_t{8700000}")


./sac_humanoid_t8700000.pth
SAC checkpoint loaded from ./sac_humanoid_t8700000.pth


0

In [8]:
torch.save(test_agent.actor.state_dict(), "policy.pth")