In [1]:
import datetime
import os
import argparse
import torch

from rlpyt.runners.minibatch_rl import MinibatchRlEval, MinibatchRl
from rlpyt.samplers.serial.sampler import SerialSampler
from rlpyt.utils.logging.context import logger_context
from rlpyt.envs.atari.atari_env import AtariTrajInfo

from dreamer_agent import AtariDreamerAgent
from algorithm import Dreamer
from envs.atari import AtariEnv
from envs.wrapper import make_wapper
from envs.one_hot import OneHotAction
from envs.time_limit import TimeLimit




def build_and_train(
    log_dir,
    game="pong",
    run_ID=0,
    cuda_idx=None,
    eval=False,
    save_model="last",
    load_model_path=None,
):
    params = torch.load(load_model_path) if load_model_path else {}
    agent_state_dict = params.get("agent_state_dict")
    optimizer_state_dict = params.get("optimizer_state_dict")

    action_repeat = 2
    env_kwargs = dict(
        name=game,
        action_repeat=action_repeat,
        size=(64, 64),
        grayscale=False,
        life_done=True,
        sticky_actions=False,
    )
    factory_method = make_wapper(
        AtariEnv,
        [OneHotAction, TimeLimit],
        [dict(), dict(duration=1000 / action_repeat)],
    )
    sampler = SerialSampler(
        EnvCls=factory_method,
        TrajInfoCls=AtariTrajInfo,  # default traj info + GameScore
        env_kwargs=env_kwargs,
        eval_env_kwargs=env_kwargs,
        batch_T=1,
        batch_B=1,
        max_decorrelation_steps=0,
        eval_n_envs=10,
        eval_max_steps=int(10e3),
        eval_max_trajectories=5,
    )
    algo = Dreamer(
        horizon=10,
        kl_scale=0.1,
        use_pcont=True,
        initial_optim_state_dict=optimizer_state_dict,
    )
    agent = AtariDreamerAgent(
        train_noise=0.4,
        eval_noise=0,
        expl_type="epsilon_greedy",
        expl_min=0.1,
        expl_decay=2000 / 0.3,
        initial_model_state_dict=agent_state_dict,
        model_kwargs=dict(use_pcont=True),
    )
    runner_cls = MinibatchRlEval if eval else MinibatchRl
    runner = runner_cls(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=5e6,
        log_interval_steps=1e3,
        affinity=dict(cuda_idx=cuda_idx),
    )
    config = dict(game=game)
    name = "dreamer_" + game
    with logger_context(
        log_dir,
        run_ID,
        name,
        config,
        snapshot_mode=save_model,
        override_prefix=True,
        use_summary_writer=True,
    ):
        runner.train()



2024-01-07 16:11:22.717921: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
log_dir = os.path.abspath('/home/eddy/Projects/RL_project/logs_atari')

build_and_train(
    log_dir ,
    game="assault",
    run_ID=0,
    cuda_idx=0,
    eval=False,
    save_model="last",
    load_model_path=False,
)

2024-01-07 16:11:24.031908  | dreamer_assault_0 Runner  master CPU affinity: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11].
2024-01-07 16:11:24.032600  | dreamer_assault_0 Runner  master Torch threads: 6.
[32musing seed 6944[0m
2024-01-07 16:11:25.489589  | dreamer_assault_0 Sampler decorrelating envs, max steps: 0
2024-01-07 16:11:25.490290  | dreamer_assault_0 Serial Sampler initialized.
2024-01-07 16:11:25.490614  | dreamer_assault_0 Running 5000000 iterations of minibatch RL.
2024-01-07 16:11:26.247878  | dreamer_assault_0 Initialized agent model on device: cuda:0.
DAI CHE SIAMO VICINIIIIIIII
size passato a SequenceNStepReturnBuffer = 300000
2024-01-07 16:11:26.249539  | dreamer_assault_0 Optimizing over 1000 iterations.


  super().__init__(params, defaults)


2024-01-07 16:11:30.648241  | dreamer_assault_0 itr #999 saving snapshot...
2024-01-07 16:11:30.683151  | dreamer_assault_0 itr #999 saved
2024-01-07 16:11:30.689410  | -----------------------------  -----------
2024-01-07 16:11:30.689923  | Diagnostics/NewCompletedTrajs     3
2024-01-07 16:11:30.690231  | Diagnostics/StepsInTrajWindow   915
2024-01-07 16:11:30.690590  | Diagnostics/Iteration           999
2024-01-07 16:11:30.691053  | Diagnostics/CumTime (s)           4.43372
2024-01-07 16:11:30.692392  | Diagnostics/CumSteps           1000
2024-01-07 16:11:30.692868  | Diagnostics/CumCompletedTrajs     3
2024-01-07 16:11:30.693276  | Diagnostics/CumUpdates            0
2024-01-07 16:11:30.694197  | Diagnostics/StepsPerSecond      225.544
2024-01-07 16:11:30.694667  | Diagnostics/UpdatesPerSecond      0
2024-01-07 16:11:30.694997  | Diagnostics/ReplayRatio           0
2024-01-07 16:11:30.695412  | Diagnostics/CumReplayRatio        0
2024-01-07 16:11:30.695902  | Length/Average        

Imagination: 100%|██████████| 100/100 [01:57<00:00,  1.17s/it]


2024-01-07 16:13:47.608939  | dreamer_assault_0 itr #5999 saving snapshot...
2024-01-07 16:13:47.718611  | dreamer_assault_0 itr #5999 saved
2024-01-07 16:13:47.726787  | -----------------------------  ---------------
2024-01-07 16:13:47.727238  | Diagnostics/NewCompletedTrajs      5
2024-01-07 16:13:47.727628  | Diagnostics/StepsInTrajWindow   5995
2024-01-07 16:13:47.728021  | Diagnostics/Iteration           5999
2024-01-07 16:13:47.728429  | Diagnostics/CumTime (s)          141.469
2024-01-07 16:13:47.728849  | Diagnostics/CumSteps            6000
2024-01-07 16:13:47.729703  | Diagnostics/CumCompletedTrajs     22
2024-01-07 16:13:47.730186  | Diagnostics/CumUpdates             0
2024-01-07 16:13:47.730601  | Diagnostics/StepsPerSecond         8.23055
2024-01-07 16:13:47.731053  | Diagnostics/UpdatesPerSecond       0
2024-01-07 16:13:47.731473  | Diagnostics/ReplayRatio            0
2024-01-07 16:13:47.731955  | Diagnostics/CumReplayRatio         0
2024-01-07 16:13:47.732813  | Lengt

Imagination: 100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


2024-01-07 16:15:55.970858  | dreamer_assault_0 itr #6999 saving snapshot...
2024-01-07 16:15:56.165458  | dreamer_assault_0 itr #6999 saved
2024-01-07 16:15:56.189491  | -----------------------------  --------------
2024-01-07 16:15:56.191275  | Diagnostics/NewCompletedTrajs      4
2024-01-07 16:15:56.192486  | Diagnostics/StepsInTrajWindow   6999
2024-01-07 16:15:56.194509  | Diagnostics/Iteration           6999
2024-01-07 16:15:56.195827  | Diagnostics/CumTime (s)          269.917
2024-01-07 16:15:56.197700  | Diagnostics/CumSteps            7000
2024-01-07 16:15:56.198831  | Diagnostics/CumCompletedTrajs     26
2024-01-07 16:15:56.199776  | Diagnostics/CumUpdates             0
2024-01-07 16:15:56.201010  | Diagnostics/StepsPerSecond         7.78526
2024-01-07 16:15:56.201936  | Diagnostics/UpdatesPerSecond       0
2024-01-07 16:15:56.202735  | Diagnostics/ReplayRatio            0
2024-01-07 16:15:56.203639  | Diagnostics/CumReplayRatio         0
2024-01-07 16:15:56.206138  | Length

Imagination: 100%|██████████| 100/100 [02:03<00:00,  1.23s/it]


2024-01-07 16:18:03.855901  | dreamer_assault_0 itr #7999 saving snapshot...
2024-01-07 16:18:04.028521  | dreamer_assault_0 itr #7999 saved
2024-01-07 16:18:04.054324  | -----------------------------  --------------
2024-01-07 16:18:04.056083  | Diagnostics/NewCompletedTrajs      3
2024-01-07 16:18:04.057111  | Diagnostics/StepsInTrajWindow   7774
2024-01-07 16:18:04.058261  | Diagnostics/Iteration           7999
2024-01-07 16:18:04.059300  | Diagnostics/CumTime (s)          397.78
2024-01-07 16:18:04.060222  | Diagnostics/CumSteps            8000
2024-01-07 16:18:04.063117  | Diagnostics/CumCompletedTrajs     29
2024-01-07 16:18:04.064536  | Diagnostics/CumUpdates             0
2024-01-07 16:18:04.065592  | Diagnostics/StepsPerSecond         7.82086
2024-01-07 16:18:04.066762  | Diagnostics/UpdatesPerSecond       0
2024-01-07 16:18:04.067732  | Diagnostics/ReplayRatio            0
2024-01-07 16:18:04.068859  | Diagnostics/CumReplayRatio         0
2024-01-07 16:18:04.070073  | Length/

Imagination: 100%|██████████| 100/100 [02:06<00:00,  1.27s/it]


2024-01-07 16:20:14.832356  | dreamer_assault_0 itr #8999 saving snapshot...
2024-01-07 16:20:14.952993  | dreamer_assault_0 itr #8999 saved
2024-01-07 16:20:14.961429  | -----------------------------  --------------
2024-01-07 16:20:14.962074  | Diagnostics/NewCompletedTrajs      5
2024-01-07 16:20:14.962560  | Diagnostics/StepsInTrajWindow   8990
2024-01-07 16:20:14.962936  | Diagnostics/Iteration           8999
2024-01-07 16:20:14.963326  | Diagnostics/CumTime (s)          528.704
2024-01-07 16:20:14.963829  | Diagnostics/CumSteps            9000
2024-01-07 16:20:14.964666  | Diagnostics/CumCompletedTrajs     34
2024-01-07 16:20:14.965194  | Diagnostics/CumUpdates             0
2024-01-07 16:20:14.965843  | Diagnostics/StepsPerSecond         7.63804
2024-01-07 16:20:14.966425  | Diagnostics/UpdatesPerSecond       0
2024-01-07 16:20:14.966827  | Diagnostics/ReplayRatio            0
2024-01-07 16:20:14.967443  | Diagnostics/CumReplayRatio         0
2024-01-07 16:20:14.967896  | Length

Imagination:  23%|██▎       | 23/100 [00:29<01:39,  1.29s/it]


KeyboardInterrupt: 