In [2]:
%load_ext autoreload
%autoreload 2

In [17]:
import argparse
import os
import sys
import jax
import jax.numpy as jnp
import jax.random
from gymnasium import spaces
from tqdm import tqdm

from klax import lipschitz_l1_jax, triangular
from rl_environments import Vandelpol
# from rl_environments import LDSEnv, InvertedPendulum, CollisionAvoidanceEnv, Vandelpol
from rsm_learner import Learner
from rsm_verifier import Verifier
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
parser = argparse.ArgumentParser(
    prog='System Normalization for Neural Certificates',
    description='Learning a reach-avoid supermartingale neural network with normalization techniques',
    epilog='By Musan (Hao Wu) @ SKLCS, Institute of Software, UCAS'
)
# problem formulation    
parser.add_argument("--env", default="vandelpol", help='control system')
parser.add_argument("--timeout", default=60, type=int, help='max time limit in minutes') 
parser.add_argument("--reach_prob", default=0.8, type=float, help='reach-avoid probability')

# neural network and training
parser.add_argument("--hidden", default=128, type=int, help='hidden neurons in each layer')
parser.add_argument("--num_layers", default=2, type=int, help='number of hidden layers')
# learner 
parser.add_argument("--continue_rsm", type=int, default=0, help='use an existing network')

# verifier
parser.add_argument("--eps", default=0.05, type=float) # ???
parser.add_argument("--lip", default=0.01, type=float) # ???
# parser.add_argument("--p_lip", default=0.0, type=float) 
parser.add_argument("--l_lip", default=4.0, type=float) # ???
parser.add_argument("--fail_check_fast", type=int, default=0)
parser.add_argument("--grid_factor", default=1.0, type=float)
parser.add_argument("--batch_size", default=512, type=int)
# parser.add_argument("--ppo_iters", default=50, type=int)
# parser.add_argument("--policy", default="policies/lds0_zero.jax")
# parser.add_argument("--train_p", type=int, default=1)
parser.add_argument("--square_l_output", default=True)
parser.add_argument("--jitter_grid", type=int, default=0)
parser.add_argument("--soft_constraint", type=int, default=1)
parser.add_argument("--gamma_decrease", default=1.0, type=float)
parser.add_argument("--debug_k0", action="store_true")
parser.add_argument("--gen_plot", action="store_true")
parser.add_argument("--no_refinement", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("--small_mem", action="store_true")
sys.argv = ['test.ipynb', '--env', 'vandelpol']
args = parser.parse_args()

In [5]:
env = Vandelpol()
learner = Learner(
    env=env,
    l_hidden=[args.hidden] * args.num_layers,
    l_lip=args.l_lip,
    eps=args.eps,
    gamma_decrease=args.gamma_decrease,
    reach_prob=args.reach_prob,
    softplus_l_output=args.square_l_output,
)

  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [6]:
verifier = Verifier(
    learner=learner,
    env=env,
    batch_size=args.batch_size,
    reach_prob=args.reach_prob,
    fail_check_fast=bool(args.fail_check_fast),
    grid_factor=args.grid_factor,
    small_mem=args.small_mem,
)

In [14]:
from rsm_loop import RSMLoop
loop = RSMLoop( 
        learner,
        verifier,
        env,
        lip_factor=args.lip,
        plot=args.plot,
        # train_p=bool(args.train_p),#???
        jitter_grid=bool(args.jitter_grid),#???
        soft_constraint=bool(args.soft_constraint),#???
    )


# txt_return = learner.evaluate_rl()
loop.plot_l(f"plots/{args.env}_start.png")

In [15]:
loop.run(args.timeout * 60)

Creating pre-train buffer ...  [done]

## Iteration 0 (0:00 elapsed) ##


  0%|          | 0/200 [00:00<?, ?it/s]

TypeError: differentiating with respect to argnums=(0, 1) requires at least 2 positional arguments to be passed by the caller, but got only 1 positional arguments.

In [22]:
from rsm_verifier import Verifier
verifier.get_domain_jitter_grid(10)

[False False False False False False False False False False False False
  True False False False False False False False False False False False
 False]


(Array([[-1.6000000e+00, -1.6000000e+00],
        [-8.0000007e-01, -1.6000000e+00],
        [-2.9802322e-08, -1.6000000e+00],
        [ 8.0000007e-01, -1.6000000e+00],
        [ 1.6000000e+00, -1.6000000e+00],
        [-1.6000000e+00, -8.0000007e-01],
        [-8.0000007e-01, -8.0000007e-01],
        [-2.9802322e-08, -8.0000007e-01],
        [ 8.0000007e-01, -8.0000007e-01],
        [ 1.6000000e+00, -8.0000007e-01],
        [-1.6000000e+00, -2.9802322e-08],
        [-8.0000007e-01, -2.9802322e-08],
        [ 8.0000007e-01, -2.9802322e-08],
        [ 1.6000000e+00, -2.9802322e-08],
        [-1.6000000e+00,  8.0000007e-01],
        [-8.0000007e-01,  8.0000007e-01],
        [-2.9802322e-08,  8.0000007e-01],
        [ 8.0000007e-01,  8.0000007e-01],
        [ 1.6000000e+00,  8.0000007e-01],
        [-1.6000000e+00,  1.6000000e+00],
        [-8.0000007e-01,  1.6000000e+00],
        [-2.9802322e-08,  1.6000000e+00],
        [ 8.0000007e-01,  1.6000000e+00],
        [ 1.6000000e+00,  1.600000