Реализуйте алгоритм SAC для среды lunar lander

In [1]:
!pip install swig
!pip install "gymnasium[box2d]"

Collecting swig
  Using cached swig-4.3.1-py3-none-win_amd64.whl.metadata (3.5 kB)
Using cached swig-4.3.1-py3-none-win_amd64.whl (2.6 MB)
Installing collected packages: swig
Successfully installed swig-4.3.1
Collecting box2d-py==2.3.5 (from gymnasium[box2d])
  Using cached box2d-py-2.3.5.tar.gz (374 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting pygame>=2.1.3 (from gymnasium[box2d])
  Downloading pygame-2.6.1-cp311-cp311-win_amd64.whl.metadata (13 kB)
Downloading pygame-2.6.1-cp311-cp311-win_amd64.whl (10.6 MB)
   ---------------------------------------- 0.0/10.6 MB ? eta -:--:--
   --- ------------------------------------ 1.0/10.6 MB 8.4 MB/s eta 0:00:02
   ----------------------- ---------------- 6.3/10.6 MB 20.3 MB/s eta 0:00:01
   ---------------------------------------- 10.6/10.6 MB 25.5 MB/s eta 0:00:00
Building wheels for collected packages: box2d-py
  Building wheel for box2d-py (setup.py): started
  Buildin

  DEPRECATION: Building 'box2d-py' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'box2d-py'. Discussion can be found at https://github.com/pypa/pip/issues/6334


In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import random
from torch.distributions import Normal

In [2]:
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4
REPLAY_SIZE = 100000
BATCH_SIZE = 256
START_STEPS = 10000
TOTAL_STEPS = 200000
UPDATE_AFTER = 1000
UPDATE_EVERY = 50

In [6]:
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, act_limit):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
        )
        self.mu_layer = nn.Linear(256, act_dim)
        self.log_std_layer = nn.Linear(256, act_dim)
        self.act_limit = act_limit

    def forward(self, obs):
        x = F.relu(self.net(obs))
        mean, std = self.mu_layer(x),  torch.clamp(self.log_std_layer(x), -20, 2).exp()
        normal = torch.distributions.Normal(mean, std)

        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * (action_high - action_low) / 2.0 + (action_low + action_high) / 2.0

        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        return action, log_prob

    def get_action(self, obs, deterministic=False):
        mu = self.mu_layer(self.net(obs))

        if deterministic:
              action = torch.tanh(mu) * self.act_limit
        else:
            with torch.no_grad():
              std = torch.clamp(self.log_std_layer(self.net(obs)), -20, 2).exp()
              normal = Normal(mu, std)
              u = normal.rsample()
              action = torch.tanh(u) * self.act_limit
        return action.cpu().numpy()




In [7]:
class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.q1 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, 256), nn.ReLU(),
            nn.Linear(256, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=-1)
        return self.q1(x), self.q2(x)

In [8]:
class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)

    def add(self, *args):
        self.buffer.append(tuple(args))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return (
            torch.tensor(states, dtype=torch.float32),
            torch.tensor(actions, dtype=torch.float32),
            torch.tensor(rewards, dtype=torch.float32).unsqueeze(1),
            torch.tensor(next_states, dtype=torch.float32),
            torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
        )

In [9]:
env = gym.make("LunarLanderContinuous-v3")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
action_low, action_high = float(env.action_space.low[0]), float(env.action_space.high[0])

act_limit  = (action_high - action_low) / 2


actor = Actor(obs_dim, act_dim, act_limit)
critic = Critic(obs_dim, act_dim)
critic_target = Critic(obs_dim, act_dim)
critic_target.load_state_dict(critic.state_dict())

actor_optim = optim.Adam(actor.parameters(), lr=ACTOR_LR)
critic_optim = optim.Adam(critic.parameters(), lr=CRITIC_LR)

replay_buffer = ReplayBuffer(REPLAY_SIZE)

obs, _ = env.reset()
episode_return, episode_len = 0, 0

In [10]:
def update():
    if REPLAY_SIZE < BATCH_SIZE:
        return

    state, action, reward, next_state, done = replay_buffer.sample(BATCH_SIZE)

    with torch.no_grad():
      next_action, next_log_prob = actor(next_state)
      q1_target, q2_target = critic_target(next_state, next_action)
      q_target = torch.min(q1_target, q2_target) - ALPHA * next_log_prob
      target = reward + (1 - done) * GAMMA * q_target

    q1, q2 = critic(state, action)
    critic_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)

    critic_optim.zero_grad()
    critic_loss.backward()
    critic_optim.step()

    new_action, new_log_prob = actor(state)
    q1, q2 =  critic(state, new_action)
    actor_loss = (ALPHA * new_log_prob - torch.min(q1, q2)).mean()

    actor_optim.zero_grad()
    actor_loss.backward()
    actor_optim.step()

    for param, target_param in zip(critic.parameters(), critic_target.parameters()):
        target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

In [13]:
device = "cpu"
actor.to(device)
critic.to(device)
critic_target.to(device)


for step in range(TOTAL_STEPS):
    if step < START_STEPS:
        act = env.action_space.sample()
    else:
        with torch.no_grad():
            obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
            act = actor.get_action(obs_t)[0]

    next_obs, rew, terminated, truncated, _ = env.step(act)
    done = terminated or truncated
    replay_buffer.add(obs, act, rew, next_obs, done)

    obs = next_obs
    episode_return += rew
    episode_len += 1

    if done:
        obs, _ = env.reset()
        print(f"Step: {step}, Return: {episode_return:.2f}, Len: {episode_len}")
        episode_return, episode_len = 0, 0

    if step >= UPDATE_AFTER and step % UPDATE_EVERY == 0:
        for _ in range(UPDATE_EVERY):
            update()

Step: 16, Return: -112.37, Len: 156
Step: 110, Return: -8.22, Len: 94
Step: 188, Return: -65.01, Len: 78
Step: 269, Return: -143.62, Len: 81
Step: 372, Return: -275.20, Len: 103
Step: 478, Return: -488.76, Len: 106
Step: 548, Return: -146.84, Len: 70
Step: 632, Return: -277.26, Len: 84
Step: 706, Return: -100.20, Len: 74
Step: 800, Return: -307.76, Len: 94
Step: 901, Return: -166.95, Len: 101
Step: 976, Return: -67.93, Len: 75
Step: 1165, Return: -209.79, Len: 189
Step: 1238, Return: -38.73, Len: 73
Step: 1387, Return: -295.04, Len: 149
Step: 1512, Return: -208.74, Len: 125
Step: 1596, Return: -122.18, Len: 84
Step: 1672, Return: -88.86, Len: 76
Step: 1822, Return: -234.44, Len: 150
Step: 1928, Return: -304.47, Len: 106
Step: 2084, Return: -74.82, Len: 156
Step: 2188, Return: -156.19, Len: 104
Step: 2292, Return: -255.42, Len: 104
Step: 2442, Return: -208.43, Len: 150
Step: 2542, Return: -283.29, Len: 100
Step: 2663, Return: -478.93, Len: 121
Step: 2750, Return: -457.06, Len: 87
Step: 

KeyboardInterrupt: 