In [1]:
from time import sleep

import numpy as np
import torch
from torch import nn
import gym
from gym.envs.registration import register
from stable_baselines3 import *
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import wandb
from wandb.integration.sb3 import WandbCallback

from gsnake.env import GoogleSnakeEnv
from gsnake.configs import GoogleSnakeConfig

register(
    id='GoogleSnake-v1',
    entry_point=GoogleSnakeEnv,
    max_episode_steps=1000,
)

pygame 2.1.0 (SDL 2.0.16, Python 3.9.13)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
name = 'PPO_MLP_time_nch_obsfix_50M'

config = GoogleSnakeConfig(
    # reward_mode='basic',
    multi_channel=True,
    reward_mode='time_constrained',
    reward_scale=1,
    n_foods=3
)
run = wandb.init(
    job_type='train', config=config.__dict__,
    project='RL2',
    tags=[name.split('_')[0], 'gsnake'],
    name=name,
    sync_tensorboard=True,
    monitor_gym='False'
)
# Parallel environments
env = make_vec_env("GoogleSnake-v1", n_envs=10, env_kwargs={'config':config})
policy_kwargs = {'normalize_images':False}
model = PPO(
    "MultiInputPolicy",
    env,
    policy_kwargs=policy_kwargs,
    verbose=0, tensorboard_log=f'runs/{run.id}')
# model = PPO("CnnPolicy", env, verbose=1, tensorboard_log=f'runs/{run.id}')

model.learn(total_timesteps=50_000_000, callback=WandbCallback(verbose=2), progress_bar=True)
run.finish()

model.save(f'{name}.pt')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdev-jahn[0m. Use [1m`wandb login --relogin`[0m to force relogin


Output()

0,1
global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
rollout/ep_len_mean,▁▂▄▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███▇▇██▇██▇▇███
rollout/ep_rew_mean,▁▁▂▃▃▄▄▄▅▅▅▅▆▆▅▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇██▇██▇█
time/fps,▁▂▄▆▆▇▇▇▇▇▇██████████████████▇▆▆▅▅▅▄▄▄▃▃
train/approx_kl,█▅▅▄▄▃▂▃▂▃▃▄▃▃▄▄▄▂▃▂▂▂▂▂▁▂▂▂▁▃▂▁▂▂▂▁▁▁▂▁
train/clip_fraction,█▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▂▁▁▁▁▁▂▁▁▁▁▁▂▁
train/clip_range,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/entropy_loss,▁▇██████████████████████████████████████
train/explained_variance,▁▂▄▃▅▅▅▆▇▆▇▇▆█▇█▇▆▇▆▆▆▆▇▇▇█▇▇▇▇██▇█▇█▇██
train/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
global_step,50012160.0
rollout/ep_len_mean,459.89001
rollout/ep_rew_mean,2988.29004
time/fps,1269.0
train/approx_kl,0.00398
train/clip_fraction,0.02447
train/clip_range,0.2
train/entropy_loss,-0.08138
train/explained_variance,0.90449
train/learning_rate,0.0003


In [6]:
####################################################################
# Human evaluation
####################################################################
model = PPO.load("PPO_MLP_time.pt")
config = GoogleSnakeConfig(
    # reward_mode='basic',
    multi_channel=True,
    reward_mode='time_constrained',
    reward_scale=1,
    n_foods=3
)
env = GoogleSnakeEnv(config, 42, 'gui')
obs = env.reset()
try:
    while True:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
        sleep(0.5)
except KeyboardInterrupt:
    print('Terminated')
finally:
    env.close()

Terminated
