In [1]:
# Imports

import time
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import jit, vmap, block_until_ready

from envs.photon_env import BatchedPhotonLangevinReadoutEnv
from envs.single_photon_env import SinglePhotonLangevinReadoutEnv

from rl_algos.rl_wrappers import VecEnv

In [2]:
# Defining Cairo Params and RL Params

tau_0 = 0.398
kappa = 20.0
chi = 0.65 * 2. * jnp.pi
kerr = 0.002
gamma = 1/140
time_coeff = 10.0
snr_coeff = 10.0
smoothness_coeff = 10.0
n0 = 43
res_amp_scaling = 1/0.43
actual_max_photons = n0 * (1 - jnp.exp(-0.5 * kappa * tau_0))**2
print(f"Rough Max Photons: {n0}")
print(f"Actual Max Photons: {actual_max_photons}")
nR = 0.01
snr_scale_factor = 1.9
gamma_I = 1/140
num_t1 = 5.0
photon_gamma = 1/1500
init_fid = 1 - 1e-4
photon_weight = 4.0

batchsize = 64
num_envs = 8
num_updates = 2000
config = {
    "LR": 3e-3,
    "NUM_ENVS": num_envs,
    "NUM_STEPS": batchsize,
    "NUM_UPDATES": num_updates,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": int(batchsize * num_envs / 64),
    "CLIP_EPS": 0.2,
    "VALUE_CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu6",
    "LAYER_SIZE": 64,
    "ENV_NAME": "photon_langevin_readout_env",
    "ANNEAL_LR": False,
    "DEBUG": True,
    "DEBUG_ACTION": False,
    "PRINT_RATE": 100,
    "ACTION_PRINT_RATE": 100,
}

Rough Max Photons: 43
Actual Max Photons: 41.40804860100575


In [3]:
single_env = SinglePhotonLangevinReadoutEnv(
    kappa=kappa,
    chi=chi,
    kerr=kerr,
    time_coeff=time_coeff,
    snr_coeff=snr_coeff,
    smoothness_coeff=smoothness_coeff,
    n0=n0,
    tau_0=tau_0,
    res_amp_scaling=res_amp_scaling,
    nR=nR,
    snr_scale_factor=snr_scale_factor,
    gamma_I=gamma_I,
    photon_gamma=photon_gamma,
    num_t1=num_t1,
    init_fid=init_fid,
    photon_weight=photon_weight,
)

In [4]:
### Timing for Single Env ###

rng = jax.random.PRNGKey(seed=30)
rng, rng_reset = jax.random.split(rng)

env_params = single_env.default_params

init_obs, init_state = single_env.reset(rng_reset, env_params)

dummy_action = jnp.ones_like(single_env.ts_action)

rng, rng_step = jax.random.split(rng)

obs, state, reward, done, info = single_env.step(rng_step, init_state, dummy_action, env_params)

start = time.time()
obs, state, reward, done, info = single_env.step(rng_step, init_state, dummy_action, env_params)
end = time.time()
time_taken = end - start
us = 1e-6
ms = 1e-3

print(f"Time taken for single action: {time_taken / ms}ms")

Time taken for single action: 5.030155181884766ms


In [12]:
### Timing when vmapped ###

vec_env = VecEnv(single_env)

rng = jax.random.PRNGKey(seed=30)

num_envs = 1024

rng, _rng = jax.random.split(rng)
rng_reset = jax.random.split(_rng, num_envs)

vec_env_params = vec_env.default_params

init_vec_obs, init_vec_state = vec_env.reset(rng_reset, vec_env_params)

rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, num_envs)

vec_action = jnp.tile(dummy_action, (num_envs,1))
jit_vmap_step = jax.jit(vec_env.step)

obs, state, reward, done, info = jit_vmap_step(rng_step, init_vec_state, vec_action, env_params)

start = time.time()
obs, state, reward, done, info = block_until_ready(jit_vmap_step(rng_step, init_vec_state, vec_action, env_params))
end = time.time()
time_taken = end - start

print(f"Time taken for {num_envs} vmapped actions: {time_taken / ms}ms")
print(f"Time taken per action: {time_taken / num_envs / us}us")

Time taken for 1024 vmapped actions: 211.51208877563477ms
Time taken per action: 206.55477419495583us


In [13]:
batched_env = BatchedPhotonLangevinReadoutEnv(
    kappa=kappa,
    chi=chi,
    batchsize=num_envs,
    kerr=kerr,
    time_coeff=time_coeff,
    snr_coeff=snr_coeff,
    smoothness_coeff=smoothness_coeff,
    n0=n0,
    tau_0=tau_0,
    res_amp_scaling=res_amp_scaling,
    nR=nR,
    snr_scale_factor=snr_scale_factor,
    gamma_I=gamma_I,
    photon_gamma=photon_gamma,
    num_t1=num_t1,
    init_fid=init_fid,
    photon_weight=photon_weight,
)

In [17]:
### Comparing with Old Batched Environment

rng = jax.random.PRNGKey(seed=30)
rng, rng_reset = jax.random.split(rng)

batched_env_params = batched_env.default_params
init_obs, init_state = batched_env.reset(rng_reset, batched_env_params)

rng, rng_step = jax.random.split(rng)

jit_batched_step = jax.jit(batched_env.step)

obs, state, reward, done, info = jit_batched_step(rng_step, init_state, vec_action, batched_env_params)

start = time.time()
obs, state, reward, done, info = block_until_ready(jit_batched_step(rng_step, init_state, vec_action, batched_env_params))
end = time.time()
time_taken = end - start

print(f"Time taken for {num_envs} vmapped actions: {time_taken / ms}ms")
print(f"Time taken per action: {time_taken / num_envs / us}us")

Time taken for 1024 vmapped actions: 189.3310546875ms
Time taken per action: 184.89360809326172us
