# Soft Actor-Critic (SAC) â€” Lunar Lander (Continuous)

A concise notebook implementing SAC on `LunarLander-v3` (continuous action space).

## 1. Imports and Device Setup

In [None]:
from collections import deque, namedtuple
import gymnasium as gym
import numpy as np
import random

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [None]:
!pip install swig
!pip install gymnasium[box2d]

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
print(device)

In [None]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

## 2. Environment Initialization

In [None]:
env = gym.make("LunarLander-v3", continuous=True, render_mode="rgb_array")

In [None]:
env.reset()
frame = env.render()
plt.imshow(frame)
plt.show()

In [None]:
state_size = env.observation_space.shape
num_actions = env.action_space.shape[0]

print('State Shape:', state_size)
print('Number of actions:', num_actions)

## 3. Hyperparameters

In [None]:
LR = 3e-4
# Learning rate

gamma = 0.99
# Discount factor

batch_size = 256

tau = 0.005
# Soft update parameter

start_steps = int(1e4)
# Number of initial environment steps using random actions
# Used to fill replay buffer with diverse experience

memory_size = int(1e6)

replay_fill = int(1e4)
# Minimum number of transitions in replay buffer before training starts

target_entropy = - num_actions


## 4. Actor and Critic Networks

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(state_size[0]+num_actions, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)
    def forward(self, x, a):
        y1 = torch.relu(self.l1(torch.cat((x,a),1)))
        y2 = torch.relu(self.l2(y1))
        y = self.l3(y2)
        return y

In [None]:
class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(state_size[0], 256)
        self.l2 = nn.Linear(256, 256)
        self.mu = nn.Linear(256, num_actions)
        self.log_sigma = nn.Linear(256, num_actions)
    def forward(self, x, deterministic, with_log):
        y1 = torch.relu(self.l1(x))
        y2 = torch.relu(self.l2(y1))
        mu = self.mu(y2)

        if deterministic:
            action = torch.tanh(mu)
            log_prob = None
        else:
            log_sigma = self.log_sigma(y2)
            log_sigma = torch.clamp(log_sigma, min=-10.0, max=2.0)
            sigma = torch.exp(log_sigma)
            dist = torch.distributions.Normal(mu, sigma)
            x_t = dist.rsample()
            if with_log:
                log_prob = dist.log_prob(x_t).sum(dim=-1, keepdim=True)
                log_prob -= (2*(np.log(2) - x_t - nn.functional.softplus(-2*x_t))).sum(dim=-1, keepdim=True)
            else:
                log_prob = None
            action = torch.tanh(x_t)

        return action, log_prob


In [None]:
critic_1 = Critic().to(device)
critic_2 = Critic().to(device)
actor = Actor().to(device)
critic_target_1 = Critic().to(device)
critic_target_2 = Critic().to(device)
log_alpha = torch.tensor(np.log(0.2), dtype=torch.float32, device=device, requires_grad=True)

actor_optimizer = optim.Adam(actor.parameters(), lr=LR)
critic_optimizer_1 = optim.Adam(critic_1.parameters(), lr=LR)
critic_optimizer_2 = optim.Adam(critic_2.parameters(), lr=LR)
log_alpha_optimizer = torch.optim.Adam([log_alpha], lr=LR)

In [None]:
critic_target_1.load_state_dict(critic_1.state_dict())
critic_target_2.load_state_dict(critic_2.state_dict())


In [None]:
for param in critic_target_1.parameters():
    param.requires_grad = False
for param in critic_target_2.parameters():
    param.requires_grad = False

## 5. Replay Buffer

In [None]:
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
memory_buffer = deque(maxlen=memory_size)

def get_experiences(memory_buffer, batch_size):

    batch = random.sample(memory_buffer, batch_size)
    states_list = [e.state for e in batch]
    states = torch.tensor(np.vstack(states_list), dtype=torch.float32, device=device)

    actions_list = [e.action for e in batch]
    actions = torch.tensor(np.vstack(actions_list), dtype=torch.float32, device=device)

    rewards_list = [e.reward for e in batch]
    rewards = torch.tensor(np.vstack(rewards_list), dtype=torch.float32, device=device)

    next_states_list = [e.next_state for e in batch]
    next_states = torch.tensor(np.vstack(next_states_list), dtype=torch.float32, device=device)

    dones_list = [e.done for e in batch]
    dones = torch.tensor(np.vstack(dones_list), dtype=torch.float32, device=device)

    return (states, actions, rewards, next_states, dones)

## 6. SAC Update Functions

In [None]:
def update_networks(experiences, critic_1, critic_2, critic_target_1, critic_target_2, actor, gamma, log_alpha, criterion, critic_optimizer_1, critic_optimizer_2, actor_optimizer, log_alpha_optimizer, target_entropy):

    states, actions, rewards, next_states, dones = experiences

    q_vals_1 = critic_1(states, actions)
    q_vals_2 = critic_2(states, actions)

    with torch.no_grad():
        next_actions, log_probs = actor(next_states, False, True)
        log_probs = log_probs
        next_q_value_1 = critic_target_1(next_states, next_actions)
        next_q_value_2 = critic_target_2(next_states, next_actions)
        next_q_value = torch.min(next_q_value_1, next_q_value_2)
        target = rewards + gamma * (1 - dones) * (next_q_value - torch.exp(log_alpha)*log_probs)

    loss_1 = criterion(q_vals_1, target)
    loss_2 = criterion(q_vals_2, target)

    critic_optimizer_1.zero_grad()
    loss_1.backward()
    critic_optimizer_1.step()

    critic_optimizer_2.zero_grad()
    loss_2.backward()
    critic_optimizer_2.step()

    for param_1, param_2 in zip(critic_1.parameters(), critic_2.parameters()):
        param_1.requires_grad = False
        param_2.requires_grad = False

    actions_pi, log_probs_pi = actor(states, False, True)
    log_probs_pi = log_probs_pi
    q_value_pi_1 = critic_1(states, actions_pi)
    q_value_pi_2 = critic_2(states, actions_pi)
    q_value_pi = torch.min(q_value_pi_1, q_value_pi_2)

    actor_loss = - torch.mean(q_value_pi - torch.exp(log_alpha).detach() * log_probs_pi)
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

    alpha_loss = (torch.exp(log_alpha) * (-log_probs_pi - target_entropy).detach()).mean()
    log_alpha_optimizer.zero_grad()
    alpha_loss.backward()
    log_alpha_optimizer.step()

    for param_1, param_2 in zip(critic_1.parameters(), critic_2.parameters()):
        param_1.requires_grad = True
        param_2.requires_grad = True


In [None]:
def soft_update_target_network(critic_1, critic_2, critic_target_1, critic_target_2, tau):

    for target_param, param in zip(critic_target_1.parameters(), critic_1.parameters()):

        target_param.data = (tau * param.data + (1-tau) * target_param.data )

    for target_param, param in zip(critic_target_2.parameters(), critic_2.parameters()):

        target_param.data = (tau * param.data + (1-tau) * target_param.data )

## 7. Training Loop

Main training loop: environment interaction, fill replay buffer, sample batches, update actor/critics, soft-update target networks, and periodic logging.

In [None]:
num_episodes = 2000
start_episode = 0
criterion = nn.MSELoss()
num_to_print = 100
total_rewards_list = []
max_steps = 1000
CHECKPOINT_PATH = "sac_lunar_lander_checkpoint.pt"

for i in range(start_episode, num_episodes):
    state = env.reset()[0]
    total_rewards = 0

    for j in range(max_steps):
        with torch.no_grad():
            action, _ = actor(torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0), False, False)
            action = action.cpu().numpy()[0]

        next_state, reward, done, _, _ = env.step(action)
        total_rewards += reward

        exper = Experience(state, action, reward, next_state, done)
        memory_buffer.append(exper)

        if len(memory_buffer) >= replay_fill:

            experiences = get_experiences(memory_buffer, batch_size)

            update_networks(experiences, critic_1, critic_2, critic_target_1, critic_target_2,
                             actor, gamma, log_alpha, criterion, critic_optimizer_1, critic_optimizer_2,
                               actor_optimizer, log_alpha_optimizer, target_entropy)

            soft_update_target_network(critic_1, critic_2, critic_target_1, critic_target_2, tau)

        state = next_state

        if done:
            break

    total_rewards_list.append(total_rewards)
    avg_last_rewards = np.mean(total_rewards_list[-num_to_print:])

    if (i+1) % 20 == 0:
        checkpoint = {
            "episode": i + 1,
            "actor": actor.state_dict(),
            "critic_1": critic_1.state_dict(),
            "critic_2": critic_2.state_dict(),
            "critic_target_1": critic_target_1.state_dict(),
            "critic_target_2": critic_target_2.state_dict(),
            "actor_optimizer": actor_optimizer.state_dict(),
            "critic_optimizer_1": critic_optimizer_1.state_dict(),
            "critic_optimizer_2": critic_optimizer_2.state_dict(),
            "log_alpha": log_alpha.detach().cpu(),
            "log_alpha_optimizer": log_alpha_optimizer.state_dict(),
            "memory_buffer": list(memory_buffer),
            "total_rewards_list": total_rewards_list,
        }

        torch.save(checkpoint, CHECKPOINT_PATH)
        print(f"\rEpisode {i+1} | {avg_last_rewards} ", end="")

    if avg_last_rewards >= 200.0:
        print(f"\n\nEnvironment solved in {i+1} episodes!")
        torch.save(actor.state_dict(), "sac_lunar_lander_actor_network.pth")
        break




## 8. Evaluation

In [None]:
rewards_list = []
episodes = 50

actor.eval()
with torch.no_grad():
    for ep in range(episodes):
        state = env.reset()[0]
        done = False
        total_reward = 0.0

        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
            action, _ = actor(state_tensor, True, False)
            action = action.cpu().numpy()[0]

            next_state, reward, done, _, _ = env.step(action)
            total_reward += reward
            state = next_state

        rewards_list.append(total_reward)

print(f"Average reward over {episodes} episodes:", np.mean(rewards_list))

## 9. Video Recording

In [None]:
import imageio

In [None]:
import imageio
import torch
import numpy as np

def create_video(filename, env, actor, fps=30, max_steps=1000):
    frames = []

    device = next(actor.parameters()).device
    state = env.reset()[0]

    for _ in range(max_steps):

        frame = env.render()
        frames.append(frame)

        with torch.no_grad():
            state_tensor = torch.tensor(
                state,
                dtype=torch.float32,
                device=device
            ).unsqueeze(0)

            action, _ = actor(state_tensor, True, False)

        action = action.cpu().numpy()[0]

        state, _, done, _, _ = env.step(action)

        if done:
            break

    imageio.mimsave(
        filename,
        frames,
        fps=fps,
        codec="libx264"
    )


In [None]:
import IPython
import base64
def embed_mp4(filename):
    video = open(filename, "rb").read()
    b64 = base64.b64encode(video)
    tag = """
    <video width="840" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>""".format(
        b64.decode()
    )

    return IPython.display.HTML(tag)

In [None]:
filename = "./lunar_lander.mp4"
create_video(filename, env, actor)
embed_mp4(filename)