In [23]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
import argparse
import os
import sys
import jax
import jax.numpy as jnp
# import jax.random
# from gymnasium import spaces
from tqdm import tqdm

from klax import lipschitz_l1_jax, triangular
from rl_environments import Vandelpol, Test2, LinearLQR
# from rl_environments import LDSEnv, InvertedPendulum, CollisionAvoidanceEnv, Vandelpol
from rsm_learner import Learner
from rsm_verifier import Verifier
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [17]:
parser = argparse.ArgumentParser(
    prog='System Normalization for Neural Certificates',
    description='Learning a reach-avoid supermartingale neural network with normalization techniques',
    epilog='By Musan (Hao Wu) @ SKLCS, Institute of Software, UCAS'
)

parser.add_argument("--debug", default=True, type=bool)
# problem formulation    
# parser.add_argument("--env", default="vandelpol", help='system')
parser.add_argument("--env", default="test2", help='system')
parser.add_argument("--timeout", default=60, type=int, help='max time limit in minutes') 
parser.add_argument("--reach_prob", default=0.8, type=float, help='reach-avoid probability')

parser.add_argument("--project", default=False, type=bool, help='Nomalize the system')

# neural network and training
parser.add_argument("--hidden", default=128, type=int, help='hidden neurons in each layer')
parser.add_argument("--num_layers", default=2, type=int, help='number of hidden layers')
# learner 
parser.add_argument("--continue_rsm", type=int, default=0, help='use an existing network')

# verifier
parser.add_argument("--eps", default=0.01, type=float) # ???
parser.add_argument("--lip", default=0.1, type=float) # ???

 

# parser.add_argument("--p_lip", default=0.0, type=float) 
parser.add_argument("--l_lip", default=10.0, type=float) # ???
parser.add_argument("--fail_check_fast", type=int, default=0)
parser.add_argument("--grid_factor", default=1.0, type=float)
parser.add_argument("--batch_size", default=512, type=int)
# parser.add_argument("--ppo_iters", default=50, type=int)
# parser.add_argument("--policy", default="policies/lds0_zero.jax")
# parser.add_argument("--train_p", type=int, default=1)
parser.add_argument("--square_l_output", default=True)
parser.add_argument("--jitter_grid", type=int, default=0)
parser.add_argument("--soft_constraint", type=int, default=1)
parser.add_argument("--gamma_decrease", default=1.0, type=float)
parser.add_argument("--debug_k0", action="store_true")
parser.add_argument("--gen_plot", action="store_true")
parser.add_argument("--no_refinement", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("--small_mem", action="store_true")
sys.argv = ['test.ipynb', '--env', 'LinearLQR']
args = parser.parse_args()

In [25]:
# env = Vandelpol()
# env = Test2()
env = LinearLQR()
learner = Learner(
    env=env,
    l_hidden=[args.hidden] * args.num_layers,
    l_lip=args.l_lip,
    eps=args.eps,
    gamma_decrease=args.gamma_decrease,
    reach_prob=args.reach_prob,
    softplus_l_output=args.square_l_output,
    debug = args.debug,
)

In [26]:
verifier = Verifier(
    learner=learner,
    env=env,
    batch_size=args.batch_size,
    reach_prob=args.reach_prob,
    fail_check_fast=bool(args.fail_check_fast),
    grid_factor=args.grid_factor,
    small_mem=args.small_mem,
    debug = args.debug,
)

In [27]:
from rsm_loop import RSMLoop
loop = RSMLoop( 
        learner,
        verifier,
        env,
        lip_factor=args.lip,
        plot=args.plot,
        jitter_grid=bool(args.jitter_grid),
        soft_constraint=bool(args.soft_constraint),
        debug = args.debug,
    )


loop.plot_l(f"plots/{args.env}_start.png")

In [21]:
loop.run(args.timeout*60)


## Iteration 0 (0:00 elapsed) ##



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Train: loss=4.05, dec_loss=1.73, violations=0.425: 100%|██████████| 100/100 [00:37<00:00,  2.70it/s]


Trained on 40000 samples, start_loss=25.6, end_loss=4.05, start_violations=0.514, end_violations=0.425 in 37.1 seconds



100%|██████████| 3/3 [00:00<00:00, 137.10it/s]

[A
[A
[A
100%|██████████| 77/77 [00:00<00:00, 215.37it/s]


lipschitz_k=33.01280227661133 (without delta)
delta=0.010050251256281407
K=0.3317869575538827 (with delta)
Checking GRID of size 200



[A
100%|██████████| 1/1 [00:15<00:00, 15.48s/it]

violations=20557
hard_violations=20557
20557/40000 violated decrease condition
20557/40000 hard violations
Train buffer len: 60557
Grid runtime=15.88 s
info= {'ds_size': 40000, 'lipschitz_k': 33.01280227661133, 'K_f': 1.08, 'K_l': 30.56740951538086, 'iter': 0, 'runtime': 0.0020599365234375, 'delta': 0.010050251256281407, 'K': 0.3317869575538827, 'avg_increase': 0, 'dec_violations': '20557/40000', 'hard_violations': '20557/40000'}






## Iteration 1 (0:53 elapsed) ##



Train: loss=0.812, dec_loss=0.621, violations=0.52:  36%|███▌      | 36/100 [28:11<50:06, 46.98s/it]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Train: loss=3.68, dec_loss=1.52, violations=0.468: 100%|██████████| 100/100 [00:50<00:00,  1.98it

Trained on 60557 samples, start_loss=3.84, end_loss=3.68, start_violations=0.468, end_violations=0.468 in 50.6 seconds


100%|██████████| 3/3 [00:00<00:00, 159.19it/s]
100%|██████████| 77/77 [00:00<00:00, 259.42it/s]


lipschitz_k=30.778768157958986 (without delta)
delta=0.010050251256281407
K=0.3093343533463215 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:10<00:00, 10.21s/it]

violations=21085
hard_violations=21085
21085/40000 violated decrease condition
21085/40000 hard violations
Train buffer len: 81642
Grid runtime=10.55 s
info= {'ds_size': 60557, 'lipschitz_k': 30.778768157958986, 'K_f': 1.08, 'K_l': 28.498859405517578, 'iter': 1, 'runtime': 53.11841106414795, 'delta': 0.010050251256281407, 'K': 0.3093343533463215, 'avg_increase': 0, 'dec_violations': '21085/40000', 'hard_violations': '21085/40000'}






## Iteration 2 (1:54 elapsed) ##


Train: loss=3.52, dec_loss=1.34, violations=0.483: 100%|██████████| 100/100 [01:05<00:00,  1.53it/s]


Trained on 81642 samples, start_loss=3.55, end_loss=3.52, start_violations=0.483, end_violations=0.483 in 1.1 minutes


100%|██████████| 3/3 [00:00<00:00, 170.20it/s]
100%|██████████| 77/77 [00:00<00:00, 253.31it/s]


lipschitz_k=31.471372375488283 (without delta)
delta=0.010050251256281407
K=0.3162951997536511 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:10<00:00, 10.12s/it]

violations=20805
hard_violations=20805
20805/40000 violated decrease condition
20805/40000 hard violations
Train buffer len: 102447
Grid runtime=10.47 s
info= {'ds_size': 81642, 'lipschitz_k': 31.471372375488283, 'K_f': 1.08, 'K_l': 29.140159606933594, 'iter': 2, 'runtime': 114.3666410446167, 'delta': 0.010050251256281407, 'K': 0.3162951997536511, 'avg_increase': 0, 'dec_violations': '20805/40000', 'hard_violations': '20805/40000'}






## Iteration 3 (3:10 elapsed) ##


Train: loss=3.34, dec_loss=1.2, violations=0.491: 100%|██████████| 100/100 [01:20<00:00,  1.24it/s]


Trained on 102447 samples, start_loss=3.39, end_loss=3.34, start_violations=0.491, end_violations=0.491 in 1.3 minutes


100%|██████████| 3/3 [00:00<00:00, 144.03it/s]
100%|██████████| 77/77 [00:00<00:00, 259.23it/s]


lipschitz_k=31.002335128784182 (without delta)
delta=0.010050251256281407
K=0.3115812575757204 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:09<00:00,  9.07s/it]

violations=20519
hard_violations=20519
20519/40000 violated decrease condition
20519/40000 hard violations
Train buffer len: 122966
Grid runtime=9.41 s
info= {'ds_size': 102447, 'lipschitz_k': 31.002335128784182, 'K_f': 1.08, 'K_l': 28.70586585998535, 'iter': 3, 'runtime': 190.4919810295105, 'delta': 0.010050251256281407, 'K': 0.3115812575757204, 'avg_increase': 0, 'dec_violations': '20519/40000', 'hard_violations': '20519/40000'}






## Iteration 4 (4:41 elapsed) ##


Train: loss=3.25, dec_loss=1.12, violations=0.494: 100%|██████████| 100/100 [01:36<00:00,  1.04it/s]


Trained on 122966 samples, start_loss=3.28, end_loss=3.25, start_violations=0.494, end_violations=0.494 in 1.6 minutes


100%|██████████| 3/3 [00:00<00:00, 159.02it/s]
100%|██████████| 77/77 [00:00<00:00, 257.50it/s]


lipschitz_k=30.356561508178714 (without delta)
delta=0.010050251256281407
K=0.30509107043395695 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:05<00:00,  5.57s/it]

violations=21075
hard_violations=21075
21075/40000 violated decrease condition
21075/40000 hard violations
Train buffer len: 144041
Grid runtime=5.91 s
info= {'ds_size': 122966, 'lipschitz_k': 30.356561508178714, 'K_f': 1.08, 'K_l': 28.107927322387695, 'iter': 4, 'runtime': 280.7755649089813, 'delta': 0.010050251256281407, 'K': 0.30509107043395695, 'avg_increase': 0, 'dec_violations': '21075/40000', 'hard_violations': '21075/40000'}






## Iteration 5 (6:23 elapsed) ##


Train: loss=3.24, dec_loss=1.06, violations=0.496: 100%|██████████| 100/100 [01:51<00:00,  1.12s/it]


Trained on 144041 samples, start_loss=3.24, end_loss=3.24, start_violations=0.497, end_violations=0.496 in 1.9 minutes


100%|██████████| 3/3 [00:00<00:00, 144.26it/s]
100%|██████████| 77/77 [00:00<00:00, 252.66it/s]


lipschitz_k=30.724979095458988 (without delta)
delta=0.010050251256281407
K=0.30879375975335666 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:03<00:00,  3.76s/it]

violations=20893
hard_violations=20893
20893/40000 violated decrease condition
20893/40000 hard violations
Train buffer len: 164934
Grid runtime=4.11 s
info= {'ds_size': 144041, 'lipschitz_k': 30.724979095458988, 'K_f': 1.08, 'K_l': 28.449054718017578, 'iter': 5, 'runtime': 383.0679111480713, 'delta': 0.010050251256281407, 'K': 0.30879375975335666, 'avg_increase': 0, 'dec_violations': '20893/40000', 'hard_violations': '20893/40000'}






## Iteration 6 (8:19 elapsed) ##


Train: loss=3.15, dec_loss=1.02, violations=0.498: 100%|██████████| 100/100 [02:08<00:00,  1.29s/it]


Trained on 164934 samples, start_loss=3.17, end_loss=3.15, start_violations=0.498, end_violations=0.498 in 2.1 minutes


100%|██████████| 3/3 [00:00<00:00, 186.55it/s]
100%|██████████| 77/77 [00:00<00:00, 246.22it/s]


lipschitz_k=30.699755172729493 (without delta)
delta=0.010050251256281407
K=0.30854025299225624 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:04<00:00,  4.97s/it]

violations=21490
hard_violations=21490
21490/40000 violated decrease condition
21490/40000 hard violations
Train buffer len: 186424
Grid runtime=5.32 s
info= {'ds_size': 164934, 'lipschitz_k': 30.699755172729493, 'K_f': 1.08, 'K_l': 28.42569923400879, 'iter': 6, 'runtime': 498.8892910480499, 'delta': 0.010050251256281407, 'K': 0.30854025299225624, 'avg_increase': 0, 'dec_violations': '21490/40000', 'hard_violations': '21490/40000'}






## Iteration 7 (10:33 elapsed) ##


Train: loss=3.11, dec_loss=0.978, violations=0.499: 100%|██████████| 100/100 [02:23<00:00,  1.43s/it]


Trained on 186424 samples, start_loss=3.07, end_loss=3.11, start_violations=0.499, end_violations=0.499 in 2.4 minutes


100%|██████████| 3/3 [00:00<00:00, 167.31it/s]
100%|██████████| 77/77 [00:00<00:00, 255.35it/s]


lipschitz_k=30.98436218261719 (without delta)
delta=0.010050251256281407
K=0.3114006249509265 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:05<00:00,  5.43s/it]

violations=21899
hard_violations=21899
21899/40000 violated decrease condition
21899/40000 hard violations
Train buffer len: 208323
Grid runtime=5.77 s
info= {'ds_size': 186424, 'lipschitz_k': 30.98436218261719, 'K_f': 1.08, 'K_l': 28.689224243164062, 'iter': 7, 'runtime': 633.2700700759888, 'delta': 0.010050251256281407, 'K': 0.3114006249509265, 'avg_increase': 0, 'dec_violations': '21899/40000', 'hard_violations': '21899/40000'}






## Iteration 8 (13:02 elapsed) ##


Train: loss=3.05, dec_loss=0.93, violations=0.5: 100%|██████████| 100/100 [02:34<00:00,  1.55s/it]  


Trained on 208323 samples, start_loss=3.13, end_loss=3.05, start_violations=0.5, end_violations=0.5 in 2.6 minutes


100%|██████████| 3/3 [00:00<00:00, 187.73it/s]
100%|██████████| 77/77 [00:00<00:00, 256.76it/s]


lipschitz_k=30.13361251831055 (without delta)
delta=0.010050251256281407
K=0.30285037706844775 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:03<00:00,  3.34s/it]

violations=20389
hard_violations=20389
20389/40000 violated decrease condition
20389/40000 hard violations
Train buffer len: 228712
Grid runtime=3.68 s
info= {'ds_size': 208323, 'lipschitz_k': 30.13361251831055, 'K_f': 1.08, 'K_l': 27.901493072509766, 'iter': 8, 'runtime': 782.4714872837067, 'delta': 0.010050251256281407, 'K': 0.30285037706844775, 'avg_increase': 0, 'dec_violations': '20389/40000', 'hard_violations': '20389/40000'}






## Iteration 9 (15:41 elapsed) ##


Train: loss=3.03, dec_loss=0.912, violations=0.5: 100%|██████████| 100/100 [02:50<00:00,  1.70s/it]


Trained on 228712 samples, start_loss=3.05, end_loss=3.03, start_violations=0.5, end_violations=0.5 in 2.8 minutes


100%|██████████| 3/3 [00:00<00:00, 181.50it/s]
100%|██████████| 77/77 [00:00<00:00, 253.99it/s]


lipschitz_k=30.761658325195313 (without delta)
delta=0.010050251256281407
K=0.3091623952280936 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:04<00:00,  4.25s/it]

violations=20862
hard_violations=20862
20862/40000 violated decrease condition
20862/40000 hard violations
Train buffer len: 249574
Grid runtime=4.60 s
info= {'ds_size': 228712, 'lipschitz_k': 30.761658325195313, 'K_f': 1.08, 'K_l': 28.483016967773438, 'iter': 9, 'runtime': 941.2607421875, 'delta': 0.010050251256281407, 'K': 0.3091623952280936, 'avg_increase': 0, 'dec_violations': '20862/40000', 'hard_violations': '20862/40000'}






## Iteration 10 (18:36 elapsed) ##


Train: loss=2.98, dec_loss=0.878, violations=0.501: 100%|██████████| 100/100 [03:04<00:00,  1.85s/it]


Trained on 249574 samples, start_loss=2.98, end_loss=2.98, start_violations=0.501, end_violations=0.501 in 3.1 minutes


100%|██████████| 3/3 [00:00<00:00, 121.88it/s]
100%|██████████| 77/77 [00:00<00:00, 263.17it/s]


lipschitz_k=30.511343078613283 (without delta)
delta=0.010050251256281407
K=0.3066466641066662 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:04<00:00,  4.59s/it]

violations=22073
hard_violations=22073
22073/40000 violated decrease condition
22073/40000 hard violations
Train buffer len: 271647
Grid runtime=4.92 s
info= {'ds_size': 249574, 'lipschitz_k': 30.511343078613283, 'K_f': 1.08, 'K_l': 28.251243591308594, 'iter': 10, 'runtime': 1116.24321103096, 'delta': 0.010050251256281407, 'K': 0.3066466641066662, 'avg_increase': 0, 'dec_violations': '22073/40000', 'hard_violations': '22073/40000'}






## Iteration 11 (21:46 elapsed) ##


Train: loss=2.98, dec_loss=0.869, violations=0.501: 100%|██████████| 100/100 [03:21<00:00,  2.02s/it]


Trained on 271647 samples, start_loss=2.99, end_loss=2.98, start_violations=0.501, end_violations=0.501 in 3.4 minutes


100%|██████████| 3/3 [00:00<00:00, 193.90it/s]
100%|██████████| 77/77 [00:00<00:00, 256.57it/s]


lipschitz_k=30.65292045593262 (without delta)
delta=0.010050251256281407
K=0.30806955232093086 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:03<00:00,  3.76s/it]

violations=20641
hard_violations=20641
20641/40000 violated decrease condition
20641/40000 hard violations
Train buffer len: 292288
Grid runtime=4.10 s
info= {'ds_size': 271647, 'lipschitz_k': 30.65292045593262, 'K_f': 1.08, 'K_l': 28.382333755493164, 'iter': 11, 'runtime': 1306.317039012909, 'delta': 0.010050251256281407, 'K': 0.30806955232093086, 'avg_increase': 0, 'dec_violations': '20641/40000', 'hard_violations': '20641/40000'}






## Iteration 12 (25:12 elapsed) ##


Train: loss=2.91, dec_loss=0.849, violations=0.501: 100%|██████████| 100/100 [03:39<00:00,  2.19s/it]


Trained on 292288 samples, start_loss=2.94, end_loss=2.91, start_violations=0.501, end_violations=0.501 in 3.7 minutes


100%|██████████| 3/3 [00:00<00:00, 170.23it/s]
100%|██████████| 77/77 [00:00<00:00, 225.45it/s]


lipschitz_k=29.86609268188477 (without delta)
delta=0.010050251256281407
K=0.30016173549632935 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:03<00:00,  3.70s/it]

violations=20316
hard_violations=20316
20316/40000 violated decrease condition
20316/40000 hard violations
Train buffer len: 312604
Grid runtime=4.08 s
info= {'ds_size': 292288, 'lipschitz_k': 29.86609268188477, 'K_f': 1.08, 'K_l': 27.653789520263672, 'iter': 12, 'runtime': 1512.1729819774628, 'delta': 0.010050251256281407, 'K': 0.30016173549632935, 'avg_increase': 0, 'dec_violations': '20316/40000', 'hard_violations': '20316/40000'}






## Iteration 13 (28:56 elapsed) ##


Train: loss=2.93, dec_loss=0.841, violations=0.501: 100%|██████████| 100/100 [03:52<00:00,  2.32s/it]


Trained on 312604 samples, start_loss=2.9, end_loss=2.93, start_violations=0.501, end_violations=0.501 in 3.9 minutes


100%|██████████| 3/3 [00:00<00:00, 139.77it/s]
100%|██████████| 77/77 [00:00<00:00, 262.83it/s]


lipschitz_k=31.322158813476566 (without delta)
delta=0.010050251256281407
K=0.31479556596458863 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:03<00:00,  3.13s/it]

violations=21040
hard_violations=21040
21040/40000 violated decrease condition
21040/40000 hard violations
Train buffer len: 333644
Grid runtime=3.47 s
info= {'ds_size': 312604, 'lipschitz_k': 31.322158813476566, 'K_f': 1.08, 'K_l': 29.001998901367188, 'iter': 13, 'runtime': 1735.6440451145172, 'delta': 0.010050251256281407, 'K': 0.31479556596458863, 'avg_increase': 0, 'dec_violations': '21040/40000', 'hard_violations': '21040/40000'}






## Iteration 14 (32:51 elapsed) ##


Train: loss=2.89, dec_loss=0.821, violations=0.501: 100%|██████████| 100/100 [04:08<00:00,  2.48s/it]


Trained on 333644 samples, start_loss=2.92, end_loss=2.89, start_violations=0.501, end_violations=0.501 in 4.1 minutes


100%|██████████| 3/3 [00:00<00:00, 184.51it/s]
100%|██████████| 77/77 [00:00<00:00, 253.40it/s]


lipschitz_k=30.584705657958985 (without delta)
delta=0.010050251256281407
K=0.30738397646189936 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:03<00:00,  3.57s/it]

violations=21702
hard_violations=21702
21702/40000 violated decrease condition
21702/40000 hard violations
Train buffer len: 355346
Grid runtime=3.92 s
info= {'ds_size': 333644, 'lipschitz_k': 30.584705657958985, 'K_f': 1.08, 'K_l': 28.319171905517578, 'iter': 14, 'runtime': 1971.3920311927795, 'delta': 0.010050251256281407, 'K': 0.30738397646189936, 'avg_increase': 0, 'dec_violations': '21702/40000', 'hard_violations': '21702/40000'}






## Iteration 15 (37:04 elapsed) ##


Train: loss=2.89, dec_loss=0.811, violations=0.501: 100%|██████████| 100/100 [04:21<00:00,  2.62s/it]


Trained on 355346 samples, start_loss=2.9, end_loss=2.89, start_violations=0.502, end_violations=0.501 in 4.4 minutes


100%|██████████| 3/3 [00:00<00:00, 179.93it/s]
100%|██████████| 77/77 [00:00<00:00, 221.72it/s]


lipschitz_k=30.901111907958985 (without delta)
delta=0.010050251256281407
K=0.3105639387734571 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:04<00:00,  4.06s/it]

violations=20065
hard_violations=20065
20065/40000 violated decrease condition
20065/40000 hard violations
Train buffer len: 375411
Grid runtime=4.45 s
info= {'ds_size': 355346, 'lipschitz_k': 30.901111907958985, 'K_f': 1.08, 'K_l': 28.612140655517578, 'iter': 15, 'runtime': 2223.5274879932404, 'delta': 0.010050251256281407, 'K': 0.3105639387734571, 'avg_increase': 0, 'dec_violations': '20065/40000', 'hard_violations': '20065/40000'}






## Iteration 16 (41:30 elapsed) ##


Train: loss=2.86, dec_loss=0.8, violations=0.502: 100%|██████████| 100/100 [04:36<00:00,  2.77s/it]  


Trained on 375411 samples, start_loss=2.88, end_loss=2.86, start_violations=0.502, end_violations=0.502 in 4.6 minutes


100%|██████████| 3/3 [00:00<00:00, 131.81it/s]
100%|██████████| 77/77 [00:00<00:00, 262.31it/s]


lipschitz_k=29.97559066772461 (without delta)
delta=0.010050251256281407
K=0.3012622177660765 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:02<00:00,  2.69s/it]

violations=21724
hard_violations=21724
21724/40000 violated decrease condition
21724/40000 hard violations
Train buffer len: 397135
Grid runtime=3.03 s
info= {'ds_size': 375411, 'lipschitz_k': 29.97559066772461, 'K_f': 1.08, 'K_l': 27.755176544189453, 'iter': 16, 'runtime': 2489.93682718277, 'delta': 0.010050251256281407, 'K': 0.3012622177660765, 'avg_increase': 0, 'dec_violations': '21724/40000', 'hard_violations': '21724/40000'}






## Iteration 17 (46:10 elapsed) ##


Train: loss=2.86, dec_loss=0.793, violations=0.502: 100%|██████████| 100/100 [04:57<00:00,  2.97s/it]


Trained on 397135 samples, start_loss=2.88, end_loss=2.86, start_violations=0.502, end_violations=0.502 in 5.0 minutes


100%|██████████| 3/3 [00:00<00:00, 155.09it/s]
100%|██████████| 77/77 [00:00<00:00, 241.24it/s]


lipschitz_k=30.567453689575196 (without delta)
delta=0.010050251256281407
K=0.30721058984497684 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:03<00:00,  3.53s/it]

violations=19953
hard_violations=19953
19953/40000 violated decrease condition
19953/40000 hard violations
Train buffer len: 417088
Grid runtime=3.89 s
info= {'ds_size': 397135, 'lipschitz_k': 30.567453689575196, 'K_f': 1.08, 'K_l': 28.303197860717773, 'iter': 17, 'runtime': 2769.938891887665, 'delta': 0.010050251256281407, 'K': 0.30721058984497684, 'avg_increase': 0, 'dec_violations': '19953/40000', 'hard_violations': '19953/40000'}






## Iteration 18 (51:11 elapsed) ##


Train: loss=2.87, dec_loss=0.787, violations=0.502: 100%|██████████| 100/100 [05:08<00:00,  3.08s/it]


Trained on 417088 samples, start_loss=2.87, end_loss=2.87, start_violations=0.502, end_violations=0.502 in 5.1 minutes


100%|██████████| 3/3 [00:00<00:00, 137.61it/s]
100%|██████████| 77/77 [00:00<00:00, 255.31it/s]


lipschitz_k=30.242424545288088 (without delta)
delta=0.010050251256281407
K=0.30394396527927725 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:02<00:00,  2.28s/it]

violations=20378
hard_violations=20378
20378/40000 violated decrease condition
20378/40000 hard violations
Train buffer len: 437466
Grid runtime=2.64 s
info= {'ds_size': 417088, 'lipschitz_k': 30.242424545288088, 'K_f': 1.08, 'K_l': 28.00224494934082, 'iter': 18, 'runtime': 3071.341467142105, 'delta': 0.010050251256281407, 'K': 0.30394396527927725, 'avg_increase': 0, 'dec_violations': '20378/40000', 'hard_violations': '20378/40000'}






## Iteration 19 (56:22 elapsed) ##


Train: loss=2.86, dec_loss=0.776, violations=0.502: 100%|██████████| 100/100 [05:20<00:00,  3.21s/it]


Trained on 437466 samples, start_loss=2.86, end_loss=2.86, start_violations=0.502, end_violations=0.502 in 5.3 minutes


100%|██████████| 3/3 [00:00<00:00, 182.34it/s]
100%|██████████| 77/77 [00:00<00:00, 258.31it/s]


lipschitz_k=29.83112113952637 (without delta)
delta=0.010050251256281407
K=0.29981026270880773 (with delta)
Checking GRID of size 200


100%|██████████| 1/1 [00:01<00:00,  1.92s/it]

violations=21359
hard_violations=21359
21359/40000 violated decrease condition
21359/40000 hard violations
Train buffer len: 458825
Grid runtime=2.26 s
info= {'ds_size': 437466, 'lipschitz_k': 29.83112113952637, 'K_f': 1.08, 'K_l': 27.621408462524414, 'iter': 19, 'runtime': 3382.314952135086, 'delta': 0.010050251256281407, 'K': 0.29981026270880773, 'avg_increase': 0, 'dec_violations': '21359/40000', 'hard_violations': '21359/40000'}
Timeout!





False

In [22]:
devices = jax.devices()
print(devices)

[CpuDevice(id=0)]
