In [12]:
import torch as th
import torch.nn as nn
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class CustomCNN(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.AvgPool2d(2),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        print(self.cnn(observations).shape)
        return self.linear(self.cnn(observations))

policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=128),
    net_arch=[16, 1024, 32, 32, 16, 4]
)
# model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
# model.learn(1000)

In [2]:
from M2048Cnn import M2048
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
import numpy as np

def linear_schedule(initial_value, final_value=0.0):

    if isinstance(initial_value, str):
        initial_value = float(initial_value)
        final_value = float(final_value)
        assert (initial_value > 0.0)

    def scheduler(progress):
        return final_value + progress * (initial_value - final_value)

    return scheduler

NUM_ENVS = 32
make_env = lambda seed=None: Monitor(M2048(4, seed=seed))
# env = SubprocVecEnv([make_env(seed=s) for s in np.random.randint(1,1e9, NUM_ENVS)])
env = DummyVecEnv([make_env])
# env = ActionMasker(env, MinesweeperEnv.get_action_mask)
lr_schedule = linear_schedule(2.5e-2, 2.5e-5)
clip_range_schedule = linear_schedule(0.15, 0.025)
if 0:
    model = MaskablePPO(
        "MlpPolicy", 
        env=env, 
        batch_size=2048 * 16,
        policy_kwargs=policy_kwargs,
        verbose=1,
        tensorboard_log="./tensorboard/",
        learning_rate=lr_schedule,
        clip_range=clip_range_schedule,
        device='cuda',
    )
else:
    model = MaskablePPO.load("./model/2048_14.pkl", env=env)
for i in range(100):
    model.learn(total_timesteps=1e5)
    model.save(f"./model/2048_{i}.pkl")


Logging to ./tensorboard/PPO_24


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 169      |
|    ep_rew_mean     | 1.64e+03 |
| time/              |          |
|    fps             | 778      |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 163          |
|    ep_rew_mean          | 1.52e+03     |
| time/                   |              |
|    fps                  | 425          |
|    iterations           | 2            |
|    time_elapsed         | 9            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0021813652 |
|    clip_fraction        | 0.0508       |
|    clip_range           | 0.147        |
|    entropy_loss         | -0.862       |
|    explained_variance   | 0            |
|    learning_r

In [None]:
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)
env.close()
mean_reward, std_reward