Write a simple example for dys with rosenbrock 

Created on 8/July-2024


# dependency

In [1]:
from plnet.layer import PBiLipNet, PPLNet, MLP
from plnet.solver import mln_back_solve_dys_demo, get_partial_bilipnet_params
import jax.random as random
import orbax.checkpoint
from plnet.rosenbrock_utils import Sampler, PRosenbrock
import matplotlib.pyplot as plt
import scipy.io 

from plnet.train import data_gen_partial, train_partial



# Data
Generate the data with rosenbrock

# Train 


Default values

In [2]:
data_dim = 20
lr_max = 1e-2
epochs = 500
n_batch = 50
name = 'PBiLipNet'

depth = 2 
p_dim = 2
units = [256] * 2
# for unitary, which is data dimension
po_units = [256, data_dim]  
# for monlip,zx 
pb_units = [256, sum(units)]

root_dir = f'/home/rover/Desktop/rl_with_plnet/RL-with-PLnet/rl/plnet/results_exp/{name}-rosenbrock-dim{data_dim}-batch{n_batch}'
rng = random.PRNGKey(42)
rng, rng_data = random.split(rng, 2)


## Train func
train examples for rosenbrock with some configurations

## Generate data and train

In [4]:
data= data_gen_partial(rng_data, train_batches=n_batch, data_dim=data_dim)
print(data["xtrain"].shape)
print(data["ptrain"].shape)
print(data["ytrain"].shape)

(10000, 20)
(10000, 2)
(10000,)


In [4]:
for tau in [2]:
	train_dir = f'{root_dir}/{name}-{depth}-tau{tau}'
	block = PBiLipNet(units, po_units, pb_units, depth=depth, tau=tau)
	model = PPLNet(block)
	train_partial(rng, model, data, name=name, train_dir=train_dir, lr_max=lr_max, epochs=epochs)


model: PBiLipNet, size: 0.70M
Epoch:   1 | loss: 2257.6484/1117.7542, tau: 2.0, Lip: 1.878/3.76
Epoch:   2 | loss: 977.9875/850.7519, tau: 2.0, Lip: 1.788/3.58
Epoch:   3 | loss: 766.0264/675.7066, tau: 2.0, Lip: 1.700/3.40
Epoch:   4 | loss: 609.0770/537.6015, tau: 2.0, Lip: 1.616/3.23
Epoch:   5 | loss: 487.1208/433.0805, tau: 2.0, Lip: 1.540/3.08
Epoch:   6 | loss: 393.1337/353.8495, tau: 2.0, Lip: 1.472/2.94
Epoch:   7 | loss: 320.8917/289.4265, tau: 2.0, Lip: 1.409/2.82
Epoch:   8 | loss: 264.5515/238.7356, tau: 2.0, Lip: 1.353/2.71
Epoch:   9 | loss: 219.0271/198.9546, tau: 2.0, Lip: 1.302/2.60
Epoch:  10 | loss: 183.2645/166.4635, tau: 2.0, Lip: 1.255/2.51
Epoch:  11 | loss: 154.5052/140.9893, tau: 2.0, Lip: 1.213/2.43
Epoch:  12 | loss: 131.1465/120.3471, tau: 2.0, Lip: 1.174/2.35
Epoch:  13 | loss: 111.7632/102.7805, tau: 2.0, Lip: 1.138/2.28
Epoch:  14 | loss: 95.7979/88.4605, tau: 2.0, Lip: 1.105/2.21
Epoch:  15 | loss: 82.4725/76.5294, tau: 2.0, Lip: 1.075/2.15
Epoch:  16 |

# Solve


Restore the model

In [3]:
tau=2

block = PBiLipNet(units, po_units, pb_units, depth=depth, tau=tau)
model = PPLNet(block)
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

train_dir = f'{root_dir}/{name}-{depth}-tau{tau}'
# where the param comes from
params = orbax_checkpointer.restore(f'{train_dir}/ckpt/params')

rng = random.PRNGKey(43)
key_z, keyp = random.split(rng)
z = Sampler(rng, 10000, data_dim)
p = Sampler(rng, 10000, p_dim)

# run plnet func
fn = lambda x, p: model.apply(params, x, p)

# model.apply(params, z, bh_p)

In [4]:
# test 
print(params['params']['PBiLipBlock'])

{'mon_0': {'Fab0': array([[ 1.1653428 , -0.0687464 , -0.05806358, ..., -0.0464643 ,
        -0.15812951,  0.34365946],
       [ 0.03124196,  2.4540002 , -0.08172581, ..., -0.0605182 ,
        -0.023707  ,  0.2202434 ],
       [ 0.09185564,  0.05957979,  2.2282016 , ..., -0.06542158,
        -0.13526024,  0.03641318],
       ...,
       [ 0.16434495, -0.00283598,  0.09350877, ...,  1.1284621 ,
        -0.04327191,  0.07825875],
       [-0.00636231,  0.03510911,  0.08703255, ..., -0.07029201,
         1.4553713 , -0.06496897],
       [-0.19476318, -0.16457419, -0.09534018, ..., -0.06068465,
         0.07194604,  0.8192592 ]], dtype=float32), 'Fab1': array([[-1.98145664e+00,  2.53760628e-02,  4.49671894e-02, ...,
        -1.62805419e-03,  3.90794761e-02, -2.57879924e-02],
       [-6.65239543e-02, -1.44831812e+00,  3.37532093e-03, ...,
        -1.50213212e-01,  3.90280299e-02,  2.69624442e-02],
       [-1.09072372e-01,  5.09645268e-02,  2.07885575e+00, ...,
        -1.27384573e-01,  7.3956

solve the x based on given z

In [5]:
max_iter = 50
alpha = 1.0
Lambda = 1.0


In [6]:
(uni_params, mon_params, b_params, bh_params) = get_partial_bilipnet_params(params, p, 
                                                	tau, depth, units, po_units, pb_units)



In [7]:
data = mln_back_solve_dys_demo(uni_params, mon_params, b_params, bh_params, 
                          z, fn, p, units, max_iter=max_iter, alpha=alpha, Lambda=Lambda, depth = depth)


Iter.    0 | v: 4.71569490
Iter.    1 | v: 44.84608459
Iter.    2 | v: 11.06963825
Iter.    3 | v: 4.53247738
Iter.    4 | v: 4.21412945
Iter.    5 | v: 4.42254305
Iter.    6 | v: 4.05555010
Iter.    7 | v: 3.31029677
Iter.    8 | v: 2.48341203
Iter.    9 | v: 1.74969745
Iter.   10 | v: 1.17705703
Iter.   11 | v: 0.76314569
Iter.   12 | v: 0.48058826
Iter.   13 | v: 0.29559582
Iter.   14 | v: 0.17830700
Iter.   15 | v: 0.10582250
Iter.   16 | v: 0.06194878
Iter.   17 | v: 0.03584557
Iter.   18 | v: 0.02053519
Iter.   19 | v: 0.01166179
Iter.   20 | v: 0.00657249
Iter.   21 | v: 0.00367954
Iter.   22 | v: 0.00204774
Iter.   23 | v: 0.00113354
Iter.   24 | v: 0.00062465
Iter.   25 | v: 0.00034270
Iter.   26 | v: 0.00018737
Iter.   27 | v: 0.00010204
Iter.   28 | v: 0.00005545
Iter.   29 | v: 0.00003005
Iter.   30 | v: 0.00001627
Iter.   31 | v: 0.00000881
Iter.   32 | v: 0.00000478
Iter.   33 | v: 0.00000261
Iter.   34 | v: 0.00000144
Iter.   35 | v: 0.00000081
Iter.   36 | v: 0.00000047

save

In [10]:
plt.semilogy(data['step'], data['vgap'])
plt.savefig(f'{train_dir}/DYS-PLNet-alpha{alpha:.1f}-lambda{Lambda:.1f}.pdf')
plt.close()
scipy.io.savemat(f'{train_dir}/DYS-PLNet-alpha{alpha:.1f}-lambda{Lambda:.1f}.mat', data)

plot