In [1]:
import torch

module("unload", "cuda/11.6")
module("load", "cuda/11.4")
module("load","ffmpeg")
module("list")

In [2]:
import sys
from pathlib import Path
path_root1 = Path( '/cobra/u/kkumari/warp-drive')
path_root2 = Path( '/project_ghent/warp-drive/')
sys.path.append(str(path_root1))
sys.path.append(str(path_root2))

In [3]:
from warp_drive.env_wrapper import EnvWrapper
from warp_drive.utils.common import get_project_root

from animations import (
    generate_tag_env_rollout_animation,
)

In [4]:
from gym.spaces import Discrete, MultiDiscrete
from IPython.display import HTML
import yaml
import numpy as np

In [5]:
# Set logger level e.g., DEBUG, INFO, WARNING, ERROR
import logging

logging.getLogger().setLevel(logging.ERROR)

In [6]:
# Load the run config.

# Here we show an example configures

CFG = """
# Sample YAML configuration for the tag continuous environment
name: "tag_continuous"

# Environment settings
env:
    num_preys: 50
    num_predators: 1
    stage_size: 30
    episode_length: 500
    preparation_length: 100
    max_acceleration: 0.1
    max_turn: 2.35  # 3*pi/4 radians
    num_acceleration_levels: 10
    num_turn_levels: 10
    starving_penalty_for_predator: -1.0
    surviving_reward_for_prey: 1.0
    edge_hit_penalty: -0.1
    end_of_game_penalty : -100.0
    end_of_game_reward: 100.0
    use_full_observation: False
    eating_distance: 0.05
    seed: 274880
    env_backend: "numba"

# Trainer settings
trainer:
    num_envs: 400 # number of environment replicas
    train_batch_size: 10000 # total batch size used for training per iteration (across all the environments)
    num_episodes: 500 # number of episodes to run the training for (can be arbitrarily high)
# Policy network settings
policy: # list all the policies below
    prey:
        to_train: True # flag indicating whether the model needs to be trained
        algorithm: "A2C" # algorithm used to train the policy
        gamma: 0.98 # discount rate gamms
        lr: 0.005 # learning rate
        vf_loss_coeff: 1 # loss coefficient for the value function loss
        entropy_coeff:
        - [0, 0.5]
        - [2000000, 0.05]
        model: # policy model settings
            module_name: "fully_connected" # model type
            class_name: "FullyConnected" # class type
            fc_dims: [256, 256] # dimension(s) of the fully connected layers as a list
            model_ckpt_filepath: "" # filepath (used to restore a previously saved model)
    predator:
        to_train: True
        algorithm: "A2C"
        gamma: 0.98
        lr: 0.002
        vf_loss_coeff: 1
        model:
            type: "fully_connected"
            fc_dims: [256, 256]
            model_ckpt_filepath: ""

# Checkpoint saving setting
saving:
    metrics_log_freq: 100 # how often (in iterations) to print the metrics
    model_params_save_freq: 5000 # how often (in iterations) to save the model parameters
    basedir: "/tmp" # base folder used for saving
    name: "collective_v0"
    tag: "50preys_1predator"

"""

run_config = yaml.safe_load(CFG)

In [7]:
from warp_drive.utils.env_registrar import EnvironmentRegistrar
from custom_env import CUDACustomEnv

env_registrar = EnvironmentRegistrar()
env_registrar.add_cuda_env_src_path(CUDACustomEnv.name, "custom_env_step_numba", env_backend="numba")

env_wrapper = EnvWrapper(
    env_obj=CUDACustomEnv(**run_config["env"]),
    num_envs=run_config["trainer"]["num_envs"],
    env_backend="numba",
    env_registrar=env_registrar
)

  deprecation(


function_manager: Setting Numba to use CUDA device 0


In [8]:
policy_tag_to_agent_id_map = {
    "predator": list(env_wrapper.env.predators),
    "prey": list(env_wrapper.env.preys),
}

In [9]:
import warp_drive.training.trainer
from warp_drive.training.trainer import Trainer
from importlib import reload
reload(warp_drive.training.trainer)
trainer = Trainer(
    env_wrapper=env_wrapper,
    config=run_config,
    policy_tag_to_agent_id_map=policy_tag_to_agent_id_map,
    num_devices=torch.cuda.device_count(),
)

Policy module FullyConnected loaded from warp_drive.training.models.fully_connected
Policy module FullyConnected loaded from warp_drive.training.models.fully_connected


In [10]:
trainer.train()



Device: 0
Iterations Completed                    : 1 / 25
Speed performance stats
Mean policy eval time per iter (ms)     :     367.69
Mean action sample time per iter (ms)   :      51.73
Mean env. step time per iter (ms)       :     211.82
Mean training time per iter (ms)        :     178.67
Mean total time per iter (ms)           :     825.83
Mean steps per sec (policy eval)        :   27196.61
Mean steps per sec (action sample)      :  193312.80
Mean steps per sec (env. step)          :   47210.13
Mean steps per sec (training time)      :   55969.88
Mean steps per sec (total)              :   12109.09
Metrics for policy 'prey'
VF loss coefficient                     :    1.00000
Entropy coefficient                     :    0.50000
Total loss                              :  187.24503
Policy loss                             :   49.55555
Value function loss                     :  140.08598
Mean rewards                            :    0.99992
Max. rewards                            :

In [12]:
trainer.load_model_checkpoint(
    {
        "prey": "/tmp/collective_v0/50preys_1predator/1679099911/prey_250000.state_dict",
        "predator": "/tmp/collective_v0/50preys_1predator/1679099911/predator_250000.state_dict",
    }
)

[Device 0]: Loading the provided trainer model checkpoints. 
[Device 0]: Loading the 'prey' torch model from the previously saved checkpoint: '/tmp/collective_v0/50preys_1predator/1679099911/prey_250000.state_dict' 
[Device 0]: Updating the timestep for the 'prey' model to 250000. 
[Device 0]: Loading the 'predator' torch model from the previously saved checkpoint: '/tmp/collective_v0/50preys_1predator/1679099911/predator_250000.state_dict' 
[Device 0]: Updating the timestep for the 'predator' model to 250000. 
