In [4]:
%pip install -r requirements.txt

/bin/bash: /home/henriquesabino/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Note: you may need to restart the kernel to use updated packages.


In [5]:
import gym    
import slimevolleygym
from slimevolleygym import SurvivalRewardEnv
import numpy as np
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.resize_observation import ResizeObservation
from atari_wrappers import RenderWrapper, BufferWrapper, ImageToPyTorch

INPUT_SHAPE = (84, 84)
WINDOW_LENGTH = 4

env_name = 'SlimeVolleyNoFrameskip-v0'
env = gym.make(env_name)
env = SurvivalRewardEnv(env)
env = RenderWrapper(env)
env = ResizeObservation(env, INPUT_SHAPE)
env = GrayScaleObservation(env, True)
env = ImageToPyTorch(env)
env = BufferWrapper(env, WINDOW_LENGTH, np.uint8)

In [6]:
import os
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
from train_metrics import ModelMetricsCallback
from gym.wrappers.record_video import RecordVideo

def createOrLoadModel(model_prefix, model_suffix):
    if not os.path.exists('./models/'):
        return defaultDQNModel(), 0
    
    files = os.listdir('./models/')

    if len(files) == 0:
        return defaultDQNModel(), 0

    try:
        steps = [int(x.removeprefix(model_prefix).removesuffix(model_suffix)) for x in files]
        steps.sort()

        model_name = f'{model_prefix}{steps[-1]}{model_suffix}'
        print(f'Loading {model_name}')
        return DQN.load(model_name), steps[-1]
    except:
        return defaultDQNModel(), 0

def defaultDQNModel():
    return DQN("CnnPolicy", env, verbose=0, buffer_size=50000)
    
model, steps_done = createOrLoadModel('dqn_', '_steps.zip')

eval_env = gym.make(env_name)
eval_env = ResizeObservation(eval_env, INPUT_SHAPE)
eval_env = GrayScaleObservation(eval_env, True)
eval_env = ImageToPyTorch(eval_env)
eval_env = BufferWrapper(eval_env, WINDOW_LENGTH, np.uint8)

checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./models', name_prefix='dqn')
model_metrics_callback = ModelMetricsCallback(eval_env, './models',num_episodes=30, verbose=0)
metrics_callback = EveryNTimesteps(n_steps=1000, callback=model_metrics_callback)

callbacks=[checkpoint_callback, metrics_callback]
model = model.learn(total_timesteps=(1e6 - steps_done), log_interval=4, callback=callbacks)

record_env = RecordVideo(env, './videos')
obs = record_env.reset()

done = False
while not done:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, info = record_env.step(int(action))
record_env.close()

KeyboardInterrupt: 