In [1]:
!pip install gymnasium[atari]
!pip install gymnasium[accept-rom-license]
!pip install stable_baselines3
!pip install utils

Defaulting to user installation because normal site-packages is not writeable
[33mDEPRECATION: matlabengineforpython R2021b has a non-standard version number. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of matlabengineforpython or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mDefaulting to user installation because normal site-packages is not writeable
[33mDEPRECATION: matlabengineforpython R2021b has a non-standard version number. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of matlabengineforpython or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mDefaulting to user installation because normal site-packages is not writeabl

In [2]:

import os
import random
import time
import uuid
from collections import deque, Counter, namedtuple, defaultdict
from itertools import count


import numpy as np
import gymnasium as gym
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, FireResetEnv, MaxAndSkipEnv, NoopResetEnv


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils as nn_utils
from torch.distributions import Categorical
from torch.nn import init



from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm


import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)



  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _HISTOGRAMPROTO = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.EnumValueDescriptor(
  _DATATYPE = _descriptor.EnumDescriptor(
  _descriptor.FieldDescriptor(
  _SERIALIZEDDTYPE = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _RESOURCEHANDLEPROTO_DTYPEANDSHAPE = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _TENSORPROTO = _descriptor.Descriptor(


In [3]:
class EnvironmentSetup:
    ENV_ARGS = {'id': "PongDeterministic-v4"}
    SEED = 1
    OUTPUT_DIR = 'output'
    NUM_ENVS = 1

    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.root_dir = os.getcwd()
        self.output_path = os.path.join(self.root_dir, self.OUTPUT_DIR)
        self.create_output_dir()
        self.set_random_seeds()
        print('device = ', self.device)

    def create_output_dir(self):
        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path)

    def set_random_seeds(self):
        random.seed(self.SEED)
        np.random.seed(self.SEED)
        torch.manual_seed(self.SEED)

class TrainingConfig:
    def __init__(self):
        self.lr = 3e-4
        self.num_steps = 2048
        self.num_envs = 3
        self.num_iterations = 800
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.update_epochs = 10
        self.clip_coef = 0.2
        self.entropy_coef = 0.01
        self.vf_coef = 0.5
        self.max_grad_norm = 0.5
        self.mini_batch_count = 64
        self.update_plots = 10
        self.output_dir = 'output'


In [4]:
class EnvironmentCreator:
    def __init__(self, env_args):
        self.env_args = env_args

    def make_env(self):
        env = gym.make(**self.env_args)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = NoopResetEnv(env, noop_max=30)
        env = MaxAndSkipEnv(env, skip=4)
        env = ClipRewardEnv(env)
        env = gym.wrappers.ResizeObservation(env, (84, 84))
        env = gym.wrappers.GrayScaleObservation(env)
        env = gym.wrappers.FrameStack(env, 4)
        return env

class Plotter:
    @staticmethod
    def plot(history, show=False, save_path=None):
        sns.lineplot(y=history['reward'], x=list(range(len(history['reward']))))
        if save_path:
            plt.savefig(save_path)
        if show:
            plt.show()
        plt.clf()
        plt.close()

class Evaluator:
    def __init__(self, agent, env_creator, episodes=10):
        self.agent = agent
        self.env_creator = env_creator
        self.episodes = episodes

    def evaluate(self):
        envs = gym.vector.SyncVectorEnv([lambda: self.env_creator.make_env()])
        self.agent.eval()
        total_rewards = []
        next_obs, _ = envs.reset()

        while len(total_rewards) < self.episodes:
            next_obs = torch.Tensor(next_obs)
            with torch.no_grad():
                action, log_prob, _, value = self.agent.get_action_and_value(next_obs)
            next_obs, reward, terminated, truncated, info = envs.step(action.numpy())

            if 'final_info' in info:
                for data in info['final_info']:
                    if data:
                        reward = data['episode']['r'][0]
                        total_rewards.append(reward)

        return total_rewards

In [5]:
class Agent(nn.Module):
    def __init__(self, action_space, hidden_size=512):
        super().__init__()
        self.action_space = action_space

        self.network = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Linear(7 * 7 * 64, hidden_size) 
        self.relu = nn.ReLU()
        self.actor = nn.Linear(hidden_size, self.action_space.n)
        self.critic = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = self.network(x)
        x = self.fc(x)
        x = self.relu(x)
        action_probs = self.actor(x)
        state_values = self.critic(x)
        return action_probs, state_values

    def get_action_and_value(self, x, action=None):
        logits, value = self.forward(x / 255.0)  
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        log_prob = probs.log_prob(action)
        entropy = probs.entropy()
        return action, log_prob, entropy, value

    def get_value(self, x):
        _, value = self.forward(x / 255.0)
        return value


In [6]:
class AgentTester:
    def __init__(self, env_creator, num_envs=1):
        self.envs = gym.vector.SyncVectorEnv(
            [lambda: env_creator.make_env() for _ in range(num_envs)]
        )
        assert isinstance(self.envs.single_action_space, gym.spaces.Discrete), 'Only discrete action is supported'

    def test_agent(self, agent):
        obs, info = self.envs.reset()
        obs = torch.tensor(obs).float()
        print('obs shape = ', obs.shape)

        action, log_prob, entropy, value = agent.get_action_and_value(obs)

        print('action shape = ', action.shape)
        print('log prob shape = ', log_prob.shape)
        print('entropy shape = ', entropy.shape)
        print('value shape = ', value.shape)

    def close_envs(self):
        self.envs.close()


# env_args = {'id': "PongDeterministic-v4"}
# env_creator = EnvironmentCreator(env_args)
# tester = AgentTester(env_creator, EnvironmentSetup.NUM_ENVS)
# agent = Agent(action_space=tester.envs.single_action_space)

# tester.test_agent(agent)
# tester.close_envs()


# del agent, tester

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


obs shape =  torch.Size([1, 4, 84, 84])
action shape =  torch.Size([1])
log prob shape =  torch.Size([1])
entropy shape =  torch.Size([1])
value shape =  torch.Size([1, 1])


In [7]:
class TrainingEnvironment:
    def __init__(self, config, num_envs, device):
        self.config = config
        self.envs = gym.vector.AsyncVectorEnv(
            [lambda: EnvironmentCreator({'id': "PongDeterministic-v4"}).make_env() for _ in range(num_envs)]
        )
        self.agent = Agent(self.envs.single_action_space).to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(self.agent.parameters(), lr=self.config.lr, eps=1e-5)
        self.check_initialization()
        self.setup_directories()
        self.best_score = float('-inf')

    def check_initialization(self):
        save_path = os.path.join(self.config.output_dir, 'ppo_checkpoint.torch')
        if os.path.exists(save_path):
            print("Loading checkpoint...")
            self.agent.load_state_dict(torch.load(save_path))
            print("Checkpoint loaded.")
        else:
            print("Starting from scratch...")

    def setup_directories(self):
        self.label = str(uuid.uuid4()).split('-')[0]
        self.save_dir = os.path.join(self.config.output_dir, self.label)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        self.fig_save_path = os.path.join(self.save_dir, 'plot.png')

    def train(self):
        M = self.config.num_steps
        N = self.config.num_envs
        reward_window = deque(maxlen=100)
        history = defaultdict(list)
        global_step = 0
        self.best_score = float('-inf')
        loss = float('inf')


        next_obs, _ = self.envs.reset()
        next_obs = torch.tensor(next_obs, device=self.device)
        next_done = torch.zeros(N, device=self.device)


        obs = torch.zeros((M, N) + self.envs.single_observation_space.shape, device=self.device)
        actions = torch.zeros((M, N) + self.envs.single_action_space.shape, device=self.device)
        log_probs = torch.zeros((M, N), device=self.device)
        rewards = torch.zeros((M, N), device=self.device)
        dones = torch.zeros((M, N), device=self.device)
        values = torch.zeros((M, N), device=self.device)

        self.loop = tqdm(range(self.config.num_iterations), desc="Training Loop")
        for iteration in self.loop:
            if iteration % self.config.update_plots == 0:
                self.plot(history)

            for step in range(M):
                global_step += N


                obs[step] = next_obs
                dones[step] = next_done


                with torch.no_grad():
                    action, log_prob, _, value = self.agent.get_action_and_value(next_obs)
                    values[step] = value.flatten()


                actions[step] = action
                log_probs[step] = log_prob


                next_obs, reward, terminated, truncated, info = self.envs.step(action.cpu().numpy())
                next_done = torch.logical_or(torch.tensor(terminated), torch.tensor(truncated)).to(self.device)

  
                rewards[step] = torch.tensor(reward, device=self.device).view(-1)
                next_obs = torch.tensor(next_obs, device=self.device)

 
                if 'final_info' in info:
                    for data in info['final_info']:
                        if data:
                            reward = data['episode']['r']
                            reward_window.append(reward)
                            avg_reward = torch.tensor(list(reward_window)).mean().item()
                            history['reward'].append(avg_reward)
                            self.loop.set_description(f"Reward = {avg_reward:.2f}, Global Step = {global_step}, Best Score = {self.best_score:.2f}, Loss = {loss:.2f}")

                            if self.best_score < avg_reward:
                                self.best_score = avg_reward
                                torch.save(self.agent.state_dict(), os.path.join(self.save_dir, 'ppo_checkpoint_best.torch'))


            self.optimize_policy(obs, actions, log_probs, rewards, dones, values, next_obs, next_done)


        torch.save(self.agent.state_dict(), os.path.join(self.save_dir, 'final_model.pth'))
        print("Training completed. Model saved.")

    def plot(self, history):
        if history['reward']:
            plt.figure(figsize=(10, 5))
            plt.plot(range(len(history['reward'])), history['reward'], label='Rewards')
            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.title('Training Rewards Over Time')
            plt.legend()
            plt.grid(True)
            plt.savefig(self.fig_save_path)
            plt.close()

    def optimize_policy(self, obs, actions, log_probs, rewards, dones, values, next_obs, next_done):
        M, N = self.config.num_steps, self.config.num_envs
        with torch.no_grad():

            next_value = self.agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards, device=self.device)
            last_gae_lam = 0


            for t in reversed(range(M)):
                if t == M - 1:
                    next_non_terminal = 1.0 - next_done.float()
                    next_values = next_value
                else:
                    next_non_terminal = 1.0 - dones[t + 1].float()
                    next_values = values[t + 1]

                delta = rewards[t] + self.config.gamma * next_values * next_non_terminal - values[t]
                advantages[t] = last_gae_lam = delta + self.config.gamma * self.config.gae_lambda * next_non_terminal * last_gae_lam


            returns = advantages + values


        b_obs = obs.view((-1,) + self.envs.single_observation_space.shape)
        b_actions = actions.view((-1,) + self.envs.single_action_space.shape)
        b_log_probs = log_probs.view(-1)
        b_advantages = advantages.view(-1)
        b_returns = returns.view(-1)
        b_values = values.view(-1)

        batch_size = M * N
        mini_batch_size = batch_size // self.config.mini_batch_count
        batch_indices = torch.arange(batch_size, device=self.device)

        for epoch in range(self.config.update_epochs):

            batch_indices = batch_indices[torch.randperm(batch_size)]

            for start in range(0, batch_size, mini_batch_size):
                end = start + mini_batch_size
                mini_indices = batch_indices[start:end]


                b_obs_mini = b_obs[mini_indices]
                b_actions_mini = b_actions[mini_indices]
                b_log_probs_old_mini = b_log_probs[mini_indices]
                b_advantages_mini = b_advantages[mini_indices]
                b_returns_mini = b_returns[mini_indices]


                _, new_log_prob, entropy, new_value = self.agent.get_action_and_value(b_obs_mini, b_actions_mini)


                log_ratio = new_log_prob - b_log_probs_old_mini
                ratio = torch.exp(log_ratio)
                surr1 = ratio * b_advantages_mini
                surr2 = torch.clamp(ratio, 1.0 - self.config.clip_coef, 1.0 + self.config.clip_coef) * b_advantages_mini
                policy_loss = -torch.min(surr1, surr2).mean()


                value_loss = 0.5 * (new_value.view(-1) - b_returns_mini).pow(2).mean()


                loss = policy_loss + self.config.vf_coef * value_loss - self.config.entropy_coef * entropy.mean()


                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.agent.parameters(), self.config.max_grad_norm)
                self.optimizer.step()


                clip_fraction = (torch.abs(ratio - 1.0) > self.config.clip_coef).float().mean().item()


            self.loop.set_description(f"Loss: {loss.item():.4f}, Clip Frac: {clip_fraction:.2f}, Best Score: {self.best_score:.2f}")




In [8]:
config = TrainingConfig()
env_creator = EnvironmentCreator({'id': "PongDeterministic-v4"})
device = 'cuda' if torch.cuda.is_available() else 'cpu'
training_environment = TrainingEnvironment(config, num_envs=1, device=device)
training_environment.train()


Starting from scratch...


Loss: -0.0153, Clip Frac: 0.29, Best Score: -2.25: 100%|██████████| 800/800 [7:49:44<00:00, 35.23s/it]                       

Training completed. Model saved.



