In [77]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [78]:
import argparse
import os
import sys
import jax
import jax.numpy as jnp
import jax.random
from gymnasium import spaces
from tqdm import tqdm

from klax import lipschitz_l1_jax, triangular
from rl_environments import Vandelpol
# from rl_environments import LDSEnv, InvertedPendulum, CollisionAvoidanceEnv, Vandelpol
from rsm_learner import Learner
from rsm_verifier import Verifier
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [79]:
parser = argparse.ArgumentParser(
    prog='System Normalization for Neural Certificates',
    description='Learning a reach-avoid supermartingale neural network with normalization techniques',
    epilog='By Musan (Hao Wu) @ SKLCS, Institute of Software, UCAS'
)

parser.add_argument("--debug", default=True, type=bool)
# problem formulation    
parser.add_argument("--env", default="vandelpol", help='control system')
parser.add_argument("--timeout", default=60, type=int, help='max time limit in minutes') 
parser.add_argument("--reach_prob", default=0.8, type=float, help='reach-avoid probability')

parser.add_argument("--project", default=False, type=bool, help='Nomalize the system')

# neural network and training
parser.add_argument("--hidden", default=128, type=int, help='hidden neurons in each layer')
parser.add_argument("--num_layers", default=2, type=int, help='number of hidden layers')
# learner 
parser.add_argument("--continue_rsm", type=int, default=0, help='use an existing network')

# verifier
parser.add_argument("--eps", default=0.05, type=float) # ???
parser.add_argument("--lip", default=0.01, type=float) # ???

 

# parser.add_argument("--p_lip", default=0.0, type=float) 
parser.add_argument("--l_lip", default=4.0, type=float) # ???
parser.add_argument("--fail_check_fast", type=int, default=0)
parser.add_argument("--grid_factor", default=1.0, type=float)
parser.add_argument("--batch_size", default=512, type=int)
# parser.add_argument("--ppo_iters", default=50, type=int)
# parser.add_argument("--policy", default="policies/lds0_zero.jax")
# parser.add_argument("--train_p", type=int, default=1)
parser.add_argument("--square_l_output", default=True)
parser.add_argument("--jitter_grid", type=int, default=0)
parser.add_argument("--soft_constraint", type=int, default=1)
parser.add_argument("--gamma_decrease", default=1.0, type=float)
parser.add_argument("--debug_k0", action="store_true")
parser.add_argument("--gen_plot", action="store_true")
parser.add_argument("--no_refinement", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("--small_mem", action="store_true")
sys.argv = ['test.ipynb', '--env', 'vandelpol']
args = parser.parse_args()

In [80]:
env = Vandelpol()
learner = Learner(
    env=env,
    l_hidden=[args.hidden] * args.num_layers,
    l_lip=args.l_lip,
    eps=args.eps,
    gamma_decrease=args.gamma_decrease,
    reach_prob=args.reach_prob,
    softplus_l_output=args.square_l_output,
    debug = args.debug,
)

In [81]:
verifier = Verifier(
    learner=learner,
    env=env,
    batch_size=args.batch_size,
    reach_prob=args.reach_prob,
    fail_check_fast=bool(args.fail_check_fast),
    grid_factor=args.grid_factor,
    small_mem=args.small_mem,
    debug = args.debug,
)

In [82]:
from rsm_loop import RSMLoop
loop = RSMLoop( 
        learner,
        verifier,
        env,
        lip_factor=args.lip,
        plot=args.plot,
        jitter_grid=bool(args.jitter_grid),
        soft_constraint=bool(args.soft_constraint),
        debug = args.debug,
    )


# loop.plot_l(f"plots/{args.env}_start.png")

In [64]:
loop.run(args.timeout*10)


## Iteration 0 (0:00 elapsed) ##
create grid.


Train: loss=6.37e+06, dec_loss=5.68e+06, violations=0.51: 100%|██████████| 100/100 [03:42<00:00,  2.23s/it]


Trained on 40000 samples, start_loss=8.93, end_loss=6.37e+06, start_violations=0.496, end_violations=0.51 in 3.7 minutes


100%|██████████| 2/2 [00:00<00:00, 27.43it/s]
100%|██████████| 76/76 [00:03<00:00, 25.18it/s]


lipschitz_k=71967882.24000001 (without delta)
delta=0.004008016032064128
K=288448.4258116233 (with delta)
Checking GRID of size 500




KeyboardInterrupt: 

In [86]:
loop.verifier.train_buffer._cached_ds

NoneType

In [87]:
import tensorflow as tf
train_ds = loop.verifier.train_buffer.as_tfds(batch_size=4096)
iterator = train_ds.as_numpy_iterator()

ValueError: need at least one array to concatenate

In [72]:
isinstance(train_ds, tf.data.Dataset)

True

In [76]:
for state in iterator:
    state = jnp.array(state)
    print(state)
    break

[[ 0.15499997 -0.46500003]
 [ 0.66499996 -0.55499995]
 [ 0.54499996  0.775     ]
 ...
 [-0.47500002  0.515     ]
 [ 0.11500002  0.515     ]
 [ 0.08499993 -0.995     ]]
