In [1]:
import torch, os, random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
from collections import deque
import gymnasium as gym

In [2]:
SEED = 42
ENV_NAME = "LunarLander-v3"
GAMMA = 0.99
MAX_ENVS = 8
LEARNING_RATE = 5e-4
MAX_STEPS = 256
TOTAL_STEPS = 10_000_000
BATCH_SIZE = MAX_ENVS * MAX_STEPS
NUM_UPDATES = TOTAL_STEPS // BATCH_SIZE
VALUE_COEFF = 0.5
ENTROPY_COEFF = 0.01
LOG_EVERY_N_STEPS = 50
print("Num updates: ", NUM_UPDATES)

random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Num updates:  4882
cuda


In [3]:
def make_env(idx, is_eval=False):
    def _init():
        render_mode = None if not is_eval else "rgb_array"
        env = gym.make(id=ENV_NAME, render_mode=render_mode)
        env.action_space.seed(SEED + idx)
        return env
    return _init

In [4]:
envs = gym.vector.SyncVectorEnv([make_env(idx) for idx in range(MAX_ENVS)])

In [5]:
action_space = envs.single_action_space.n
observation_space = envs.single_observation_space.shape[0]
print(f"Action space: {action_space}")
print(f"Observation space: {observation_space}")

Action space: 4
Observation space: 8


In [6]:
obs, _ = envs.reset()

In [7]:
class Actor(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()
        print(f"Observation space: {observation_space}, action space: {action_space}")
        self.fc1 = nn.Linear(in_features=observation_space, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=256)
        self.fc3 = nn.Linear(in_features=256, out_features=256)

        self.out = nn.Linear(in_features=256, out_features=action_space)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.softmax(self.out(x), dim=-1)
        return x

    def get_action(self, x):
        action_probs = self.forward(x)
        dist = Categorical(probs=action_probs)
        action = dist.sample()
        return action, dist.log_prob(action), dist.entropy()


In [8]:
class Critic(nn.Module):
    def __init__(self, observation_space, action_space):
        super().__init__()
        print(f"Observation space: {observation_space}, action space: {action_space}")
        self.fc1 = nn.Linear(in_features=observation_space, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=128)
        self.fc3 = nn.Linear(in_features=128, out_features=128)
        self.out = nn.Linear(in_features=128, out_features=1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))

        return self.out(x)

In [9]:
actor_net = Actor(observation_space, action_space).to(device=device)
critic_net = Critic(observation_space, action_space).to(device=device)

optimizer = optim.AdamW(list(actor_net.parameters()) + list(critic_net.parameters()), lr=LEARNING_RATE)

Observation space: 8, action space: 4
Observation space: 8, action space: 4


In [10]:
test_x = torch.randn(size=(2, 8), device=device)
with torch.no_grad():
    test_action, test_action_probs, test_entropy = actor_net.get_action(test_x)
    test_value = critic_net(test_x)

print(test_action.shape)
print(test_action_probs.shape)
print(test_entropy.shape)

torch.Size([2])
torch.Size([2])
torch.Size([2])


In [11]:
rewards_storage = torch.zeros(size=(MAX_STEPS, MAX_ENVS), device=device)
dones_storage = torch.zeros(size=(MAX_STEPS, MAX_ENVS), device=device)

In [12]:
obs, _ = envs.reset()
obs = torch.Tensor(obs).to(device)
done = torch.zeros(MAX_ENVS, device=device)

for update in range(1, NUM_UPDATES + 1):

    log_probs_list = []
    entropies_list = []
    values_list = []

    for step in range(MAX_STEPS):
        
        dones_storage[step] = done

        action, log_probs, entropy = actor_net.get_action(obs)

        values = critic_net(obs).flatten()
        log_probs_list.append(log_probs)
        entropies_list.append(entropy)
        values_list.append(values)

        obs, reward, terminated, truncated, info = envs.step(action.cpu().numpy())

        rewards_storage[step] = torch.as_tensor(reward, dtype=torch.float32, device=device)
        done = torch.as_tensor(np.logical_or(terminated, truncated), dtype=torch.float32, device=device)
        obs = torch.as_tensor(obs, dtype=torch.float32, device=device)
    
    log_probs_storage = torch.stack(log_probs_list)   # (T, N)
    entropies = torch.stack(entropies_list)          # (T, N)
    values_storage = torch.stack(values_list) 

    returns = torch.zeros_like(rewards_storage, device=device)
    with torch.no_grad():
        next_value = critic_net(obs).squeeze() * (1.0 - done)
    
    gt_next_state = next_value
    for t in reversed(range(MAX_STEPS)):
        returns[t] = rewards_storage[t] + GAMMA * gt_next_state 
        gt_next_state = returns[t] * (1 - dones_storage[t])

    returns_flat = returns.view(-1)
    values_flat = values_storage.view(-1)
    log_probs_flat = log_probs_storage.view(-1)
    entropies_flat = entropies.view(-1)
    
    advantages = (returns_flat - values_flat).detach()
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    policy_loss = -(log_probs_flat * advantages).mean()
    critic_loss = VALUE_COEFF * F.smooth_l1_loss(values_flat, returns_flat)
    entropy_loss = -ENTROPY_COEFF * entropies_flat.mean()

    total_loss = policy_loss + critic_loss + entropy_loss
    optimizer.zero_grad()
            
    total_loss.backward()
    nn.utils.clip_grad_norm_(list(actor_net.parameters()) + list(critic_net.parameters()), 1.0)
    optimizer.step()

    if update % LOG_EVERY_N_STEPS == 0:
        print(f"[STEP]: {update}, [TOTAL_LOSS]: {total_loss.item()}, [ACTOR_LOSS]: {policy_loss.item()}, [CRITIC_LOSS]: {critic_loss.item()}, [REWARDS]: {rewards_storage.mean().item()}, [MEAN ENTROPY]: {entropies_flat.mean().item()}")

[STEP]: 50, [TOTAL_LOSS]: 25.68102264404297, [ACTOR_LOSS]: -0.005055042915046215, [CRITIC_LOSS]: 25.699565887451172, [REWARDS]: -1.191617727279663, [MEAN ENTROPY]: 1.3488019704818726
[STEP]: 100, [TOTAL_LOSS]: 10.635954856872559, [ACTOR_LOSS]: -0.006335852202028036, [CRITIC_LOSS]: 10.654838562011719, [REWARDS]: -0.348219633102417, [MEAN ENTROPY]: 1.2547186613082886
[STEP]: 150, [TOTAL_LOSS]: 14.343056678771973, [ACTOR_LOSS]: -0.00582401268184185, [CRITIC_LOSS]: 14.360677719116211, [REWARDS]: -0.25883153080940247, [MEAN ENTROPY]: 1.1796603202819824
[STEP]: 200, [TOTAL_LOSS]: 11.706589698791504, [ACTOR_LOSS]: -0.0021101105958223343, [CRITIC_LOSS]: 11.72041130065918, [REWARDS]: 0.058290183544158936, [MEAN ENTROPY]: 1.1711573600769043
[STEP]: 250, [TOTAL_LOSS]: 2.3185698986053467, [ACTOR_LOSS]: 0.005786249414086342, [CRITIC_LOSS]: 2.3254294395446777, [REWARDS]: 0.024688702076673508, [MEAN ENTROPY]: 1.264561414718628
[STEP]: 300, [TOTAL_LOSS]: 9.494268417358398, [ACTOR_LOSS]: 0.039219651371

In [27]:
def evaluate(model, device, num_eval_eps=1, record=False):
    eval_env = make_env(99, is_eval=True)()
    eval_env.action_space.seed(SEED)
    
    model = model.to(device)
    model.eval()
    returns = []
    frames = []

    model.eval()
    for eps in range(num_eval_eps):
        obs, _ = eval_env.reset()
        done = False
        episode_reward = 0.0
   
        while not done:
            if record:
                frame = eval_env.render()
                frames.append(frame)

            with torch.no_grad():
                
                act, _, _ = model.get_action(torch.tensor(obs, device=device, dtype=torch.float32).unsqueeze(0))
                # print(act.cpu().numpy().item())
                obs, reward, terminated, truncated, _ = eval_env.step(act.cpu().numpy().item())
                done = terminated or truncated
                
                episode_reward += reward
                
        returns.append(episode_reward)
    
    eval_env.close()
    model.train()
    return returns, frames

In [28]:
import imageio
train_video_path = rf"B:\Pytorch\RL\videos\a2c_lunar_lander.mp4"
returns, frames = evaluate(actor_net, device, record=True, num_eval_eps=1)

if frames and len(frames) > 0:
    imageio.mimsave(
        train_video_path,
        frames,
        fps=30,
        codec='libx264',
        macro_block_size=1
    )