In [1]:
from time import sleep

import numpy as np
import torch
from torch import nn
import gym
from gym.envs.registration import register
from stable_baselines3 import *
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import wandb
from wandb.integration.sb3 import WandbCallback

from gsnake.env import GoogleSnakeEnv
from gsnake.configs import GoogleSnakeConfig

register(
    id='GoogleSnake-v1',
    entry_point=GoogleSnakeEnv,
    max_episode_steps=500,
)


class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
        super(CustomCNN, self).__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, 64), nn.ReLU())
        self.linear2 = nn.Sequential(nn.Linear(64, features_dim))

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.linear(self.cnn(observations)))

cnnpolicy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=128),
    normalize_images=False
)

pygame 2.1.0 (SDL 2.0.16, Python 3.9.13)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
config = GoogleSnakeConfig(
    multi_channel=True,
    # reward_mode='basic',
    reward_mode='time_constrained',
    reward_scale=1,
    n_foods=3
)
name = 'DQN_MLP_time_nch'
run = wandb.init(
    job_type='train', config=config.__dict__,
    project='RL2',
    tags=[name.split('_')[0], 'gsnake'],
    name=name,
    sync_tensorboard=True,
    monitor_gym='False'
)
# Parallel environments
env = make_vec_env("GoogleSnake-v1", n_envs=10, env_kwargs={'config':config})
model = DQN("MlpPolicy", env, verbose=0, tensorboard_log=f'runs/{run.id}')

model.learn(total_timesteps=1_000_000, callback=WandbCallback(verbose=2), progress_bar=True)
run.finish()

model.save(f'{name}.pt')
del model # remove to demonstrate saving and loading

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdev-jahn[0m. Use [1m`wandb login --relogin`[0m to force relogin


Output()

0,1
global_step,▁▁▁▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███
rollout/ep_len_mean,▃▃█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃
rollout/ep_rew_mean,▃▆▄▅▄▃▃▄▃▄▁▂▂▅▄▅▅▃▂▂▄▄▃▄▃▄▃█▃▃▃▁▂▄▃▃▅▆▃▅
rollout/exploration_rate,█▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
time/fps,▇▆▅█▇▆▆▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▂▂▁▁▂▁▁▄▁▁▂▂▁▅▁▁▁▂▅▁▅▁▅▁▅▁▂▄▇▂▂█▁▁▅▂▂▂▅

0,1
global_step,999980.0
rollout/ep_len_mean,25.22
rollout/ep_rew_mean,5.96
rollout/exploration_rate,0.05
time/fps,1446.0
train/learning_rate,0.0001
train/loss,1.42113


In [6]:
####################################################################
# Human evaluation
####################################################################
model = PPO.load("PPO_MLP_time.pt")
config = GoogleSnakeConfig(
    # reward_mode='basic',
    multi_channel=True,
    reward_mode='time_constrained',
    reward_scale=1,
    n_foods=3
)
env = GoogleSnakeEnv(config, 42, 'gui')
obs = env.reset()
try:
    while True:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
        sleep(0.5)
except KeyboardInterrupt:
    print('Terminated')
finally:
    env.close()

Terminated
