In [None]:
import sys
import os
# Add the src directory to Python path so model.py can find ssn and net modules
sys.path.append(os.path.abspath('../src'))
import numpy as np
from loguru import logger
import torch

Load the data that is generated from the open-loop optimization

In [3]:
# load the data
path = '../data_result/raw_data/VDP_beta_0.1_grid_30x30.npy'# Initialize the weights
data = np.load(path)
logger.info(f"Loaded data with shape: {data.shape}, dtype: {data.dtype}")

[32m2025-09-22 02:04:16.976[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mLoaded data with shape: (900,), dtype: [('x', '<f8', (2,)), ('dv', '<f8', (2,)), ('v', '<f8')][0m


## SSN(line-search) method for outer weights ##


In [8]:
# Initialize the parameter
power = 2.1
M = 50 # number greedy insertion selected
num_iterations = 10
loss_weights = (1.0, 0.0)
pruning_threshold = 1e-15

gamma = 5.0
alpha = 1e-5
lr_adam = 1e-5
regularization = (gamma, alpha) 
th = 0.0

In [14]:
from src.model import model
test = model(activation=torch.relu, power=power, regularization=regularization, optimizer='SSN_TR', loss_weights=loss_weights, th=th, train_outerweights=True)

# prepare the data
data_train, data_valid = test._prepare_data(data)

from src.greedy_insertion import _sample_uniform_sphere_points
W_hidden, b_hidden = _sample_uniform_sphere_points(M)
# test.train(data_train, data_valid, inner_weights=W_hidden, inner_bias=b_hidden)

[32m2025-09-22 02:27:35[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_configure_logger[0m:[36m105[0m - [1mModel initialized[0m
[32m2025-09-22 02:27:35[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_prepare_data[0m:[36m150[0m - [1mTraining set: 810 samples, Validation set: 90 samples[0m
[32m2025-09-22 02:27:35[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_prepare_data[0m:[36m153[0m - [1mData ranges - x: [-3.00, 3.00], v: [0.00, 10.96], dv: [-13.19, 13.13][0m


In [15]:
import torch
from src.ssn import SSN
from src.ssn_tr import SSN_TR

# Build net once if needed
if test.net is None:
    test._create_network(inner_weights=W_hidden, inner_bias=b_hidden)

# Build optimizer once if needed
if test.optimizer is None:
    test._setup_optimizer()

# Define closure for the current data tensors
train_x_tensor, train_v_tensor, train_dv_tensor = data_train
def closure():
    if isinstance(test.optimizer, (SSN, SSN_TR)):
        with torch.no_grad():
            _, hidden_activations = test.net.forward_with_hidden(train_x_tensor.detach())
        test.optimizer.hidden_activations = hidden_activations.detach()
    total_loss, _, _ = test._compute_loss(train_x_tensor, train_v_tensor, train_dv_tensor)
    return total_loss

[32m2025-09-22 02:27:49[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_setup_optimizer[0m:[36m227[0m - [1mUsing SSN_TR optimizer with alpha=1e-05, gamma=5.0, th=0.0[0m


In [16]:
from src.utils import _ddphi

params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()

q = test.optimizer._transform_param2q(params, loss)
Gq = test.optimizer._Gradient(q, params, loss)

grads = torch.autograd.grad(loss, test.optimizer.param_groups[0]["params"], create_graph=True, retain_graph=True)
grad_flat = torch.cat([g.view(-1) for g in grads])
D_nonconvex = torch.sign(params) * (_ddphi(torch.abs(params), test.th, test.gamma) - 1)

lhs = Gq
rhs = test.optimizer.c * (q - params) + test.alpha * D_nonconvex + grad_flat
print("||Gq||:", float(torch.norm(lhs)), "||lhs - rhs||:", float(torch.norm(lhs - rhs)))

||Gq||: 3.1245061830789997e-15 ||lhs - rhs||: 0.0


In [17]:
from src.utils import _compute_prox

params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss0 = closure()

mu = test.alpha / test.optimizer.c
q = test.optimizer._transform_param2q(params, loss0)
unew = _compute_prox(q, mu)

vals = []
ts = torch.linspace(0, 1, steps=11)
for t in ts:
    u_t = params*(1 - t) + unew*t
    backup = params.clone()
    test.optimizer._update_parameters(u_t)
    vals.append(float(closure()))
    test.optimizer._update_parameters(backup)

print("loss at t grid from params->prox(q):")
for t, v in zip(ts, vals):
    print(f"t={float(t):.1f} loss={v}")
print("Δloss at t=1:", vals[-1] - loss0)

loss at t grid from params->prox(q):
t=0.0 loss=14.882359343530393
t=0.1 loss=2387.4469079444934
t=0.2 loss=10186.331179815632
t=0.3 loss=23411.536390744655
t=0.4 loss=42063.058934943285
t=0.5 loss=66140.90045274328
t=0.6 loss=95645.07047727126
t=0.7 loss=130575.54022679082
t=0.8 loss=170932.35134168164
t=0.9 loss=216715.4557128873
t=1.0 loss=267924.9079181408
Δloss at t=1: tensor(267910.0256, dtype=torch.float64, grad_fn=<RsubBackward1>)


In [18]:
from src.utils import _compute_dprox, _compute_prox

# Freeze current state
params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()

# Use q = params to avoid algebraic cancellation
q = params.clone().detach()
Gq = test.optimizer._Gradient(q, params, loss)          # now ≈ alpha*D_nonconvex + grad_flat
DG = test.optimizer._Hessian(q, params, loss)
mu = test.alpha / test.optimizer.c
DP = _compute_dprox(q, mu)

# One MPCG step
from src.mpcg import mpcg
I_active = (torch.diagonal(DP) != 0)
kmaxit = max(1, int(2 * I_active.sum().item()))
sigma = test.optimizer.sigma if isinstance(test.optimizer, (SSN_TR,)) else 0.0

dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)
print("mpcg: flag", flag, "pred", pred, "relres", relres, "iters", iters, "||dq||", float(torch.norm(dq)))

# Try a backtracking grid along dq: q_new = q + t*dq, u_new = prox(q_new)
ts = [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125]
best = (None, float('inf'))
for t in ts:
    q_t = q + t * dq
    u_t = _compute_prox(q_t, mu)
    backup = params.clone()
    test.optimizer._update_parameters(u_t)
    loss_t = float(closure())
    test.optimizer._update_parameters(backup)
    print(f"t={t:.5f} loss={loss_t} Δ={loss_t - float(loss)}")
    if loss_t < best[1]:
        best = (t, loss_t)
print("best t:", best[0], "best loss:", best[1], "Δbest:", best[1] - float(loss))

mpcg: flag radius pred -27.3105670419037 relres 0.1519114752983517 iters 3 ||dq|| 1.0
t=1.00000 loss=13.68151188778112 Δ=-1.2008474557492725
t=0.50000 loss=1.2271700255679998 Δ=-13.655189317962392
t=0.25000 loss=4.790326326920983 Δ=-10.09203301660941
t=0.12500 loss=9.019587950729846 Δ=-5.862771392800546
t=0.06250 loss=11.746075485041867 Δ=-3.1362838584885253
t=0.03125 loss=13.262394543528318 Δ=-1.6199648000020748
best t: 0.5 best loss: 1.2271700255679998 Δbest: -13.655189317962392


In [19]:
print("sigma:", float(getattr(test.optimizer, "sigma", 0.0)))
DP = __import__("src.utils", fromlist=["_compute_dprox"])._compute_dprox(
    torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]]),
    test.alpha / test.optimizer.c
)
print("active (DP>0):", int((torch.diagonal(DP) > 0).sum().item()))

sigma: 1.0
active (DP>0): 50


In [20]:
from src.utils import _compute_dprox, _compute_prox, _ddphi
from src.mpcg import mpcg
from src.ssn import SSN
from src.ssn_tr import SSN_TR

train_x_tensor, train_v_tensor, train_dv_tensor = data_train

def closure():
    if isinstance(test.optimizer, (SSN, SSN_TR)):
        with torch.no_grad():
            _, S = test.net.forward_with_hidden(train_x_tensor.detach())
        test.optimizer.hidden_activations = S.detach()
    total_loss, _, _ = test._compute_loss(train_x_tensor, train_v_tensor, train_dv_tensor)
    return total_loss

# Freeze state
params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()

# q = params
q = params.clone().detach()
Gq = test.optimizer._Gradient(q, params, loss)
DG = test.optimizer._Hessian(q, params, loss)
mu = test.alpha / test.optimizer.c
DP = _compute_dprox(q, mu)

I_active = (torch.diagonal(DP) != 0)
kmaxit = max(1, int(2 * I_active.sum().item()))
sigma = getattr(test.optimizer, "sigma", 0.0)

dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)
print("mpcg: flag", flag, "pred", pred, "relres", relres, "iters", iters, "||dq||", float(torch.norm(dq)))

# Backtracking along dq
ts = [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125]
best = (None, float('inf'))
for t in ts:
    q_t = q + t * dq
    u_t = _compute_prox(q_t, mu)
    backup = params.clone()
    test.optimizer._update_parameters(u_t)
    loss_t = float(closure())
    test.optimizer._update_parameters(backup)
    print(f"t={t:.5f} loss={loss_t} Δ={loss_t - float(loss)}")
    if loss_t < best[1]:
        best = (t, loss_t)
print("best t:", best[0], "best loss:", best[1], "Δbest:", best[1] - float(loss))

mpcg: flag radius pred -27.3105670419037 relres 0.1519114752983517 iters 3 ||dq|| 1.0
t=1.00000 loss=13.68151188778112 Δ=-1.2008474557492725
t=0.50000 loss=1.2271700255679998 Δ=-13.655189317962392
t=0.25000 loss=4.790326326920983 Δ=-10.09203301660941
t=0.12500 loss=9.019587950729846 Δ=-5.862771392800546
t=0.06250 loss=11.746075485041867 Δ=-3.1362838584885253
t=0.03125 loss=13.262394543528318 Δ=-1.6199648000020748
best t: 0.5 best loss: 1.2271700255679998 Δbest: -13.655189317962392


In [21]:
from copy import deepcopy

params0 = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss0 = float(closure())
print("start loss:", loss0)

mu = test.alpha / test.optimizer.c
history = [loss0]

for k in range(10):  # few iterations to check trend
    params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
    loss = closure()
    q = params.clone().detach()
    Gq = test.optimizer._Gradient(q, params, loss)
    DG = test.optimizer._Hessian(q, params, loss)
    DP = _compute_dprox(q, mu)
    I_active = (torch.diagonal(DP) != 0)
    kmaxit = max(1, int(2 * I_active.sum().item()))
    sigma = getattr(test.optimizer, "sigma", 0.0)

    dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)

    # simple backtracking
    ts = [1.0, 0.5, 0.25, 0.125, 0.0625]
    chosen = None
    best_loss = float('inf')
    best_u = None
    for t in ts:
        q_t = q + t * dq
        u_t = _compute_prox(q_t, mu)
        backup = params.clone()
        test.optimizer._update_parameters(u_t)
        loss_t = float(closure())
        test.optimizer._update_parameters(backup)
        if loss_t < best_loss:
            best_loss = loss_t
            chosen = t
            best_u = u_t.clone()

    test.optimizer._update_parameters(best_u)
    history.append(best_loss)
    print(f"iter {k}: flag={flag} t={chosen} loss={best_loss}")

print("loss trend:", history)
# restore original if you don't want to keep the change:
# test.optimizer._update_parameters(params0)

start loss: 14.882359343530393
iter 0: flag=radius t=0.5 loss=1.2271700255679998
iter 1: flag=radius t=0.5 loss=0.15885395032724095
iter 2: flag=radius t=0.5 loss=0.11689316853411141
iter 3: flag=radius t=0.5 loss=0.0790634517740953
iter 4: flag=radius t=0.5 loss=0.049637269853174314
iter 5: flag=radius t=0.5 loss=0.0369527069684916
iter 6: flag=radius t=0.5 loss=0.030293802755239747
iter 7: flag=radius t=0.5 loss=0.02767588239475441
iter 8: flag=radius t=0.5 loss=0.0249693121558565
iter 9: flag=radius t=0.5 loss=0.023130127942680948
loss trend: [14.882359343530393, 1.2271700255679998, 0.15885395032724095, 0.11689316853411141, 0.0790634517740953, 0.049637269853174314, 0.0369527069684916, 0.030293802755239747, 0.02767588239475441, 0.0249693121558565, 0.023130127942680948]


In [22]:
from src.utils import _ddphi, _compute_prox, _compute_dprox

params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()
grads = torch.autograd.grad(loss, test.optimizer.param_groups[0]["params"], create_graph=True, retain_graph=True)
grad_flat = torch.cat([g.view(-1) for g in grads])
D_nonconvex = torch.sign(params) * (_ddphi(torch.abs(params), test.th, test.gamma) - 1)
rterm = test.alpha * D_nonconvex + grad_flat
mu = test.alpha / test.optimizer.c

for beta in [0.0, 0.25, 0.5, 0.75, 1.0]:
    q = params - (beta / test.optimizer.c) * rterm
    Gq = test.optimizer._Gradient(q, params, loss)
    DG = test.optimizer._Hessian(q, params, loss)
    DP = _compute_dprox(q, mu)
    I_active = (torch.diagonal(DP) != 0)
    kmaxit = max(1, int(2 * I_active.sum().item()))
    sigma = getattr(test.optimizer, "sigma", 0.0)
    dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)

    # evaluate best t along dq
    ts = [1.0, 0.5, 0.25, 0.125]
    best_loss = float('inf')
    for t in ts:
        q_t = q + t * dq
        u_t = _compute_prox(q_t, mu)
        backup = params.clone()
        test.optimizer._update_parameters(u_t)
        loss_t = float(closure())
        test.optimizer._update_parameters(backup)
        best_loss = min(best_loss, loss_t)

    print(f"beta={beta:.2f} ||Gq||={float(torch.norm(Gq)):.3e} flag={flag} best_loss={best_loss} Δ={best_loss - float(loss):.3e}")

beta=0.00 ||Gq||=4.149e-02 flag=radius best_loss=0.02125751619929503 Δ=-1.873e-03
beta=0.25 ||Gq||=3.112e-02 flag=radius best_loss=0.028915530467504266 Δ=5.785e-03
beta=0.50 ||Gq||=2.074e-02 flag=radius best_loss=0.052909690238746886 Δ=2.978e-02
beta=0.75 ||Gq||=1.037e-02 flag=radius best_loss=0.09391576295462797 Δ=7.079e-02
beta=1.00 ||Gq||=1.391e-16 flag=maxitr best_loss=0.15519202462791862 Δ=1.321e-01


## SSN(line-search) for the outer weights ##

In [None]:
vdp_model = model_outerweights(data, torch.relu, 2.0, regularization, 'SSN', loss_weights=(1.0, 0.0))
model_result, weight, bias, output_weight = vdp_model.train(
    inner_weights=weights, inner_bias=bias, outer_weights=outer_weights,
    iterations=10000,
    display_every=500
)

## SSN(trust-region) for the outer weights

In [None]:
vdp_model = model_outerweights(data, torch.relu, 2.0, regularization, 'SSN_TR', loss_weights=(1.0, 0.0))
model_result, weight, bias, output_weight = vdp_model.train(
    inner_weights=weights, inner_bias=bias, outer_weights=outer_weights,
    iterations=10000,
    display_every=500
)