In [1]:
!sudo apt-get update
!sudo apt-get install -y swig g++ python3-dev
!pip install gymnasium[box2d] torch

0% [Working]            Hit:1 https://packages.cloud.google.com/apt gcsfuse-jammy InRelease
0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.82)] [                                                                               Hit:2 http://archive.ubuntu.com/ubuntu jammy InRelease
0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.82)] [                                                                               Hit:3 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:5 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:7 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:8 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcont

In [2]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import random
from collections import deque
import time
import matplotlib.pyplot as plt
import datetime

ENV_ID = "BipedalWalker-v3"
RANDOM_SEED = 0
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4
ALPHA_LR = 3e-4
DISCOUNT_FACTOR = 0.99
TARGET_UPDATE_RATE = 0.005
INITIAL_ALPHA = 0.2
AUTO_TUNE_TEMP = True
BUFFER_CAPACITY = int(1e6)
TRAIN_BATCH_SIZE = 256
NETWORK_HIDDEN_UNITS = 256
EXPLORATION_STEPS = 10000
TOTAL_TRAINING_STEPS = int(1e6)
RENDER_EVALUATION = True
LOGGING_INTERVAL = 1000
REWARD_LOG_WINDOW = 100

STOP_WINDOW = 20
STOP_REWARD_TARGET = 250.0

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)


compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {compute_device}")


class ExperienceReplayBuffer:
    def __init__(self, max_size):
        self.storage = deque(maxlen=max_size)
        self._max_size = max_size

    def add_transition(self, obs, act, rew, next_obs, terminated):
        obs = np.expand_dims(obs, 0)
        next_obs = np.expand_dims(next_obs, 0)
        self.storage.append((obs, act, rew, next_obs, terminated))

    def sample_batch(self, num_samples):
        obs_batch, act_batch, rew_batch, next_obs_batch, term_batch = zip(*random.sample(self.storage, num_samples))
        return np.concatenate(obs_batch), act_batch, rew_batch, np.concatenate(next_obs_batch), term_batch

    def __len__(self):
        return len(self.storage)


MAX_LOG_STD = 2
MIN_LOG_STD = -20
NUMERICAL_STABILITY_EPS = 1e-6

class PolicyNetwork(nn.Module):
    def __init__(self, observation_dim, action_dim, hidden_units, action_limit_high):
        super(PolicyNetwork, self).__init__()
        self.layer1 = nn.Linear(observation_dim, hidden_units)
        self.layer2 = nn.Linear(hidden_units, hidden_units)
        self.mean_layer = nn.Linear(hidden_units, action_dim)
        self.log_std_layer = nn.Linear(hidden_units, action_dim)
        self.action_scale = torch.tensor(action_limit_high, dtype=torch.float32, device=compute_device)

    def forward(self, observation):
        features = F.relu(self.layer1(observation))
        features = F.relu(self.layer2(features))
        action_mean = self.mean_layer(features)
        action_log_std = self.log_std_layer(features)
        action_log_std = torch.clamp(action_log_std, min=MIN_LOG_STD, max=MAX_LOG_STD)
        return action_mean, action_log_std

    def sample_action(self, observation):
        mean, log_std = self.forward(observation)
        std = log_std.exp()
        distribution = Normal(mean, std)
        raw_action = distribution.rsample()
        squashed_action = torch.tanh(raw_action)
        scaled_action = squashed_action * self.action_scale
        log_prob = distribution.log_prob(raw_action)

        log_prob -= torch.log(self.action_scale * (1 - squashed_action.pow(2)) + NUMERICAL_STABILITY_EPS)
        log_prob = log_prob.sum(1, keepdim=True)
        deterministic_action = torch.tanh(mean) * self.action_scale
        return scaled_action, log_prob, deterministic_action

class QValueNetwork(nn.Module):
    def __init__(self, observation_dim, action_dim, hidden_units):
        super(QValueNetwork, self).__init__()

        self.q1_layer1 = nn.Linear(observation_dim + action_dim, hidden_units)
        self.q1_layer2 = nn.Linear(hidden_units, hidden_units)
        self.q1_output = nn.Linear(hidden_units, 1)

        self.q2_layer1 = nn.Linear(observation_dim + action_dim, hidden_units)
        self.q2_layer2 = nn.Linear(hidden_units, hidden_units)
        self.q2_output = nn.Linear(hidden_units, 1)

    def forward(self, observation, action):
        state_action_input = torch.cat([observation, action], 1)

        q1_val = F.relu(self.q1_layer1(state_action_input))
        q1_val = F.relu(self.q1_layer2(q1_val))
        q1_val = self.q1_output(q1_val)

        q2_val = F.relu(self.q2_layer1(state_action_input))
        q2_val = F.relu(self.q2_layer2(q2_val))
        q2_val = self.q2_output(q2_val)
        return q1_val, q2_val


class SoftActorCriticAgent:
    def __init__(self, obs_dim, act_dim, act_limit, hidden_dim, actor_lr, critic_lr, alpha_lr, gamma, tau, alpha_init, tune_alpha, target_ent=None):
        self.gamma = gamma
        self.tau = tau
        self.temperature = alpha_init
        self.tune_temperature = tune_alpha


        self.actor_net = PolicyNetwork(obs_dim, act_dim, hidden_dim, act_limit).to(compute_device)
        self.actor_optim = optim.Adam(self.actor_net.parameters(), lr=actor_lr)


        self.critic_net = QValueNetwork(obs_dim, act_dim, hidden_dim).to(compute_device)
        self.critic_target_net = QValueNetwork(obs_dim, act_dim, hidden_dim).to(compute_device)
        self.critic_target_net.load_state_dict(self.critic_net.state_dict())
        self.critic_optim = optim.Adam(self.critic_net.parameters(), lr=critic_lr)


        if self.tune_temperature:
            if target_ent is None:
                self.target_entropy = -float(act_dim)
            else:
                self.target_entropy = target_ent
            self.log_temperature = torch.zeros(1, requires_grad=True, device=compute_device)
            self.temperature_optim = optim.Adam([self.log_temperature], lr=alpha_lr)
            self.temperature = self.log_temperature.exp().item()
        else:
            self.target_entropy = None
            self.log_temperature = None
            self.temperature_optim = None

    def get_action(self, current_obs, is_eval=False):
        current_obs_tensor = torch.FloatTensor(current_obs).to(compute_device).unsqueeze(0)
        if is_eval is False:
            action_tensor, _, _ = self.actor_net.sample_action(current_obs_tensor)
        else:
            _, _, action_tensor = self.actor_net.sample_action(current_obs_tensor)
        return action_tensor.detach().cpu().numpy()[0]

    def train_step(self, replay_buffer, sample_size):
        if len(replay_buffer) < sample_size:
            return 0.0, 0.0, 0.0


        obs_sample, act_sample, rew_sample, next_obs_sample, term_sample = replay_buffer.sample_batch(sample_size)


        obs_tensor = torch.FloatTensor(obs_sample).to(compute_device)
        next_obs_tensor = torch.FloatTensor(next_obs_sample).to(compute_device)
        act_tensor = torch.FloatTensor(np.array(act_sample)).to(compute_device)
        rew_tensor = torch.FloatTensor(rew_sample).to(compute_device).unsqueeze(1)
        term_tensor = torch.FloatTensor(np.float32(term_sample)).to(compute_device).unsqueeze(1)


        with torch.no_grad():
            next_act_tensor, next_log_prob_tensor, _ = self.actor_net.sample_action(next_obs_tensor)
            q1_target_next, q2_target_next = self.critic_target_net(next_obs_tensor, next_act_tensor)
            min_q_target_next = torch.min(q1_target_next, q2_target_next)

            q_target = rew_tensor + (1 - term_tensor) * self.gamma * (min_q_target_next - self.temperature * next_log_prob_tensor)

        current_q1, current_q2 = self.critic_net(obs_tensor, act_tensor)
        q1_loss = F.mse_loss(current_q1, q_target)
        q2_loss = F.mse_loss(current_q2, q_target)
        critic_total_loss = q1_loss + q2_loss

        self.critic_optim.zero_grad()
        critic_total_loss.backward()
        self.critic_optim.step()


        for p in self.critic_net.parameters():
            p.requires_grad = False

        sampled_actions, log_probs, _ = self.actor_net.sample_action(obs_tensor)
        q1_for_actor, q2_for_actor = self.critic_net(obs_tensor, sampled_actions)
        min_q_for_actor = torch.min(q1_for_actor, q2_for_actor)

        actor_policy_loss = (self.temperature * log_probs - min_q_for_actor).mean()

        self.actor_optim.zero_grad()
        actor_policy_loss.backward()
        self.actor_optim.step()


        for p in self.critic_net.parameters():
            p.requires_grad = True


        temp_loss = torch.tensor(0.).to(compute_device)
        if self.tune_temperature:

            temp_loss = -(self.log_temperature.exp() * (log_probs + self.target_entropy).detach()).mean()

            self.temperature_optim.zero_grad()
            temp_loss.backward()
            self.temperature_optim.step()
            self.temperature = self.log_temperature.exp().item()


        for target_p, local_p in zip(self.critic_target_net.parameters(), self.critic_net.parameters()):
            target_p.data.copy_(self.tau * local_p.data + (1.0 - self.tau) * target_p.data)

        return critic_total_loss.item(), actor_policy_loss.item(), temp_loss.item()


def run_evaluation(agent_policy, environment, num_episodes=10, should_render=False):
    total_rew = 0.0
    print(f"\n--- Starting Final Evaluation ({num_episodes} episodes) ---")
    for ep_idx in range(num_episodes):
        obs, _ = environment.reset()
        ep_reward = 0.0
        terminated = False
        truncated = False
        t = 0
        while not terminated and not truncated:
            action = agent_policy.get_action(obs, is_eval=True)
            next_obs, reward, terminated, truncated, _ = environment.step(action)
            ep_reward += reward
            obs = next_obs
            t += 1
            if should_render:
                try:
                    environment.render()
                    time.sleep(0.01)
                except Exception as render_err:
                    print(f"Rendering failed during eval ep {ep_idx+1}: {render_err}")
                    should_render = False
        print(f"  Eval Episode {ep_idx+1}/{num_episodes} | Reward: {ep_reward:.2f}")
        total_rew += ep_reward
    avg_rew = total_rew / num_episodes
    print(f"--- Final Evaluation Finished | Average Reward: {avg_rew:.2f} ---")
    return avg_rew


if __name__ == "__main__":
    train_env = gym.make(ENV_ID, render_mode=None)
    eval_render_mode = "human" if RENDER_EVALUATION else None
    eval_environment = gym.make(ENV_ID, render_mode=eval_render_mode)

    obs_dimension = train_env.observation_space.shape[0]
    act_dimension = train_env.action_space.shape[0]
    action_limit = train_env.action_space.high[0] if isinstance(train_env.action_space.high, np.ndarray) else train_env.action_space.high

    print(f"Environment: {ENV_ID}")
    print(f"State Dim: {obs_dimension}, Action Dim: {act_dimension}, Action Max: {action_limit}")
    print(f"Max Timesteps: {TOTAL_TRAINING_STEPS}")
    print(f"Early Stopping: Avg Reward > {STOP_REWARD_TARGET} over last {STOP_WINDOW} episodes.")

    sac_agent = SoftActorCriticAgent(obs_dimension, act_dimension, action_limit, NETWORK_HIDDEN_UNITS, ACTOR_LR, CRITIC_LR, ALPHA_LR, DISCOUNT_FACTOR, TARGET_UPDATE_RATE, INITIAL_ALPHA, AUTO_TUNE_TEMP)

    replay_memory = ExperienceReplayBuffer(BUFFER_CAPACITY)

    total_steps = 0
    ep_count = 0
    current_ep_reward = 0
    current_ep_steps = 0
    is_done = False
    is_truncated = False
    stopped_early = False
    current_obs, _ = train_env.reset(seed=RANDOM_SEED)

    training_start_time = time.time()
    reward_log_deque = deque(maxlen=REWARD_LOG_WINDOW)
    stopping_deque = deque(maxlen=STOP_WINDOW)

    print("Starting training...")

    while total_steps < TOTAL_TRAINING_STEPS and not stopped_early:
        current_ep_steps += 1
        total_steps += 1

        if total_steps < EXPLORATION_STEPS:
            current_action = train_env.action_space.sample()
        else:
            current_action = sac_agent.get_action(current_obs)

        next_obs, reward_val, is_terminated, is_truncated, _ = train_env.step(current_action)
        is_done = is_terminated or is_truncated
        done_signal = float(is_done)

        replay_memory.add_transition(current_obs, current_action, reward_val, next_obs, done_signal)

        current_obs = next_obs
        current_ep_reward += reward_val

        if total_steps > EXPLORATION_STEPS:
            crit_loss, act_loss, alph_loss = sac_agent.train_step(replay_memory, TRAIN_BATCH_SIZE)

            if total_steps % (LOGGING_INTERVAL * 5) == 0:
                 print(f"       T: {total_steps}/{TOTAL_TRAINING_STEPS} | Losses C:{crit_loss:.2f} A:{act_loss:.2f} Alpha:{alph_loss:.2f} | Alpha: {sac_agent.temperature:.3f}")

        if is_done:
            reward_log_deque.append(current_ep_reward)
            stopping_deque.append(current_ep_reward)

            avg_reward_log = np.mean(reward_log_deque) if reward_log_deque else 0.0

            elapsed_seconds = time.time() - training_start_time
            time_str = str(datetime.timedelta(seconds=int(elapsed_seconds)))

            print(f"Ep {ep_count+1}, Reward: {current_ep_reward:.2f}, Avg ({REWARD_LOG_WINDOW}): {avg_reward_log:.2f}, Steps: {current_ep_steps}, Total T: {total_steps}/{TOTAL_TRAINING_STEPS}, Time: {time_str}")


            if len(stopping_deque) == STOP_WINDOW:
                avg_reward_stopping = np.mean(stopping_deque)
                print(f"  Avg Reward (last {STOP_WINDOW} eps): {avg_reward_stopping:.2f} (Threshold: {STOP_REWARD_TARGET})")
                if avg_reward_stopping > STOP_REWARD_TARGET:
                    print(f"\n--- Early Stopping Condition Met! ---")
                    print(f"Average reward over last {STOP_WINDOW} episodes ({avg_reward_stopping:.2f}) exceeded threshold ({STOP_REWARD_TARGET}).")
                    stopped_early = True


            current_obs, _ = train_env.reset()
            current_ep_reward = 0
            current_ep_steps = 0
            ep_count += 1
            is_done = False
            is_truncated = False


    elapsed_seconds = time.time() - training_start_time
    time_str = str(datetime.timedelta(seconds=int(elapsed_seconds)))
    stop_msg = 'Early Stop' if stopped_early else 'Max Timesteps'
    print(f"\n--- Training Finished ({stop_msg}) ---")
    print(f"Total Timesteps: {total_steps}")
    print(f"Total Episodes: {ep_count}")
    print(f"Total Time: {time_str}")


    final_eval_reward = run_evaluation(sac_agent, eval_environment, num_episodes=10, should_render=RENDER_EVALUATION)


    stop_tag = "earlystop" if stopped_early else "maxt"
    actor_save_path = f'sac_actor_{ENV_ID}_{stop_tag}_{total_steps}.pth'
    critic_save_path = f'sac_critic_{ENV_ID}_{stop_tag}_{total_steps}.pth'
    torch.save(sac_agent.actor_net.state_dict(), actor_save_path)
    torch.save(sac_agent.critic_net.state_dict(), critic_save_path)
    print(f"--- Final Model Saved ---")
    print(f"Actor: {actor_save_path}")
    print(f"Critic: {critic_save_path}")


    train_env.close()
    eval_environment.close()
    print("Script finished.")

Using device: cpu
Environment: BipedalWalker-v3
State Dim: 24, Action Dim: 4, Action Max: 1.0
Max Timesteps: 1000000
Early Stopping: Avg Reward > 250.0 over last 20 episodes.
Starting training...
Ep 1, Reward: -108.19, Avg (100): -108.19, Steps: 51, Total T: 51/1000000, Time: 0:00:00
Ep 2, Reward: -114.23, Avg (100): -111.21, Steps: 96, Total T: 147/1000000, Time: 0:00:00
Ep 3, Reward: -80.16, Avg (100): -100.86, Steps: 1600, Total T: 1747/1000000, Time: 0:00:01
Ep 4, Reward: -109.20, Avg (100): -102.94, Steps: 88, Total T: 1835/1000000, Time: 0:00:01
Ep 5, Reward: -120.91, Avg (100): -106.54, Steps: 124, Total T: 1959/1000000, Time: 0:00:01
Ep 6, Reward: -117.94, Avg (100): -108.44, Steps: 91, Total T: 2050/1000000, Time: 0:00:01
Ep 7, Reward: -119.60, Avg (100): -110.03, Steps: 79, Total T: 2129/1000000, Time: 0:00:01
Ep 8, Reward: -81.26, Avg (100): -106.44, Steps: 1600, Total T: 3729/1000000, Time: 0:00:02
Ep 9, Reward: -113.52, Avg (100): -107.22, Steps: 49, Total T: 3778/1000000,