In [1]:
import gymnasium as gym 
import torch
from collections import namedtuple
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from gymnasium.wrappers import RecordVideo

In [2]:
HIDDEN_SIZE = 128
BATCH_SIZE = 32
PERCENTILE = 70
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Policy(torch.nn.Module):
    def __init__(self, input_size, ouput_size, hidden_size):
        super(Policy, self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(input_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, ouput_size)
        )
    
    def forward(self, x):
        return self.net(x)

In [4]:
Episode = namedtuple('Episode', field_names=['reward', 'steps'])
EpisodeStep = namedtuple('EpisodeStep', field_names=['observation', 'action']) 

In [5]:
def iterate_bathes(env, policy, batch_size, device):
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs, info = env.reset()
    sm = torch.nn.Softmax(dim=1)
    while True: 
        obs_v = torch.tensor(np.array([obs]), dtype=torch.float32).to(device)
        act_probs_v = sm(policy(obs_v))
        act_probs = act_probs_v.cpu().data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, terminated, truncated, info = env.step(action)
        episode_reward += reward
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)
        if (terminated or truncated):
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)
            episode_reward = 0.0
            episode_steps = []
            obs, info = env.reset()
            if (len(batch) == batch_size):
                yield batch
                batch = []
        obs = next_obs

In [6]:
def filter_batch(batch, percentile):
    rewards = list(map(lambda s: s.reward, batch))
    reward_bound = np.percentile(rewards, percentile)
    reward_mean = np.mean(rewards)
    train_obs = []
    train_act = []
    for reward, steps in batch:
        if (reward >= reward_bound):
            train_obs.extend(map(lambda step: step.observation, steps))
            train_act.extend(map(lambda step: step.action, steps))
    train_obs_tensor = torch.tensor(train_obs, dtype=torch.float32)
    train_act_tensor = torch.tensor(train_act, dtype=torch.float32)
    return train_obs_tensor, train_act_tensor, reward_bound, reward_mean

In [7]:
class CustomTimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps):
        super().__init__(env)
        self.max_episode_steps = max_episode_steps
        self.current_step = 0

    def reset(self, **kwargs):
        self.current_step = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        self.current_step += 1
        if self.current_step >= self.max_episode_steps:
            truncated = True
        return observation, reward, terminated, truncated, info

In [8]:
env = gym.make("CartPole-v0", render_mode="rgb_array", max_episode_steps=5000)
_ = env.reset()
obs_size = env.observation_space.shape[0]
#act_size = env.action_space.shape[0]
n_actions = env.action_space.n
policy = Policy(obs_size, n_actions, HIDDEN_SIZE).to(device)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=policy.parameters(), lr=0.01)
writer = SummaryWriter(comment="CartPole-v0")

  logger.deprecation(


In [9]:
for iter_num, batch in enumerate(iterate_bathes(env, policy, BATCH_SIZE, device)):
    obs_tensor, act_tensor, reward_bound, reward_mean = filter_batch(batch, PERCENTILE)
    optimizer.zero_grad()
    action_scores_tensor = policy(obs_tensor.to(device))
    loss_tensor = loss(action_scores_tensor, (act_tensor.long()).to(device))
    loss_tensor.backward()
    optimizer.step()
    print("%d: loss=%.3f, reward_mean=%.1f, rw_bound=%.1f" % (
            iter_num, loss_tensor.item(), reward_mean, reward_bound))
    writer.add_scalar("loss", loss_tensor.item(), iter_num)
    writer.add_scalar("reward_bound", reward_bound, iter_num)
    writer.add_scalar("reward_mean", reward_mean, iter_num)

    if reward_mean > 2000:
        print("Solved!")
        break
writer.close()

  train_obs_tensor = torch.tensor(train_obs, dtype=torch.float32)


0: loss=0.689, reward_mean=25.1, rw_bound=24.7
1: loss=0.655, reward_mean=29.8, rw_bound=35.7
2: loss=0.642, reward_mean=23.3, rw_bound=24.7
3: loss=0.615, reward_mean=43.0, rw_bound=43.0
4: loss=0.596, reward_mean=52.8, rw_bound=58.7
5: loss=0.575, reward_mean=49.6, rw_bound=50.7
6: loss=0.553, reward_mean=55.7, rw_bound=60.7
7: loss=0.520, reward_mean=65.8, rw_bound=78.4
8: loss=0.522, reward_mean=76.3, rw_bound=81.0
9: loss=0.511, reward_mean=88.1, rw_bound=105.5
10: loss=0.493, reward_mean=77.8, rw_bound=85.4
11: loss=0.461, reward_mean=86.1, rw_bound=98.7
12: loss=0.441, reward_mean=78.9, rw_bound=90.1
13: loss=0.444, reward_mean=94.8, rw_bound=100.0
14: loss=0.434, reward_mean=104.2, rw_bound=113.4
15: loss=0.408, reward_mean=123.6, rw_bound=134.4
16: loss=0.422, reward_mean=117.5, rw_bound=123.4
17: loss=0.434, reward_mean=91.9, rw_bound=98.1
18: loss=0.396, reward_mean=98.2, rw_bound=104.8
19: loss=0.405, reward_mean=135.5, rw_bound=151.8
20: loss=0.394, reward_mean=240.0, rw_b

In [27]:
env = gym.make('CartPole-v0', render_mode='rgb_array', max_episode_steps=5000)
env = RecordVideo(env, video_folder="./videos", name_prefix="eval",
                  episode_trigger=lambda x: True)


  logger.deprecation(
  logger.warn(


In [28]:
policy.eval()
obs, info = env.reset()
sm = torch.nn.Softmax(dim=1)
for i in range(5000):
    with torch.no_grad():
        act_probs_v = sm(policy(torch.tensor([obs], dtype=torch.float32).to(device)))
        act_probs = act_probs_v.cpu().data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        obs, reward, terminated, truncated, info = env.step(action)
        if (terminated or truncated):
            env.close()
            print(f'Agent failed on step: {i}')
            break
env.close()


Moviepy - Building video /home/dmitriy/ITMO/DISS/chapter_4/videos/eval-episode-0.mp4.
Moviepy - Writing video /home/dmitriy/ITMO/DISS/chapter_4/videos/eval-episode-0.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/dmitriy/ITMO/DISS/chapter_4/videos/eval-episode-0.mp4
Agent failed on step: 868
