In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import os
import sys
import jax
import jax.numpy as jnp
# import jax.random
# from gymnasium import spaces
from tqdm import tqdm

from klax import lipschitz_l1_jax, triangular
from rl_environments import Vandelpol, Test2
# from rl_environments import LDSEnv, InvertedPendulum, CollisionAvoidanceEnv, Vandelpol
from rsm_learner import Learner
from rsm_verifier import Verifier
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
parser = argparse.ArgumentParser(
    prog='System Normalization for Neural Certificates',
    description='Learning a reach-avoid supermartingale neural network with normalization techniques',
    epilog='By Musan (Hao Wu) @ SKLCS, Institute of Software, UCAS'
)

parser.add_argument("--debug", default=True, type=bool)
# problem formulation    
# parser.add_argument("--env", default="vandelpol", help='system')
parser.add_argument("--env", default="test2", help='system')
parser.add_argument("--timeout", default=60, type=int, help='max time limit in minutes') 
parser.add_argument("--reach_prob", default=0.8, type=float, help='reach-avoid probability')

parser.add_argument("--project", default=False, type=bool, help='Nomalize the system')

# neural network and training
parser.add_argument("--hidden", default=128, type=int, help='hidden neurons in each layer')
parser.add_argument("--num_layers", default=2, type=int, help='number of hidden layers')
# learner 
parser.add_argument("--continue_rsm", type=int, default=0, help='use an existing network')

# verifier
parser.add_argument("--eps", default=0.01, type=float) # ???
parser.add_argument("--lip", default=0.1, type=float) # ???

 

# parser.add_argument("--p_lip", default=0.0, type=float) 
parser.add_argument("--l_lip", default=10.0, type=float) # ???
parser.add_argument("--fail_check_fast", type=int, default=0)
parser.add_argument("--grid_factor", default=1.0, type=float)
parser.add_argument("--batch_size", default=512, type=int)
# parser.add_argument("--ppo_iters", default=50, type=int)
# parser.add_argument("--policy", default="policies/lds0_zero.jax")
# parser.add_argument("--train_p", type=int, default=1)
parser.add_argument("--square_l_output", default=True)
parser.add_argument("--jitter_grid", type=int, default=0)
parser.add_argument("--soft_constraint", type=int, default=1)
parser.add_argument("--gamma_decrease", default=1.0, type=float)
parser.add_argument("--debug_k0", action="store_true")
parser.add_argument("--gen_plot", action="store_true")
parser.add_argument("--no_refinement", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("--small_mem", action="store_true")
sys.argv = ['test.ipynb']
args = parser.parse_args()

In [4]:
# env = Vandelpol()
env = Test2()
learner = Learner(
    env=env,
    l_hidden=[args.hidden] * args.num_layers,
    l_lip=args.l_lip,
    eps=args.eps,
    gamma_decrease=args.gamma_decrease,
    reach_prob=args.reach_prob,
    softplus_l_output=args.square_l_output,
    debug = args.debug,
)

In [5]:
verifier = Verifier(
    learner=learner,
    env=env,
    batch_size=args.batch_size,
    reach_prob=args.reach_prob,
    fail_check_fast=bool(args.fail_check_fast),
    grid_factor=args.grid_factor,
    small_mem=args.small_mem,
    debug = args.debug,
)

In [6]:
from rsm_loop import RSMLoop
loop = RSMLoop( 
        learner,
        verifier,
        env,
        lip_factor=args.lip,
        plot=args.plot,
        jitter_grid=bool(args.jitter_grid),
        soft_constraint=bool(args.soft_constraint),
        debug = args.debug,
    )


loop.plot_l(f"plots/{args.env}_start.png")

In [None]:
loop.run(args.timeout*60)


## Iteration 0 (0:00 elapsed) ##


Train: loss=1.79, dec_loss=1.53, violations=0.561: 100%|██████████| 100/100 [00:23<00:00,  4.20it/s]


Trained on 40000 samples, start_loss=13.9, end_loss=1.79, start_violations=0.512, end_violations=0.561 in 23.8 seconds


100%|██████████| 3/3 [00:01<00:00,  2.59it/s]
100%|██████████| 77/77 [00:00<00:00, 89.01it/s] 


lipschitz_k=11.152777519226074 (without delta)
delta=0.010050251256281407
K=0.11208821627362889 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:15<00:00, 15.84s/it]

violations=24790
hard_violations=24790
24790/40000 violated decrease condition
24790/40000 hard violations
Train buffer len: 64790
Grid runtime=18.62 s
info= {'ds_size': 40000, 'lipschitz_k': 11.152777519226074, 'K_f': 1.08, 'K_l': 10.326645851135254, 'iter': 0, 'runtime': 0.15781497955322266, 'delta': 0.010050251256281407, 'K': 0.11208821627362889, 'avg_increase': 0, 'dec_violations': '24790/40000', 'hard_violations': '24790/40000'}






## Iteration 1 (0:43 elapsed) ##


Train: loss=1.4, dec_loss=1.16, violations=0.577: 100%|██████████| 100/100 [00:36<00:00,  2.73it/s]


Trained on 64790 samples, start_loss=1.38, end_loss=1.4, start_violations=0.606, end_violations=0.577 in 36.6 seconds


100%|██████████| 3/3 [00:00<00:00, 212.57it/s]
100%|██████████| 77/77 [00:00<00:00, 270.26it/s]


lipschitz_k=10.992952194213867 (without delta)
delta=0.010050251256281407
K=0.11048193160013937 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:05<00:00,  5.08s/it]

violations=24731
hard_violations=24731
24731/40000 violated decrease condition
24731/40000 hard violations
Train buffer len: 89521
Grid runtime=5.39 s
info= {'ds_size': 64790, 'lipschitz_k': 10.992952194213867, 'K_f': 1.08, 'K_l': 10.178659439086914, 'iter': 1, 'runtime': 43.37160897254944, 'delta': 0.010050251256281407, 'K': 0.11048193160013937, 'avg_increase': 0, 'dec_violations': '24731/40000', 'hard_violations': '24731/40000'}






## Iteration 2 (1:26 elapsed) ##


Train: loss=1.2, dec_loss=0.993, violations=0.512: 100%|██████████| 100/100 [00:52<00:00,  1.90it/s]


Trained on 89521 samples, start_loss=1.19, end_loss=1.2, start_violations=0.582, end_violations=0.512 in 52.6 seconds


100%|██████████| 3/3 [00:00<00:00, 155.63it/s]
100%|██████████| 77/77 [00:00<00:00, 266.10it/s]


lipschitz_k=11.567951545715333 (without delta)
delta=0.010050251256281407
K=0.11626081955492798 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:09<00:00,  9.34s/it]

violations=25370
hard_violations=25370
25370/40000 violated decrease condition
25370/40000 hard violations
Train buffer len: 114891
Grid runtime=9.67 s
info= {'ds_size': 89521, 'lipschitz_k': 11.567951545715333, 'K_f': 1.08, 'K_l': 10.711066246032715, 'iter': 2, 'runtime': 85.89312410354614, 'delta': 0.010050251256281407, 'K': 0.11626081955492798, 'avg_increase': 0, 'dec_violations': '25370/40000', 'hard_violations': '25370/40000'}






## Iteration 3 (2:29 elapsed) ##


Train: loss=1.07, dec_loss=0.86, violations=0.516: 100%|██████████| 100/100 [01:00<00:00,  1.65it/s] 


Trained on 114891 samples, start_loss=1.06, end_loss=1.07, start_violations=0.515, end_violations=0.516 in 1.0 minutes


100%|██████████| 3/3 [00:00<00:00, 199.73it/s]
100%|██████████| 77/77 [00:00<00:00, 270.66it/s]


lipschitz_k=11.607069740295412 (without delta)
delta=0.010050251256281407
K=0.11665396723914986 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:04<00:00,  4.91s/it]

violations=25551
hard_violations=25551
25551/40000 violated decrease condition
25551/40000 hard violations
Train buffer len: 140442
Grid runtime=5.22 s
info= {'ds_size': 114891, 'lipschitz_k': 11.607069740295412, 'K_f': 1.08, 'K_l': 10.747286796569824, 'iter': 3, 'runtime': 148.77628183364868, 'delta': 0.010050251256281407, 'K': 0.11665396723914986, 'avg_increase': 0, 'dec_violations': '25551/40000', 'hard_violations': '25551/40000'}






## Iteration 4 (3:35 elapsed) ##


Train: loss=1.01, dec_loss=0.776, violations=0.518: 100%|██████████| 100/100 [01:20<00:00,  1.24it/s]


Trained on 140442 samples, start_loss=0.996, end_loss=1.01, start_violations=0.518, end_violations=0.518 in 1.3 minutes


100%|██████████| 3/3 [00:00<00:00, 169.10it/s]
100%|██████████| 77/77 [00:00<00:00, 260.95it/s]


lipschitz_k=11.846925659179687 (without delta)
delta=0.010050251256281407
K=0.1190645794892431 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:02<00:00,  2.05s/it]

violations=25464
hard_violations=25464
25464/40000 violated decrease condition
25464/40000 hard violations
Train buffer len: 165906
Grid runtime=2.38 s
info= {'ds_size': 140442, 'lipschitz_k': 11.846925659179687, 'K_f': 1.08, 'K_l': 10.969375610351562, 'iter': 4, 'runtime': 215.0634059906006, 'delta': 0.010050251256281407, 'K': 0.1190645794892431, 'avg_increase': 0, 'dec_violations': '25464/40000', 'hard_violations': '25464/40000'}






## Iteration 5 (4:59 elapsed) ##


Train: loss=0.903, dec_loss=0.71, violations=0.519: 100%|██████████| 100/100 [01:35<00:00,  1.05it/s] 


Trained on 165906 samples, start_loss=0.909, end_loss=0.903, start_violations=0.516, end_violations=0.519 in 1.6 minutes


100%|██████████| 3/3 [00:00<00:00, 164.19it/s]
100%|██████████| 77/77 [00:00<00:00, 231.80it/s]


lipschitz_k=11.33647647857666 (without delta)
delta=0.010050251256281407
K=0.11393443697061971 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:01<00:00,  1.90s/it]

violations=25383
hard_violations=25383
25383/40000 violated decrease condition
25383/40000 hard violations
Train buffer len: 191289
Grid runtime=2.27 s
info= {'ds_size': 165906, 'lipschitz_k': 11.33647647857666, 'K_f': 1.08, 'K_l': 10.496737480163574, 'iter': 5, 'runtime': 298.55774784088135, 'delta': 0.010050251256281407, 'K': 0.11393443697061971, 'avg_increase': 0, 'dec_violations': '25383/40000', 'hard_violations': '25383/40000'}






## Iteration 6 (6:37 elapsed) ##


Train: loss=0.862, dec_loss=0.661, violations=0.519: 100%|██████████| 100/100 [01:41<00:00,  1.01s/it]


Trained on 191289 samples, start_loss=0.87, end_loss=0.862, start_violations=0.52, end_violations=0.519 in 1.7 minutes


100%|██████████| 3/3 [00:00<00:00, 118.70it/s]
100%|██████████| 77/77 [00:00<00:00, 258.84it/s]


lipschitz_k=11.3644246673584 (without delta)
delta=0.010050251256281407
K=0.11421532329003417 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:01<00:00,  1.66s/it]

violations=25371
hard_violations=25371
25371/40000 violated decrease condition
25371/40000 hard violations
Train buffer len: 216660
Grid runtime=2.00 s
info= {'ds_size': 191289, 'lipschitz_k': 11.3644246673584, 'K_f': 1.08, 'K_l': 10.522615432739258, 'iter': 6, 'runtime': 396.7664999961853, 'delta': 0.010050251256281407, 'K': 0.11421532329003417, 'avg_increase': 0, 'dec_violations': '25371/40000', 'hard_violations': '25371/40000'}






## Iteration 7 (8:21 elapsed) ##


Train: loss=0.813, dec_loss=0.622, violations=0.519:  32%|███▏      | 32/100 [00:40<01:32,  1.36s/it]