In [1]:
%env CUDA_VISIBLE_DEVICES=0

from matplotlib.lines import Line2D
from matplotlib.patches import Circle
import matplotlib.pyplot as plt
import numpy as np
from brax import jumpy as jp

import brax

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import sys
from typing import Tuple

import jax
import jax.numpy as jnp
from flax.struct import dataclass

#from evojax.task.base import VectorizedTask
#from evojax.task.base import TaskState

try:
    from brax.envs import create
    from brax.envs import State as BraxState
except ModuleNotFoundError:
    print('You need to install brax for Brax tasks:')
    print('  pip install git+https://github.com/google/brax.git@main')
    sys.exit()


@dataclass
class State():
    state: BraxState
    obs: jnp.ndarray
        
class BraxTask():
    """Tasks from the Brax simulator."""

    def __init__(self,
                 env_name: str,
                 max_steps: int = 1000,
                 test: bool = False):
        self.max_steps = max_steps
        self.test = test
        brax_env = create(env_name=env_name, episode_length=max_steps)
        self.obs_shape = tuple([brax_env.observation_size, ])
        self.act_shape = tuple([brax_env.action_size, ])

        def reset_fn(key):
            state = brax_env.reset(key)
            return State(state=state, obs=state.obs)

        self._reset_fn = jax.jit(jax.vmap(reset_fn))

        def step_fn(state, action):
            state = brax_env.step(state.state, action)
            return State(state=state, obs=state.obs), state.reward, state.done

        self._step_fn = jax.jit(jax.vmap(step_fn))

    def reset(self, key: jnp.ndarray) -> State:
        return self._reset_fn(key)

    def step(self,
             state: State,
             action: jnp.ndarray) -> Tuple[State, jnp.ndarray, jnp.ndarray]:
        return self._step_fn(state, action)

In [3]:
from brax import envs

env_name = "ant"  # @param ['ant', 'humanoid', 'halfcheetah', 'fetch']
env_fn = envs.create_fn(env_name=env_name)
env = env_fn()
state = env.reset(rng=jp.random_prngkey(seed=0))

In [4]:
%%time
rollout = []
for i in range(100):
    # wiggle sinusoidally
    action = jp.ones((env.action_size,)) * jp.sin(i * jp.pi / 15)
    state = env.step(state, action)
    rollout.append(state)

CPU times: user 9.34 s, sys: 2.73 ms, total: 9.34 s
Wall time: 9.34 s


In [5]:
state = jax.jit(env.step)(state, jnp.ones((env.action_size,)))

In [6]:
%%time
for _ in range(100):
    state = jax.jit(env.step)(state, jnp.ones((env.action_size,)))

CPU times: user 286 ms, sys: 11.8 ms, total: 298 ms
Wall time: 252 ms


In [10]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

In [11]:
jp

<module 'brax.jumpy' from '/scratch/ak1774/vargpu_env/lib/python3.8/site-packages/brax/jumpy.py'>

In [12]:
import jax.numpy as jnp

In [13]:
jnp

<module 'jax.numpy' from '/scratch/ak1774/vargpu_env/lib/python3.8/site-packages/jax/numpy/__init__.py'>

In [15]:
a = jnp.ones(10)

In [19]:
a*a

DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [20]:
jax.local_device_count()

1

In [None]:
# The plan
# Things that need doing
# Parent selection
# Creating population (mutations)
# Evaluating the population
# Calculating novelty (needs an archive, neerest neigbor)



In [21]:
jax.process_count()

1

In [None]:
# Now lut us create a run_episode_batch function.
# Question, how does brax handle uneven episode length?


In [30]:
from flax import struct

@struct.dataclass
class A():
    a:int

In [31]:
v  = A(3)

In [32]:
v

A(a=3)

In [33]:
A()

TypeError: __init__() missing 1 required positional argument: 'a'

In [36]:
A.replace

<function flax.struct.dataclass.<locals>.replace(self, **updates)>

In [None]:
 # The way i want to use gpu

novelty_archive = jpn.array()  # previusly seen behavior to calculate novelty, stored on GPU
elites_map = np.array()  # the map is stored on the cpu

jax_do_es_update = jax.jit(evaluate_children_and_do_es_update)

for gen in range(1000):
    
    
    # select parent
    parent = select_random_elite(elites_map)
    
    # copy parent params to gpu
    parent_params = jpn.array(parent["params"])
    
    # evaluate_children and do es update
    # this is all the heavy calculations
    # - creating mutated copies
    # - running episodes for each children
    # - calculating the weighted sum for the es update
    child_eval_result,new_params = jax_do_es_update(parent_params)  
    
    # evaluate the new params
    new_params_eval_results = jax_do_eval_episodes(new_params)
    
    # copy back params and eval result to cpu
    new_params_eval_results = np.array(new_params_eval_results)
    child_eval_result = np.array(child_eval_result)
    new_params = np.array(new_params)
    
    # add to archive
    add_to_archive(elites_map,new_params_eval_results,child_eval_result,new_params)
    
# NOTE
# if i want to cache results, i should use a random table, which is the same for the gpu and cpu,
# so I can just store the random indicies for evaluations.
# Why do i actually need to cache results?
# The main reason is that i need to evaluate children to calculate evolvability, and want to reuse them later to do updates.
# If i do the x step per parent before switching, than this have marginal benefit, (calculating 11 pop instead of 10)
# So let us not do it.




In [None]:
def jax_do_es_update(parent_params,novelty_archive):
    
    noise = jpn.randn(POP_SIZE,NUM_PARAMS)
    pop = parent_params + noise
    
    
    
    
    
    
    

In [5]:
import numpy as np

In [1]:
%env CUDA_VISIBLE_DEVICES=0

import brax
from brax import jumpy as jp
from brax.envs import env
import jax.numpy as jnp
import jax

env: CUDA_VISIBLE_DEVICES=0


In [2]:
# To calculate novelty, we need to calculate the distance from each point in the archive,
# then calcululate the mean of the k nearest neigbor distance
key = jax.random.PRNGKey(seed=5)
archive = jax.random.normal(key, shape=(20,2))


In [10]:
eval_bds = jax.random.normal(key, shape=(5,2))

In [57]:
def get_calculate_novelty_fn(k):
    def calculate_novelty(bd,archive):
        # For no
        distances = jnp.sqrt(jnp.sum((archive - bd)**2,axis=1))
        nearest_neighbors,nearest_indicies = jax.lax.top_k(-distances, k) # take negative to calculate neerest instead of furthest
        novelty = -jnp.mean(nearest_neighbors)   # take negative again, to get the mean distance
        return novelty
    
    return calculate_novelty

calculate_novelty = get_calculate_novelty_fn(k=10)
    
# we batch only 
calculate_novelty_batch = jax.jit(jax.vmap(calculate_novelty,in_axes=[0, None]))
jitted_calculate_novelty_batch = jax.jit(jax.vmap(calculate_novelty,in_axes=[0, None]))




In [None]:
eval_bds = jax.random.normal(key, shape=(5,2))

In [45]:
novelties = calculate_novelty_batch(eval_bds,archive)

In [64]:

# TEST speed of novelty calculation

# CPU verison
def calculate_novelty_cpu(eval_bds,archive):
    from sklearn.neighbors import NearestNeighbors
    nn_model = NearestNeighbors(n_neighbors=5, algorithm='ball_tree', metric='euclidean')
    nn_model.fit(archive)
    distances, indicies = nn_model.kneighbors(eval_bds,n_neighbors=min(10,eval_bds.shape[0]))  
    return np.mean(distances,axis=1)
        
def calculate_novelty_gpu(eval_bds,archive):
    return calculate_novelty_batch(eval_bds,archive)

def calculate_novelty_gpu_jit(eval_bds,archive):
    return jitted_calculate_novelty_batch(eval_bds,archive)

key = jax.random.PRNGKey(seed=6)
archive = jax.random.normal(key, shape=(2000,2))
eval_bds = jax.random.normal(key, shape=(10000,2))
    
# copy to cpu beforehand
cpu_eval_bds = np.array(eval_bds)
cpu_archive = np.array(archive)
    


In [68]:
%%time
calculate_novelty_cpu(cpu_eval_bds,cpu_archive)

CPU times: user 98.3 ms, sys: 318 µs, total: 98.6 ms
Wall time: 96.7 ms


array([0.10318025, 0.10136115, 0.08714432, ..., 0.26581214, 0.45856768,
       0.06177525])

In [69]:
%%time
calculate_novelty_gpu(eval_bds,archive).block_until_ready()

CPU times: user 4.49 ms, sys: 154 µs, total: 4.64 ms
Wall time: 3.29 ms


DeviceArray([0.10318024, 0.10136115, 0.08714432, ..., 0.26581216,
             0.4585677 , 0.06177524], dtype=float32)

In [67]:
%%time
calculate_novelty_gpu_jit(eval_bds,archive).block_until_ready()

CPU times: user 4.16 ms, sys: 28 µs, total: 4.19 ms
Wall time: 3.3 ms


DeviceArray([0.10318024, 0.10136115, 0.08714432, ..., 0.26581216,
             0.4585677 , 0.06177524], dtype=float32)

In [None]:
key = jax.random.PRNGKey(seed=6)
noise = jax.random.normal(key, shape=(10000,100000))
weights = jax.random.normal(key, shape=(10000))

In [None]:
results = jnp.ma