In [None]:
# pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ --upgrade "sssnake"
# pip install stable_baselines3, matplotlib

In [None]:
import sssnake

In [None]:
import gymnasium as gym

base_env = gym.make("Sssnake-v0", render_mode="rgb_array")

In [None]:
class BasicObservationWrapper(gym.ObservationWrapper):
    def __init__(self, env: gym.Env, keys_to_keep=("head_position",
                                                   "head_direction_vec",
                                                   "candy_position")):
        super().__init__(env)
        self._keys = tuple(keys_to_keep)

        subspaces = {}
        for k in self._keys:
            subspaces[k] = env.observation_space[k]

        self.observation_space = gym.spaces.Dict(subspaces)

    def observation(self, obs):
        new_obs = {}
        for k in self._keys:
            new_obs[k] = obs[k]
        return new_obs

In [None]:
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

In [None]:
basic_env = BasicObservationWrapper(base_env)

flattened_env = gym.wrappers.FlattenObservation(basic_env)

vec_env = make_vec_env(lambda: flattened_env, n_envs=1, vec_env_cls=SubprocVecEnv)

In [None]:
model = PPO("MlpPolicy", vec_env, verbose=1, device="cpu")

In [None]:
model.learn(total_timesteps=500_000)

In [None]:
model.save("ppo_snake_candy")

In [None]:
import matplotlib as mpl
mpl.rcParams['animation.embed_limit'] = 100

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

render_env = gym.wrappers.RecordEpisodeStatistics(
    gym.make("Sssnake-v0", render_mode="rgb_array")
)

render_env = BasicObservationWrapper(render_env)
render_env = gym.wrappers.FlattenObservation(render_env)

obs, _ = render_env.reset()
frames = []
done = False

while not done:
    action, _ = model.predict(obs, deterministic=False)
    obs, reward, terminated, truncated, _ = render_env.step(action)
    done = terminated or truncated
    frame = render_env.render()
    frames.append(frame)
print("Frames for this episode: ", len(frames))

In [None]:
%matplotlib inline

fig = plt.figure(figsize=(6, 6))
plt.axis("off")
im = plt.imshow(frames[0])

def update(frame):
    im.set_array(frame)
    return [im]

ani = animation.FuncAnimation(fig, update, frames=frames, interval=20)
HTML(ani.to_jshtml())


In [None]:
ani.save("snake_animation.mp4", fps=50, extra_args=['-vcodec', 'libx264'])