In [1]:
import torch

In [2]:
module("unload", "cuda/11.6")
module("load", "cuda/11.4")
module("list")

Currently Loaded Modulefiles:
 1) rvs/1.0(default)   2) anaconda/3/2021.11   3) cuda/11.4  

Key:
(symbolic-version)  


In [3]:
import sys
from pathlib import Path
path_root = Path( '/cobra/u/kkumari/warp-drive')
sys.path.append(str(path_root))

In [4]:
from warp_drive.env_wrapper import EnvWrapper
from warp_drive.utils.common import get_project_root

from animations import (
    generate_tag_env_rollout_animation,
)

In [5]:
from gym.spaces import Discrete, MultiDiscrete
from IPython.display import HTML
import yaml
import numpy as np

In [6]:
# Set logger level e.g., DEBUG, INFO, WARNING, ERROR
import logging

logging.getLogger().setLevel(logging.ERROR)

In [7]:
# Load the run config.

# Here we show an example configures

CFG = """
# Sample YAML configuration for the tag continuous environment
name: "tag_continuous"

# Environment settings
env:
    num_preys: 50
    num_predators: 1
    stage_size: 30
    episode_length: 500
    preparation_length: 100
    max_acceleration: 0.1
    max_turn: 2.35  # 3*pi/4 radians
    num_acceleration_levels: 10
    num_turn_levels: 10
    eating_reward_for_predator: 10.0
    eating_penalty_for_prey: -10.0
    edge_hit_penalty: -0.0
    end_of_game_penalty : -1.0
    end_of_game_reward: 1.0
    use_full_observation: False
    eating_distance: 0.05
    seed: 274880
    env_backend: "numba"

# Trainer settings
trainer:
    num_envs: 400 # number of environment replicas
    train_batch_size: 10000 # total batch size used for training per iteration (across all the environments)
    num_episodes: 500 # number of episodes to run the training for (can be arbitrarily high)
# Policy network settings
policy: # list all the policies below
    prey:
        to_train: True # flag indicating whether the model needs to be trained
        algorithm: "A2C" # algorithm used to train the policy
        gamma: 0.98 # discount rate gamms
        lr: 0.005 # learning rate
        vf_loss_coeff: 1 # loss coefficient for the value function loss
        entropy_coeff:
        - [0, 0.5]
        - [2000000, 0.05]
        model: # policy model settings
            module_name: "fully_connected" # model type
            class_name: "FullyConnected" # class type
            fc_dims: [256, 256] # dimension(s) of the fully connected layers as a list
            model_ckpt_filepath: "" # filepath (used to restore a previously saved model)
    predator:
        to_train: True
        algorithm: "A2C"
        gamma: 0.98
        lr: 0.002
        vf_loss_coeff: 1
        model:
            type: "fully_connected"
            fc_dims: [256, 256]
            model_ckpt_filepath: ""

# Checkpoint saving setting
saving:
    metrics_log_freq: 100 # how often (in iterations) to print the metrics
    model_params_save_freq: 5000 # how often (in iterations) to save the model parameters
    basedir: "/tmp" # base folder used for saving
    name: "collective_v0"
    tag: "50preys_1predator"

"""

run_config = yaml.safe_load(CFG)

In [8]:
from warp_drive.utils.env_registrar import EnvironmentRegistrar
from custom_env import CUDACustomEnv

env_registrar = EnvironmentRegistrar()
env_registrar.add_cuda_env_src_path(CUDACustomEnv.name, "custom_env_step_numba", env_backend="numba")

env_wrapper = EnvWrapper(
    env_obj=CUDACustomEnv(**run_config["env"]),
    num_envs=run_config["trainer"]["num_envs"],
    env_backend="numba",
    env_registrar=env_registrar
)

  deprecation(


function_manager: Setting Numba to use CUDA device 0




In [9]:
policy_tag_to_agent_id_map = {
    "predator": list(env_wrapper.env.predators),
    "prey": list(env_wrapper.env.preys),
}

In [10]:
import warp_drive.training.trainer
from warp_drive.training.trainer import Trainer
from importlib import reload
reload(warp_drive.training.trainer)
trainer = Trainer(
    env_wrapper=env_wrapper,
    config=run_config,
    policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,
    num_devices=torch.cuda.device_count(),
)

Policy module FullyConnected loaded from warp_drive.training.models.fully_connected
Policy module FullyConnected loaded from warp_drive.training.models.fully_connected


since Python 3.9 and will be removed in a subsequent version. The only 
supported seed types are: None, int, float, str, bytes, and bytearray.
  random.seed(seed)


In [11]:
module("load","ffmpeg")
anim = generate_tag_env_rollout_animation(trainer)
HTML(anim.to_html5_video())

In [12]:
trainer.train()



Device: 0
Iterations Completed                    : 1 / 25
Speed performance stats
Mean policy eval time per iter (ms)     :     101.86
Mean action sample time per iter (ms)   :      35.53
Mean env. step time per iter (ms)       :     339.65
Mean training time per iter (ms)        :     115.98
Mean total time per iter (ms)           :     602.30
Mean steps per sec (policy eval)        :   98174.95
Mean steps per sec (action sample)      :  281471.18
Mean steps per sec (env. step)          :   29442.05
Mean steps per sec (training time)      :   86218.67
Mean steps per sec (total)              :   16602.99
Metrics for policy 'prey'
VF loss coefficient                     :    1.00000
Entropy coefficient                     :    0.50000
Total loss                              :   -2.38771
Policy loss                             :    0.00908
Value function loss                     :    0.00002
Mean rewards                            :    0.00000
Max. rewards                            :

In [13]:
trainer.load_model_checkpoint(
    {
        "prey": "/tmp/collective_v0/50preys_1predator/1678980039/prey_250000.state_dict",
        "predator": "/tmp/collective_v0/50preys_1predator/1678980039/predator_250000.state_dict",
    }
)

[Device 0]: Loading the provided trainer model checkpoints. 
[Device 0]: Loading the 'prey' torch model from the previously saved checkpoint: '/tmp/collective_v0/50preys_1predator/1678980039/prey_250000.state_dict' 
[Device 0]: Updating the timestep for the 'prey' model to 250000. 
[Device 0]: Loading the 'predator' torch model from the previously saved checkpoint: '/tmp/collective_v0/50preys_1predator/1678980039/predator_250000.state_dict' 
[Device 0]: Updating the timestep for the 'predator' model to 250000. 


In [14]:
# Visualize the entire episode roll-out
anim = generate_tag_env_rollout_animation(trainer)
HTML(anim.to_html5_video())