In [80]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [168]:
import argparse
import os
import sys
import jax
import jax.numpy as jnp
import jax.random
from gymnasium import spaces
from tqdm import tqdm

from klax import lipschitz_l1_jax, triangular
from rl_environments import Vandelpol
# from rl_environments import LDSEnv, InvertedPendulum, CollisionAvoidanceEnv, Vandelpol
from rsm_learner import Learner
from rsm_verifier import Verifier
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [82]:
parser = argparse.ArgumentParser(
    prog='System Normalization for Neural Certificates',
    description='Learning a reach-avoid supermartingale neural network with normalization techniques',
    epilog='By Musan (Hao Wu) @ SKLCS, Institute of Software, UCAS'
)
# problem formulation    
parser.add_argument("--env", default="vandelpol", help='control system')
parser.add_argument("--timeout", default=60, type=int, help='max time limit in minutes') 
parser.add_argument("--reach_prob", default=0.8, type=float, help='reach-avoid probability')

# neural network and training
parser.add_argument("--hidden", default=128, type=int, help='hidden neurons in each layer')
parser.add_argument("--num_layers", default=2, type=int, help='number of hidden layers')
# learner 
parser.add_argument("--continue_rsm", type=int, default=0, help='use an existing network')

# verifier
parser.add_argument("--eps", default=0.05, type=float) # ???
parser.add_argument("--lip", default=0.01, type=float) # ???
# parser.add_argument("--p_lip", default=0.0, type=float) 
parser.add_argument("--l_lip", default=4.0, type=float) # ???
parser.add_argument("--fail_check_fast", type=int, default=0)
parser.add_argument("--grid_factor", default=1.0, type=float)
parser.add_argument("--batch_size", default=512, type=int)
# parser.add_argument("--ppo_iters", default=50, type=int)
# parser.add_argument("--policy", default="policies/lds0_zero.jax")
# parser.add_argument("--train_p", type=int, default=1)
parser.add_argument("--square_l_output", default=True)
parser.add_argument("--jitter_grid", type=int, default=0)
parser.add_argument("--soft_constraint", type=int, default=1)
parser.add_argument("--gamma_decrease", default=1.0, type=float)
parser.add_argument("--debug_k0", action="store_true")
parser.add_argument("--gen_plot", action="store_true")
parser.add_argument("--no_refinement", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("--small_mem", action="store_true")
sys.argv = ['test.ipynb', '--env', 'vandelpol']
args = parser.parse_args()

In [154]:
env = Vandelpol()
learner = Learner(
    env=env,
    l_hidden=[args.hidden] * args.num_layers,
    l_lip=args.l_lip,
    eps=args.eps,
    gamma_decrease=args.gamma_decrease,
    reach_prob=args.reach_prob,
    softplus_l_output=args.square_l_output,
)

In [159]:
verifier = Verifier(
    learner=learner,
    env=env,
    batch_size=args.batch_size,
    reach_prob=args.reach_prob,
    fail_check_fast=bool(args.fail_check_fast),
    grid_factor=args.grid_factor,
    small_mem=args.small_mem,
)

In [160]:
from rsm_loop import RSMLoop
loop = RSMLoop( 
        learner,
        verifier,
        env,
        lip_factor=args.lip,
        plot=args.plot,
        jitter_grid=bool(args.jitter_grid),
        soft_constraint=bool(args.soft_constraint),
    )


# txt_return = learner.evaluate_rl()
# loop.plot_l(f"plots/{args.env}_start.png")

In [86]:
loop.run(args.timeout*10)


## Iteration 0 (0:00 elapsed) ##
use train buffer.


Train: loss=43.4, dec_loss=34, violations=0.504: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s]


Trained on 40000 samples, start_loss=7.3, end_loss=43.4, start_violations=0.499, end_violations=0.504 in 1.2 minutes


100%|██████████| 1/1 [00:02<00:00,  2.51s/it]
100%|██████████| 74/74 [00:07<00:00, 10.37it/s]


lipschitz_k=1932.7548828125 (without delta)
delta=0.008016032064128256
K=15.49302511272545 (with delta)
Checking GRID of size 500


100%|██████████| 1/1 [02:53<00:00, 173.72s/it]


violations=38920
hard_violations=38636
38920/250000 violated decrease condition
38636/250000 hard violations
Train buffer len: 78636
Grid runtime=3 min
info= {'ds_size': 40000, 'lipschitz_k': 1932.7548828125, 'K_f': 4.0, 'K_l': 483.188720703125, 'iter': 0, 'runtime': 0.016546249389648438, 'delta': 0.008016032064128256, 'K': 15.49302511272545, 'avg_increase': 0, 'dec_violations': '38920/250000', 'hard_violations': '38636/250000'}

## Iteration 1 (4:20 elapsed) ##
use train buffer.




KeyboardInterrupt: 

In [172]:
grid, steps = [], []
for i in range(2):
    samples, step = jnp.linspace(
        -1.0,
        1.0,
        2,
        endpoint=False,
        retstep=True,
    )
    grid.append(samples)
    steps.append(step)

print(grid)
grid_lb = jnp.meshgrid(*grid)
grid_lb = [x.flatten() for x in grid_lb]
grid_ub = [grid_lb[i] + steps[i] for i in range(2)]
print(grid_lb)

pmass = env.integrate_noise(grid_lb, grid_ub)

[Array([-1.        , -0.79999995, -0.6       , -0.39999998, -0.20000002,
        0.        ,  0.20000005,  0.39999998,  0.6       ,  0.8000001 ],      dtype=float32), Array([-1.        , -0.79999995, -0.6       , -0.39999998, -0.20000002,
        0.        ,  0.20000005,  0.39999998,  0.6       ,  0.8000001 ],      dtype=float32)]
[0.02       0.06       0.09999999 0.14       0.18000004 0.18000004
 0.14000002 0.09999998 0.05999998 0.01999999 0.02       0.06
 0.09999999 0.14       0.18000004 0.18000004 0.14000002 0.09999998
 0.05999998 0.01999999 0.02       0.06       0.09999999 0.14
 0.18000004 0.18000004 0.14000002 0.09999998 0.05999998 0.01999999
 0.02       0.06       0.09999999 0.14       0.18000004 0.18000004
 0.14000002 0.09999998 0.05999998 0.01999999 0.02       0.06
 0.09999999 0.14       0.18000004 0.18000004 0.14000002 0.09999998
 0.05999998 0.01999999 0.02       0.06       0.09999999 0.14
 0.18000004 0.18000004 0.14000002 0.09999998 0.05999998 0.01999999
 0.02       0.06     

In [173]:
jnp.sum(pmass)

Array(1., dtype=float32)