In [205]:
%load_ext autoreload
%autoreload 2
import argparse
import os
import sys
import jax
import jax.numpy as jnp
import gymnasium as gym
# import jax.random
# from gymnasium import spaces
from tqdm import tqdm

from klax import lipschitz_l1_jax, triangular
from rl_environments import Vandelpol, Poly1, LinearLQR, NonPoly1
# from rl_environments import LDSEnv, InvertedPendulum, CollisionAvoidanceEnv, Vandelpol
from rsm_learner import Learner
from rsm_verifier import Verifier
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from rsm_loop import RSMLoop
from klax import (
    project,
    inverse_project,
    v_project,
    v_inverse_project,
    jax_save,
    jax_load,
    lipschitz_l1_jax,
    # martingale_loss,
    # triangular,
    IBPMLP,
    MLP,
    create_train_state,
    # zero_at_zero_loss,
    clip_grad_norm,
)


parser = argparse.ArgumentParser(
    prog='System Normalization for Neural Certificates',
    description='Learning a reach-avoid supermartingale neural network with normalization techniques',
    epilog='By Musan (Hao Wu) @ SKLCS, Institute of Software, UCAS'
)
parser.add_argument("--env", default="1", help='dynamical system')
parser.add_argument("--timeout", default=60, type=int, help='max time limit in minutes') 
parser.add_argument("--reach_prob", default=0.9, type=float, help='reach-avoid probability')
parser.add_argument("--batch_size", default=512, type=int, help='batch size in training and verification')

# learner 
parser.add_argument("--hidden", default=16, type=int, help='hidden neurons in each layer')
parser.add_argument("--num_layers", default=2, type=int, help='number of hidden layers')
parser.add_argument("--square_l_output", default=True, help='use square activation in the last layer')
parser.add_argument("--continue_rsm", type=int, default=0, help='use an existing network')

# verifier
parser.add_argument("--eps", default=0.05, type=float, help='epsilon in the RASM condition') 
parser.add_argument("--lip", default=0.1, type=float, help='regularization term for lipschitz constant') 
parser.add_argument("--l_lip", default=1.0, type=float, help='target lipschitz constant of the neural network') 
parser.add_argument("--grid_size", default=200, type=int, help='grid size for verification')

parser.add_argument("--normalize", default=False, type=bool, help='Nomalize the system')
parser.add_argument("--debug", default=True, type=bool)
sys.argv = ['test.ipynb', '--env', '1']
args = parser.parse_args()

env = NonPoly1(difficulty=5)

learner = Learner(
        env=env,
        l_hidden=[args.hidden] * args.num_layers,
        l_lip=args.l_lip,
        eps=args.eps,
        reach_prob=args.reach_prob,
        softplus_l_output=True,
        normalize = args.normalize,
        debug = args.debug,
    )
    
verifier = Verifier(
        learner=learner,
        env=env,
        batch_size=args.batch_size,
        reach_prob=args.reach_prob,
        grid_size=args.grid_size,
        normalize=args.normalize,
        debug = args.debug,
    )

if args.continue_rsm > 0:
        learner.load(f"saved/{args.env}_loop.jax")

loop = RSMLoop( 
        learner,
        verifier,
        env,
        lip_factor=args.lip,
        normalize=args.normalize,
        debug = args.debug,
    )

# loop.plot_l(f"plots/{args.env}_start.png")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [203]:
verifier.compute_bound_l(env.unsafe_spaces)

100%|██████████| 1/1 [00:00<00:00,  2.18it/s]


(3.6276910305023193, 4.719462871551514)

In [202]:
learner.load("saved/NonPoly1_loop.jax")
# lb = env.unsafe_spaces[0].low
# ub = env.unsafe_spaces[0].high
lb = [project(env.unsafe_spaces[0].low)]
ub = [project(env.unsafe_spaces[0].high)]
x_boxes = verifier.make_box(lb, ub)
y_lb, y_ub = verifier.v_get_y_box(x_boxes)
print(y_lb, y_ub)
verifier.compute_bounds_on_set(y_lb,y_ub)

# verifier.compute_bounds_on_set(jnp.array([lb]), jnp.array([ub]))

[[-0.35998467 -0.3478004 ]] [[-0.13042949  0.2105075 ]]


100%|██████████| 1/1 [00:00<00:00, 135.91it/s]


(1.907335877418518, 3.9509012699127197)

In [194]:
learner.load("saved/NonPoly1_loop.jax")
# lb = env.unsafe_spaces[0].low
# ub = env.unsafe_spaces[0].high
lb = project(env.init_spaces[0].low)
ub = project(env.init_spaces[0].high)
verifier.compute_bounds_on_set(jnp.array([lb]), jnp.array([ub]))

100%|██████████| 1/1 [00:00<00:00, 56.33it/s]


(0.757574200630188, 0.8432248830795288)

In [171]:
lb = project(env.init_spaces[0].low)
ub = project(env.init_spaces[0].high)
verifier.compute_bounds_on_set(jnp.array([lb]), jnp.array([ub]))

100%|██████████| 1/1 [00:00<00:00, 84.70it/s]


(0.3501799702644348, 0.836003839969635)

In [148]:
rng = jax.random.PRNGKey(0)
unsafe_samples = v_project(learner.sample_unsafe(rng, 256))
l_at_unsafe = learner.l_state.apply_fn(learner.l_state.params, unsafe_samples)
min_at_unsafe = jnp.min(l_at_unsafe)
print("sample min unsafe:", min_at_unsafe)

sample min unsafe: 0.72276515


In [144]:
a = 1
a = np.maximum(2, a)
a

2

In [207]:
rng = jax.random.PRNGKey(0)
learner.sample_init(rng, 5)

Array([[ 0.5000663 ,  3.169514  ],
       [ 0.28297544,  3.027704  ],
       [ 1.1871221 ,  3.3483384 ],
       [ 0.21701396,  3.102299  ],
       [-0.8663714 ,  3.4614828 ]], dtype=float32)

In [187]:
verifier.compute_bound_l(env.init_spaces)

[-0.16666667  0.14285715]


100%|██████████| 1/1 [00:00<00:00, 95.31it/s]


(0.31266385316848755, 0.9417946338653564)

In [213]:
s = learner.sample_init(rng, 1)
s.flatten()

Array([1.0598719, 3.6419444], dtype=float32)

In [210]:
state  = jax.random.uniform(
                rng,
                2,
                minval=np.maximum(env.init_spaces[0].low, -10000),
                maxval=np.minimum(env.init_spaces[0].high, 10000),
            )
print(state)

[-0.7092616  3.505155 ]


In [221]:
low = [-np.inf] * 3
print(low)

[-inf, -inf, -inf]


In [219]:
scale = 5
a = [-0.25*scale, 0.5*scale]
a

[-1.25, 2.5]

In [218]:
project(jnp.array(a))

Array([-0.2631579,  0.5263158], dtype=float32)

In [225]:
a = jnp.array([[1,2],[3,4]])
a.shape[1]

2

In [227]:
rng = jax.random.PRNGKey(0)
r1, r2 =jax.random.split(rng, 2)

In [228]:
r1

Array([4146024105,  967050713], dtype=uint32)

In [230]:
jax.random.split(rng, 4)

Array([[2285895361, 1501764800],
       [1518642379, 4090693311],
       [ 433833334, 4221794875],
       [ 839183663, 3740430601]], dtype=uint32)

In [238]:
rng = np.random.default_rng(seed=1)
rfloat = rng.random()
print(rfloat)

0.5118216247002567


In [239]:
rfloat = rng.random()
print(rfloat)

0.9504636963259353


In [240]:
rfloat = rng.random()
print(rfloat)

0.14415961271963373


In [248]:
s = [1,2,3,4,5,6,7,8,9]
rng = np.random.default_rng(seed=1)
rng.permutation(s)

array([8, 1, 2, 5, 3, 6, 9, 7, 4])

In [247]:
rng.permutation(s)

array([2, 8, 7, 1, 3, 4, 5, 6, 9])

In [251]:
import tensorflow as tf
s = tf.data.Dataset.from_tensor_slices(s)
s.shuffle(buffer_size=9, seed=1)

<_ShuffleDataset element_spec=TensorSpec(shape=(), dtype=tf.int32, name=None)>

In [262]:
rng = jax.random.PRNGKey(0)
int = jax.random.randint(rng, minval=0, maxval=9, shape=(2,))
print(78)
print(int[0])


6
3


In [266]:
np.random.default_rng(78)

Generator(PCG64) at 0x348209700

In [269]:
jnp.log(2.0)

Array(0.6931472, dtype=float32, weak_type=True)

In [274]:
jax.scipy.stats.uniform.cdf(2, loc=0, scale=2)

Array(1., dtype=float32)

In [276]:
jax.scipy.stats.uniform.cdf(0, loc=-1, scale=2)


Array(0.5, dtype=float32)

In [278]:
jax.scipy.stats.norm.cdf(-1, loc=0, scale=1)

Array(0.15865527, dtype=float32)

In [279]:
a = jnp.array([4,3,2,1])
b = jnp.array([[1],[2],[3],[4]])
a<=b

Array([[False, False, False,  True],
       [False, False,  True,  True],
       [False,  True,  True,  True],
       [ True,  True,  True,  True]], dtype=bool)