In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import yaml
import tempfile
import os
from pathlib import Path

import matplotlib.pyplot as plt

from stable_baselines3.common.results_plotter import plot_results, X_TIMESTEPS

In [None]:
algo = "dqn"
env = "SpaceInvadersNoFrameskip-v4"
log_dir = "logs/"
repo = f"{algo}-{env}"
hf_username = "AntonDergunov"
device = "mps"

In [None]:
config = {
    env: {
        "env_wrapper": ["stable_baselines3.common.atari_wrappers.AtariWrapper"],
        "frame_stack": 4,
        "policy": "CnnPolicy",
        "n_timesteps": 1e7,
        "buffer_size": 100_000,
        "learning_rate": 1e-4,
        "batch_size": 32,
        "learning_starts": 1_000,
        "target_update_interval": 1000,
        "train_freq": 4,
        "gradient_steps": 1,
        "exploration_fraction": 0.1,
        "exploration_final_eps": 0.01,
        "optimize_memory_usage": False,
    }
}

tmpdir = tempfile.gettempdir()
yaml_path = os.path.join(tmpdir, "space_invaders_dqn.yml")

with open(yaml_path, "w") as f:
    yaml.dump(config, f)

In [None]:
!python -m rl_zoo3.train \
    --algo {algo} \
    --env {env} \
    -f {log_dir} \
    -c {yaml_path} \
    --device {device}

In [None]:
# Find latest run folder automatically
base = Path(log_dir) / algo
runs = sorted(base.glob(f"{env}_*"), key=lambda p: p.stat().st_mtime)

if runs:
    latest = runs[-1]
    print("Plotting:", latest)
    plot_results([str(latest)], 10_000_000, X_TIMESTEPS, "DQN Results", figsize=(8,4))
    plt.show()
else:
    print("No runs found.")

In [None]:
!python -m rl_zoo3.enjoy \
    --algo {algo} \
    --env {env} \
    --no-render \
    --n-timesteps 5000 \
    --folder {log_dir}

In [None]:
!python -m rl_zoo3.push_to_hub \
    --algo {algo} \
    --env {env} \
    --repo-name {repo} \
    --orga {hf_username} \
    -f {logs}