In [1]:
# Importing Packages

import numpy as np
import matplotlib.pyplot as plt
import time

from typing import Optional, Union

import jax
import jax.numpy as jnp
from jax.scipy.special import erf
from jax import jit, vmap, block_until_ready, config

from rl_algos.ppo_clipped_normal import PPO_make_train as PPO_Clipped_make_train

In [2]:
# Seeding everything

seed = 30

rng = jax.random.PRNGKey(seed)
rng, _rng = jax.random.split(rng)

In [3]:
# Defining Sherbrooke Params and RL Params

tau_0 = 0.783
kappa = 14.31
chi = 0.31 * 2. * jnp.pi
kerr = 0.00
n0 = 53.8
res_amp_scaling = 1/0.348
actual_max_photons = n0 * (1. - 2. * jnp.cos(0.5 * chi * tau_0) * jnp.exp(-0.5 * kappa * tau_0) + jnp.exp(-0.5 * kappa * tau_0))
print(f"Rough Max Photons: {n0}")
print(f"Actual Max Photons: {actual_max_photons}")
nR = 0.01
snr_scale_factor = 1.1
gamma_I = 1/362.9
photon_gamma = 1/4000
init_fid = 1.

time_coeff = 10.0
snr_coeff = 20.0
smoothness_coeff = 10.0
smoothness_baseline_scale = 1.0
apply_smoothing = True
use_processed_action = True
bandwidth = 50.0
freq_relative_cutoff = 0.1
bandwidth_coeff = 0.0
apply_bandwidth_constraint = False
num_t1 = 9.0
photon_weight = 12.0
shot_noise_std = 0.0
standard_fid = 0.99

env_config = {"kappa": kappa,
    "chi": chi,
    "kerr": kerr,
    "time_coeff": time_coeff,
    "snr_coeff": snr_coeff,
    "smoothness_coeff": smoothness_coeff,
    "smoothness_baseline_scale": smoothness_baseline_scale,
    "apply_smoothing": apply_smoothing,
    "use_processed_action": use_processed_action,
    "bandwidth": bandwidth,
    "freq_relative_cutoff": freq_relative_cutoff,
    "bandwidth_coeff": bandwidth_coeff,
    "apply_bandwidth_constraint": apply_bandwidth_constraint,
    "n0": n0,
    "tau_0": tau_0,
    "res_amp_scaling": res_amp_scaling,
    "nR": nR,
    "snr_scale_factor": snr_scale_factor,
    "gamma_I": gamma_I,
    "photon_gamma": photon_gamma,
    "num_t1": num_t1,
    "init_fid": init_fid,
    "photon_weight": photon_weight,
    "standard_fid": standard_fid,
    "shot_noise_std": shot_noise_std,
}

num_envs = 256
num_updates = 4000
config = {
    "LR": 3e-4,
    "NUM_ENVS": num_envs,
    "NUM_STEPS": 1,
    "NUM_UPDATES": num_updates,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": int(num_envs / 64),
    "CLIP_EPS": 0.2,
    "VALUE_CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu6",
    "LAYER_SIZE": 128,
    "ENV_NAME": "single_langevin_env",
    "ENV_PARAMS": env_config,
    "ANNEAL_LR": False,
    "DEBUG": True,
    "DEBUG_ACTION": False,
    "PRINT_RATE": 100,
    "ACTION_PRINT_RATE": 100,
}

Rough Max Photons: 53.8
Actual Max Photons: 53.711451977836106


## Varying Kappa

In [4]:
param_error = 0.05
param_error_array = jnp.array([1 - param_error, 1. ,1. + param_error])
kappas = kappa * param_error_array
chis = chi * param_error_array

single_train_kappa_2 = jit(PPO_Clipped_make_train(config), static_argnums=-1)

print(f"Starting a Run of {num_updates} Updates")
start = time.time()
single_result_k2 = single_train_kappa_2(
    _rng, 
    num_envs)
end = time.time()
print(f"time taken: {end - start}")

Starting a Run of 4000 Updates
global update: 100
reward: -31.329
max pF: 1.516
max photon: 45.43
photon time: 0.9986
smoothness: 0.009866
bandwidth: 39.988
global update: 200
reward: -20.638
max pF: 1.605
max photon: 47.571
photon time: 0.7224
smoothness: 0.008218999999999999
bandwidth: 35.381
global update: 300
reward: -8.176
max pF: 1.663
max photon: 48.111000000000004
photon time: 0.6929000000000001
smoothness: 0.006418
bandwidth: 32.178
global update: 400
reward: 0.655
max pF: 1.713
max photon: 48.64
photon time: 0.7634000000000001
smoothness: 0.005202
bandwidth: 29.72
global update: 500
reward: 8.279
max pF: 1.7710000000000001
max photon: 49.967
photon time: 0.664
smoothness: 0.004024
bandwidth: 24.971
global update: 600
reward: 14.194
max pF: 1.829
max photon: 51.029
photon time: 0.5807
smoothness: 0.0035039999999999997
bandwidth: 22.778000000000002
global update: 700
reward: 16.644000000000002
max pF: 1.821
max photon: 50.943
photon time: 0.5468000000000001
smoothness: 0.003068

In [None]:
import copy

config_2 = copy.deepcopy(config)
config_2["ENV_PARAMS"]["kappa"] = kappas[0]

config_2

In [None]:
single_train_kappa_1 = jit(PPO_Clipped_make_train(config_2), static_argnums=-1)

print(f"Starting a Run of {num_updates} Updates")
start = time.time()
single_result_k1 = single_train_kappa_1(
    _rng, 
    num_envs)
end = time.time()
print(f"time taken: {end - start}")

In [None]:
config_3 = copy.deepcopy(config)
config_3["ENV_PARAMS"]["kappa"] = kappas[2]
config_3["ENV_PARAMS"]["num_t1"] = 7.0

config_3

In [None]:
single_train_kappa_3 = jit(PPO_Clipped_make_train(config_3), static_argnums=-1)

print(f"Starting a Run of {num_updates} Updates")
start = time.time()
single_result_k3 = single_train_kappa_3(
    _rng, 
    num_envs)
end = time.time()
print(f"time taken: {end - start}")

In [None]:
metric_k1 = single_result_k1["metrics"]
metric_k2 = single_result_k2["metrics"]
metric_k3 = single_result_k3["metrics"]

actions_k1 = metric_k1["action"]
actions_k2 = metric_k2["action"]
actions_k3 = metric_k3["action"]

photon_times_k1 = metric_k1["photon time"][-1, 0]
photon_times_k2 = metric_k2["photon time"][-1, 0]
photon_times_k3 = metric_k3["photon time"][-1, 0]

In [None]:
from utils import photon_env_dicts

env = photon_env_dicts()[config["ENV_NAME"]](**env_config)

In [None]:
raw_final_action_k1 = actions_k1[-1, 0]
raw_final_action_k2 = actions_k2[-1, 0]
raw_final_action_k3 = actions_k3[-1, 0]

ts_mod = env.ts_action / param_error_array[2] / (config["ENV_PARAMS"]["num_t1"]) * config_3["ENV_PARAMS"]["num_t1"]

raw_smooth_action_k1 = env.prepare_action(raw_final_action_k1) * jnp.heaviside(photon_times_k1 - env.ts_action / param_error_array[0], 0.)
raw_smooth_action_k2 = env.prepare_action(raw_final_action_k2) * jnp.heaviside(photon_times_k2 - env.ts_action / param_error_array[0], 0.)
raw_smooth_action_k3 = env.prepare_action(raw_final_action_k3) * jnp.heaviside(photon_times_k3 - ts_mod, 0.)

In [None]:
# Comparing Kappa Waveforms

# Divide Time by Error Percentage because it corresponds to change in timescale
# Amplitude can be left as it is because we would deal with normalized waveforms
# So absolute values aren't relevant, only the waveform shapes

plt.plot(env.ts_action / param_error_array[0], -raw_smooth_action_k1, label='Low Kappa')
plt.plot(env.ts_action / param_error_array[1], -raw_smooth_action_k2, label='Normal Kappa')
plt.plot(ts_mod, -raw_smooth_action_k3, label='High Kappa')
plt.legend()
plt.xlabel('Time (us)')
plt.ylabel('Amplitude (A.U.)')
plt.show()

## Chi Variance

In [None]:
config_4 = copy.deepcopy(config)
config_4["ENV_PARAMS"]["chi"] = chis[0]

config_4

In [None]:
single_train_chi_1 = jit(PPO_Clipped_make_train(config_4), static_argnums=-1)

print(f"Starting a Run of {num_updates} Updates")
start = time.time()
single_result_c1 = single_train_chi_1(
    _rng, 
    num_envs)
end = time.time()
print(f"time taken: {end - start}")

In [None]:
config_5 = copy.deepcopy(config)
config_5["ENV_PARAMS"]["chi"] = chis[1]

config_5

In [None]:
single_train_chi_2 = jit(PPO_Clipped_make_train(config_5), static_argnums=-1)

print(f"Starting a Run of {num_updates} Updates")
start = time.time()
single_result_c2 = single_train_chi_2(
    _rng, 
    num_envs)
end = time.time()
print(f"time taken: {end - start}")

In [None]:
config_6 = copy.deepcopy(config)
config_6["ENV_PARAMS"]["chi"] = chis[2]

config_6

In [None]:
single_train_chi_3 = jit(PPO_Clipped_make_train(config_6), static_argnums=-1)

print(f"Starting a Run of {num_updates} Updates")
start = time.time()
single_result_c3 = single_train_chi_3(
    _rng, 
    num_envs)
end = time.time()
print(f"time taken: {end - start}")

In [None]:
metric_c1 = single_result_c1["metrics"]
metric_c2 = single_result_c2["metrics"]
metric_c3 = single_result_c3["metrics"]

actions_c1 = metric_c1["action"]
actions_c2 = metric_c2["action"]
actions_c3 = metric_c3["action"]

In [None]:
raw_final_action_c1 = actions_c1[-1, 0]
raw_final_action_c2 = actions_c2[-1, 0]
raw_final_action_c3 = actions_c3[-1, 0]

raw_smooth_action_c1 = env.prepare_action(raw_final_action_c1)
raw_smooth_action_c2 = env.prepare_action(raw_final_action_c2)
raw_smooth_action_c3 = env.prepare_action(raw_final_action_c3)

In [None]:
# Comparing Chi Waveforms

plt.plot(env.ts_action, raw_smooth_action_c1, label='Low Chi')
plt.plot(env.ts_action, raw_smooth_action_c2, label='Normal Chi')
plt.plot(env.ts_action, raw_smooth_action_c3, label='High Chi')
plt.legend()
plt.xlabel('Time (us)')
plt.ylabel('Amplitude (A.U.)')
plt.show()