In [None]:
# Importing Packages

import numpy as np
import matplotlib.pyplot as plt
import time
import copy

from typing import Optional, Union

import jax
import jax.numpy as jnp
from jax.scipy.special import erf
from jax import jit, vmap, block_until_ready

from utils import plot_learning
from envs.single_photon_env import SinglePhotonLangevinReadoutEnv
from env_configs import get_sherbrooke_config, get_kyoto_config

from rl_algos.ppo_continuous import PPO_make_train

In [None]:
# Seeding everything

seed = 30

rng = jax.random.PRNGKey(seed)
rng, _rng = jax.random.split(rng)

In [None]:
k_config = get_kyoto_config()
k_env_config = copy.deepcopy(k_config)

num_envs = 256
num_updates = 6000
kyoto_config = {
    "LR": 3e-4,
    "NUM_ENVS": num_envs,
    "NUM_STEPS": 1,
    "NUM_UPDATES": num_updates,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": int(num_envs / 64),
    "CLIP_EPS": 0.2,
    "VALUE_CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu6",
    "LAYER_SIZE": 128,
    "ENV_NAME": "single_langevin_env",
    "ENV_PARAMS": k_env_config,
    "ANNEAL_LR": False,
    "DEBUG": True,
    "DEBUG_ACTION": False,
    "PRINT_RATE": 100,
    "ACTION_PRINT_RATE": 100,
}

In [None]:
kyoto_env = SinglePhotonLangevinReadoutEnv(**k_env_config)

In [None]:
k_env_config

In [None]:
# Default Kyoto Action

ts = kyoto_env.ts_action

gauss_sigma = 0.0142
num_sigma = 2.
total_duration = k_env_config["tau_0"]

default_pulse = jnp.heaviside((total_duration - num_sigma * gauss_sigma) - ts, 1.) * jnp.heaviside(ts - num_sigma * gauss_sigma, 0.)
default_pulse += jnp.heaviside(num_sigma * gauss_sigma - ts, 1.) * jnp.exp(-(ts - num_sigma * gauss_sigma)**2 / (2 * gauss_sigma**2))
default_pulse += jnp.heaviside(ts - (total_duration - num_sigma * gauss_sigma), 1.) * jnp.exp(-(ts - (total_duration - num_sigma * gauss_sigma))**2 / (2 * gauss_sigma**2))
prepped_pulse = kyoto_env.prepare_action(default_pulse) / kyoto_env.a0

manual_pulse = kyoto_env.normalize_pulse(default_pulse)
manual_pulse = kyoto_env.drive_smoother(manual_pulse)
res_drive_diff = jnp.diff(manual_pulse, n=1) / kyoto_env.grad_dt
res_drive_diff_clipped = jnp.clip(
    res_drive_diff, a_min=-kyoto_env.max_grad, a_max=kyoto_env.max_grad
)
res_drive_processed = jnp.cumsum(
    jnp.concatenate((jnp.array([manual_pulse[0]]), res_drive_diff_clipped * kyoto_env.grad_dt))
)
manual_pulse = kyoto_env.gradient_clipper(manual_pulse)
manual_pulse = kyoto_env.normalize_pulse(manual_pulse)
manual_pulse = kyoto_env.drive_smoother(manual_pulse)

plt.plot(ts, default_pulse, label='default kyoto pulse')
plt.plot(ts, manual_pulse, label='manual kyoto pulse')
plt.plot(ts, prepped_pulse, label='prepped kyoto pulse')
plt.legend()
plt.show()

In [None]:
kyoto_env.rollout_action(_rng, default_pulse)

In [None]:
# Important coefficients to be set, can be played around with
k_env_config["time_coeff"] = 2
k_env_config["smoothness_coeff"] = 1.

k_env_config

In [None]:
mod_kyoto_config = {
    "LR": 3e-4,
    "NUM_ENVS": num_envs,
    "NUM_STEPS": 1,
    "NUM_UPDATES": num_updates,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": int(num_envs / 64),
    "CLIP_EPS": 0.2,
    "VALUE_CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu6",
    "LAYER_SIZE": 128,
    "ENV_NAME": "single_langevin_env",
    "ENV_PARAMS": k_env_config,
    "ANNEAL_LR": False,
    "DEBUG": True,
    "DEBUG_ACTION": False,
    "PRINT_RATE": 100,
    "ACTION_PRINT_RATE": 100,
}

In [None]:
kyoto_train = jit(PPO_make_train(mod_kyoto_config), static_argnums=-1)

print(f"Starting a Run of {num_updates} Updates")
start = time.time()
kyoto_result = kyoto_train(
    _rng, 
    num_envs)
end = time.time()
print(f"time taken: {end - start}")