In [24]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
import os
import agent
import environment

import jax.random as jrandom
import jax.numpy as jnp
import time

In [26]:
os.environ.setdefault('JAX_PLATFORM_NAME', 'gpu')     # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.7'  # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'    

In [27]:
def number_of_frames_processed_per_second(step_idx, time_to_start_training):
    time_taken_until_training = time.time() - time_to_start_training
    return step_idx / time_taken_until_training

In [28]:



env = environment.create("PongNoFrameskip-v4")
REWARD_BOUND_TO_SOLVE_PONG = 19.5
MINIMUM_REPLAY_BUFFER_SIZE = 64
SYNC_EVERY_N_STEPS = 1000


model_hparams = agent.ModelHParams()
agent_hparams = agent.AgentHParams()
net = agent.Agent(
    model_architecture=agent.Brain,
    model_hparams=model_hparams,
    agent_hparams=agent_hparams,
    n_actions=env.action_space.n,
)
target_net = agent.Agent(
    model_architecture=agent.Brain,
    model_hparams=model_hparams,
    agent_hparams=agent_hparams,
    n_actions=env.action_space.n,
)
rng = jrandom.PRNGKey(net.seed)

observation, _ = env.reset()
batch = jnp.expand_dims(observation, axis=0)
state = net.initialize(batch)
target_state = target_net.initialize(batch)

average_loss = []
total_reward = 0
total_rewards = []
best_reward = 0
step_idx = 0

In [7]:
time_to_start_training = time.time()
epsilon = 0.0
for _ in range(100):
    action = agent.policy(state, observation, env.action_space.n, epsilon, action_rng)



In [11]:
time_to_start_training = time.time()
for _ in range(100):
    train_batch = net.get_batch(rng)

In [12]:
time_to_start_training = time.time()
for _ in range(100):
    agent.train_step(train_batch, state, target_state, 0.99)

In [18]:

for _ in range(100):
    idxs = net.memory.get_batch_idxs(32, rng)


In [20]:
import dataclasses

In [21]:
for _ in range(100):
    states, actions, rewards, is_dones, next_states = zip(*[dataclasses.asdict(net.memory.memory[idx]).values() for idx in idxs])

In [23]:
for _ in range(100):
    return_value = (
            jnp.asarray(states, dtype=jnp.float32), 
            jnp.asarray(actions), 
            jnp.asarray(rewards, dtype=jnp.float32), 
            jnp.asarray(is_dones, dtype=jnp.uint8), 
            jnp.asarray(next_states, dtype=jnp.float32)
    )

In [30]:
time_to_start_training = time.time()
epsilon = 0.0
for step_idx in range(1000):
    action = agent.policy(state, observation, env.action_space.n, epsilon, action_rng)
    next_observation, reward, is_done, *_ = env.step(action)
    if is_done:
        next_observation, _ = env.reset()
    observation = next_observation
    train_batch = net.get_batch(rng)
    agent.train_step(train_batch, state, target_state, 0.99)

end_time = time.time()
fps = step_idx / (end_time - time_to_start_training)
print(fps)


In [29]:
import jax

# jax.profiler.start_server(9999)
total_reward = 0
best_reward = -100
step_idx = 0
episode_id = 0
time_to_start_training = time.time()

is_log_start_training = True

log_dir = "./log_dir/tensorboard/"

for _ in range(200):
    
    step_idx += 1
    rng, action_rng = jrandom.split(rng)
    epsilon = net.get_epsilon(step_idx=step_idx)
    time_to_execute_policy = time.time()
    with jax.profiler.trace(log_dir=log_dir):
        action = agent.policy(state, observation, env.action_space.n, epsilon, action_rng)
    with jax.profiler.trace(log_dir=log_dir):
        next_observation, reward, is_done, *_ = env.step(action)
    net.memory.add_experience_to_memory(observation, action, reward, is_done, next_observation)
    time_taken_to_execute_policy = time.time() - time_to_execute_policy
    total_reward += reward
    if is_done:
        episode_id += 1
        print("time taken to execute policy: %.3f" % time_taken_to_execute_policy)
        print("episode: %d, total steps: %d, reward: %.3f, best reward: %.3f, epsilon: %.2f" % (episode_id, step_idx, total_reward, best_reward, epsilon))
        print("number of frames processed per second: %.3f" % number_of_frames_processed_per_second(step_idx, time_to_start_training))
        next_observation, _ = env.reset()
        total_reward = 0
        if net.memory_size > MINIMUM_REPLAY_BUFFER_SIZE:
            print("time taken for training step: %.3f" % (time_train_step_end - time_train_step_start))

    observation = next_observation

    if net.memory_size < MINIMUM_REPLAY_BUFFER_SIZE:
        continue

    if is_log_start_training:
        time_taken_until_training = time.time() - time_to_start_training
        # print time taken until training in seconds
        print("time taken until training: %.3f" % time_taken_until_training)
        is_log_start_training = False

    if step_idx % SYNC_EVERY_N_STEPS == 0:
        target_state = agent.sync_target_network(state=state, target_state=target_state)


    with jax.profiler.trace(log_dir=log_dir):

        rng, batch_rng = jrandom.split(rng)
        loss, state, rng = net.train_step(state=state, target_state=target_state, rng=rng)
        time_train_step_end = time.time()
        loss.block_until_ready()
    
    


time taken until training: 30.113


In [1]:
import jax

with jax.profiler.trace(log_dir="./logdir/default"):
  key = jax.random.PRNGKey(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

2023-01-02 04:57:25.632052: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcupti.so.11.2'; dlerror: libcupti.so.11.2: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/intel64/lib:/opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/mic/lib:/opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/intel64/lib:/opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/mic/lib::/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64/:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64/
2023-01-02 04:57:25.642136: E external/org_tensorflow/tensorflow/core/profiler/backends/gpu/cupti_error_manager.cc:192] cuptiSubscribe: error 15: CUPTI_ERROR_NOT_INITIALIZED
2023-01-02 04:57:25.642158: E external/org_tensorflow/tensorflow/core/profiler/backends/gpu/cupti_error_manager.cc:457] cuptiGetResultString: ignored due to a previous error.
2023-01-02 0

In [1]:
import jax
with jax.profiler.trace("./logdir/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.PRNGKey(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

2023-01-02 05:03:25.471636: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcupti.so.11.2'; dlerror: libcupti.so.11.2: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/intel64/lib:/opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/mic/lib:/opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/intel64/lib:/opt/intel/compilers_and_libraries_2018.3.222/linux/mpi/mic/lib::/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64/:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64/
2023-01-02 05:03:25.481410: E external/org_tensorflow/tensorflow/core/profiler/backends/gpu/cupti_error_manager.cc:192] cuptiSubscribe: error 15: CUPTI_ERROR_NOT_INITIALIZED
2023-01-02 05:03:25.481429: E external/org_tensorflow/tensorflow/core/profiler/backends/gpu/cupti_error_manager.cc:457] cuptiGetResultString: ignored due to a previous error.
2023-01-02 0

ValueError: Invalid trace folder: /mnt/batch/tasks/shared/LS_root/mounts/clusters/gpu-deeprl/code/Users/stefruinard/deep_reinforcement_learning/003_dqn_pong/logdir/jax-trace/plugins/profile/2023_01_02_05_03_27

In [22]:
start_time = time.time()
env.reset()
i = 0
while True:
    i+=1
    action = agent.policy(state, observation, env.action_space.n, 0.01, action_rng)
    observation, reward, is_done, *_ = env.step(env.action_space.sample())
    if is_done:
        print("done")
        env.reset()
        end_time = time.time()
        break
    loss, state, rng = net.train_step(state=state, target_state=target_state, rng=rng)
print("Time taken to execute one episode: %.3f" % (end_time - start_time))

done
Time taken to execute one episode: 139.725


In [24]:
i / (end_time - start_time)

7.228485958879105

In [None]:
0.480