In [1]:
import sys
import os
# Add the src directory to Python path so model.py can find ssn and net modules
sys.path.append(os.path.abspath('../src'))
import numpy as np
from loguru import logger
import torch

Load the data that is generated from the open-loop optimization

In [2]:
# load the data
path = '../data_result/raw_data/VDP_beta_0.1_grid_30x30.npy'# Initialize the weights
data = np.load(path)
logger.info(f"Loaded data with shape: {data.shape}, dtype: {data.dtype}")

[32m2025-09-22 10:43:23.987[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mLoaded data with shape: (900,), dtype: [('x', '<f8', (2,)), ('dv', '<f8', (2,)), ('v', '<f8')][0m


## SSN(trust-region) method for outer weights ##


In [3]:
# Initialize the parameter
power = 2.1
M = 50 # number greedy insertion selected
num_iterations = 10
loss_weights = (1.0, 0.0)
pruning_threshold = 1e-15

gamma = 5.0
alpha = 1e-5
lr_adam = 1e-5
regularization = (gamma, alpha) 
th = 0.0

In [None]:
from src.model import model
from src.greedy_insertion import _sample_uniform_sphere_points

test = model(activation=torch.relu, power=power, regularization=regularization, optimizer='SSN_TR', loss_weights=loss_weights, th=th, train_outerweights=True)

# prepare the data
data_train, data_valid = test._prepare_data(data)

W_hidden, b_hidden = _sample_uniform_sphere_points(M)
test.train(data_train, data_valid, inner_weights=W_hidden, inner_bias=b_hidden)

[32m2025-09-22 10:47:29[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_configure_logger[0m:[36m105[0m - [1mModel initialized[0m


[32m2025-09-22 10:47:29[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_prepare_data[0m:[36m150[0m - [1mTraining set: 810 samples, Validation set: 90 samples[0m
[32m2025-09-22 10:47:29[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_prepare_data[0m:[36m153[0m - [1mData ranges - x: [-3.00, 3.00], v: [0.00, 10.96], dv: [-13.19, 13.13][0m
[32m2025-09-22 10:47:29[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m307[0m - [1mStarting network training session[0m
[32m2025-09-22 10:47:29[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_setup_optimizer[0m:[36m227[0m - [1mUsing SSN_TR optimizer with alpha=1e-05, gamma=5.0, th=0.0[0m
[32m2025-09-22 10:47:29[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m322[0m - [1mTraining hyperparameters: iterations=5000, batch_size=1620, display_every=1000[0m
[32m2025-09-22 10:47:29[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m323[0m - [1mLoss weights: value=1.0, gradient=0.0[

In [20]:
test_ls = model(activation=torch.relu, power=power, regularization=regularization, optimizer='SSN', loss_weights=loss_weights, th=th, train_outerweights=True)
test.train(data_train, data_valid, inner_weights=W_hidden, inner_bias=b_hidden)

[32m2025-09-22 10:50:30[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_configure_logger[0m:[36m105[0m - [1mModel initialized[0m
[32m2025-09-22 10:50:30[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m307[0m - [1mStarting network training session[0m
[32m2025-09-22 10:50:31[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_setup_optimizer[0m:[36m227[0m - [1mUsing SSN_TR optimizer with alpha=1e-05, gamma=5.0, th=0.0[0m
[32m2025-09-22 10:50:31[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m322[0m - [1mTraining hyperparameters: iterations=5000, batch_size=1620, display_every=1000[0m
[32m2025-09-22 10:50:31[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m323[0m - [1mLoss weights: value=1.0, gradient=0.0[0m
[32m2025-09-22 10:50:31[0m | [1mINFO    [0m | [36msrc.model[0m:[36mtrain[0m:[36m362[0m - [1mEpoch 0: Train Loss = 16.611090, Val Loss = 59.024638[0m
[32m2025-09-22 10:50:40[0m | [1mINFO    [0m | [3

## Test Cases ##

In [5]:
import torch
from src.ssn import SSN
from src.ssn_tr import SSN_TR

# Build net once if needed
if test.net is None:
    test._create_network(inner_weights=W_hidden, inner_bias=b_hidden)

# Build optimizer once if needed
if test.optimizer is None:
    test._setup_optimizer()

# Define closure for the current data tensors
train_x_tensor, train_v_tensor, train_dv_tensor = data_train
def closure():
    if isinstance(test.optimizer, (SSN, SSN_TR)):
        with torch.no_grad():
            _, hidden_activations = test.net.forward_with_hidden(train_x_tensor.detach())
        test.optimizer.hidden_activations = hidden_activations.detach()
    total_loss, _, _ = test._compute_loss(train_x_tensor, train_v_tensor, train_dv_tensor)
    return total_loss

[32m2025-09-22 10:43:52[0m | [1mINFO    [0m | [36msrc.model[0m:[36m_setup_optimizer[0m:[36m227[0m - [1mUsing SSN_TR optimizer with alpha=1e-05, gamma=5.0, th=0.0[0m


In [10]:
from src.utils import _ddphi

params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()

q = test.optimizer._transform_param2q(params, loss)
Gq = test.optimizer._Gradient(q, params, loss)

grads = torch.autograd.grad(loss, test.optimizer.param_groups[0]["params"], create_graph=True, retain_graph=True)
grad_flat = torch.cat([g.view(-1) for g in grads])
D_nonconvex = torch.sign(params) * (_ddphi(torch.abs(params), test.th, test.gamma) - 1)

lhs = Gq
rhs = test.optimizer.c * (q - params) + test.alpha * D_nonconvex + grad_flat
print("||Gq||:", float(torch.norm(lhs)), "||lhs - rhs||:", float(torch.norm(lhs - rhs)))

||Gq||: 136.8681762444426 ||lhs - rhs||: 0.0


In [11]:
from src.utils import _compute_prox

params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss0 = closure()

mu = test.alpha / test.optimizer.c
q = test.optimizer._transform_param2q(params, loss0)
unew = _compute_prox(q, mu)

vals = []
ts = torch.linspace(0, 1, steps=11)
for t in ts:
    u_t = params*(1 - t) + unew*t
    backup = params.clone()
    test.optimizer._update_parameters(u_t)
    vals.append(float(closure()))
    test.optimizer._update_parameters(backup)

print("loss at t grid from params->prox(q):")
for t, v in zip(ts, vals):
    print(f"t={float(t):.1f} loss={v}")
print("Δloss at t=1:", vals[-1] - loss0)

loss at t grid from params->prox(q):
t=0.0 loss=38.21510561321624
t=0.1 loss=38.21510492660367
t=0.2 loss=38.21510607095796
t=0.3 loss=38.21510561321624
t=0.4 loss=38.21510652869967
t=0.5 loss=38.21510561321624
t=0.6 loss=38.21510561321624
t=0.7 loss=38.215105613216245
t=0.8 loss=38.215105613216245
t=0.9 loss=38.215105613216245
t=1.0 loss=38.215105613216245
Δloss at t=1: tensor(7.1054e-15, dtype=torch.float64, grad_fn=<RsubBackward1>)


In [12]:
from src.utils import _compute_dprox, _compute_prox

# Freeze current state
params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()

# Use q = params to avoid algebraic cancellation
q = params.clone().detach()
Gq = test.optimizer._Gradient(q, params, loss)          # now ≈ alpha*D_nonconvex + grad_flat
DG = test.optimizer._Hessian(q, params, loss)
mu = test.alpha / test.optimizer.c
DP = _compute_dprox(q, mu)

# One MPCG step
from src.mpcg import mpcg
I_active = (torch.diagonal(DP) != 0)
kmaxit = max(1, int(2 * I_active.sum().item()))
sigma = test.optimizer.sigma if isinstance(test.optimizer, (SSN_TR,)) else 0.0

dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)
print("mpcg: flag", flag, "pred", pred, "relres", relres, "iters", iters, "||dq||", float(torch.norm(dq)))

# Try a backtracking grid along dq: q_new = q + t*dq, u_new = prox(q_new)
ts = [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125]
best = (None, float('inf'))
for t in ts:
    q_t = q + t * dq
    u_t = _compute_prox(q_t, mu)
    backup = params.clone()
    test.optimizer._update_parameters(u_t)
    loss_t = float(closure())
    test.optimizer._update_parameters(backup)
    print(f"t={t:.5f} loss={loss_t} Δ={loss_t - float(loss)}")
    if loss_t < best[1]:
        best = (t, loss_t)
print("best t:", best[0], "best loss:", best[1], "Δbest:", best[1] - float(loss))

mpcg: flag radius pred -71.33386699443048 relres 0.1512846050380833 iters 1 ||dq|| 0.9999999999999999
t=1.00000 loss=32.41331244223033 Δ=-5.80179317098591
t=0.50000 loss=2.5480531923313716 Δ=-35.667052420884865
t=0.25000 loss=12.189804984138405 Δ=-26.025300629077833
t=0.12500 loss=23.15368325820984 Δ=-15.0614223550064
t=0.06250 loss=30.17189165504805 Δ=-8.043213958168188
t=0.03125 loss=34.06452212180356 Δ=-4.150583491412675
best t: 0.5 best loss: 2.5480531923313716 Δbest: -35.667052420884865


In [13]:
print("sigma:", float(getattr(test.optimizer, "sigma", 0.0)))
DP = __import__("src.utils", fromlist=["_compute_dprox"])._compute_dprox(
    torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]]),
    test.alpha / test.optimizer.c
)
print("active (DP>0):", int((torch.diagonal(DP) > 0).sum().item()))

sigma: 1.0
active (DP>0): 50


In [14]:
from src.utils import _compute_dprox, _compute_prox, _ddphi
from src.mpcg import mpcg
from src.ssn import SSN
from src.ssn_tr import SSN_TR

train_x_tensor, train_v_tensor, train_dv_tensor = data_train

def closure():
    if isinstance(test.optimizer, (SSN, SSN_TR)):
        with torch.no_grad():
            _, S = test.net.forward_with_hidden(train_x_tensor.detach())
        test.optimizer.hidden_activations = S.detach()
    total_loss, _, _ = test._compute_loss(train_x_tensor, train_v_tensor, train_dv_tensor)
    return total_loss

# Freeze state
params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()

# q = params
q = params.clone().detach()
Gq = test.optimizer._Gradient(q, params, loss)
DG = test.optimizer._Hessian(q, params, loss)
mu = test.alpha / test.optimizer.c
DP = _compute_dprox(q, mu)

I_active = (torch.diagonal(DP) != 0)
kmaxit = max(1, int(2 * I_active.sum().item()))
sigma = getattr(test.optimizer, "sigma", 0.0)

dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)
print("mpcg: flag", flag, "pred", pred, "relres", relres, "iters", iters, "||dq||", float(torch.norm(dq)))

# Backtracking along dq
ts = [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03125]
best = (None, float('inf'))
for t in ts:
    q_t = q + t * dq
    u_t = _compute_prox(q_t, mu)
    backup = params.clone()
    test.optimizer._update_parameters(u_t)
    loss_t = float(closure())
    test.optimizer._update_parameters(backup)
    print(f"t={t:.5f} loss={loss_t} Δ={loss_t - float(loss)}")
    if loss_t < best[1]:
        best = (t, loss_t)
print("best t:", best[0], "best loss:", best[1], "Δbest:", best[1] - float(loss))

mpcg: flag radius pred -71.33386699443048 relres 0.1512846050380833 iters 1 ||dq|| 0.9999999999999999
t=1.00000 loss=32.41331244223033 Δ=-5.80179317098591
t=0.50000 loss=2.5480531923313716 Δ=-35.667052420884865
t=0.25000 loss=12.189804984138405 Δ=-26.025300629077833
t=0.12500 loss=23.15368325820984 Δ=-15.0614223550064
t=0.06250 loss=30.17189165504805 Δ=-8.043213958168188
t=0.03125 loss=34.06452212180356 Δ=-4.150583491412675
best t: 0.5 best loss: 2.5480531923313716 Δbest: -35.667052420884865


In [15]:
from copy import deepcopy

params0 = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss0 = float(closure())
print("start loss:", loss0)

mu = test.alpha / test.optimizer.c
history = [loss0]

for k in range(10):  # few iterations to check trend
    params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
    loss = closure()
    q = params.clone().detach()
    Gq = test.optimizer._Gradient(q, params, loss)
    DG = test.optimizer._Hessian(q, params, loss)
    DP = _compute_dprox(q, mu)
    I_active = (torch.diagonal(DP) != 0)
    kmaxit = max(1, int(2 * I_active.sum().item()))
    sigma = getattr(test.optimizer, "sigma", 0.0)

    dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)

    # simple backtracking
    ts = [1.0, 0.5, 0.25, 0.125, 0.0625]
    chosen = None
    best_loss = float('inf')
    best_u = None
    for t in ts:
        q_t = q + t * dq
        u_t = _compute_prox(q_t, mu)
        backup = params.clone()
        test.optimizer._update_parameters(u_t)
        loss_t = float(closure())
        test.optimizer._update_parameters(backup)
        if loss_t < best_loss:
            best_loss = loss_t
            chosen = t
            best_u = u_t.clone()

    test.optimizer._update_parameters(best_u)
    history.append(best_loss)
    print(f"iter {k}: flag={flag} t={chosen} loss={best_loss}")

print("loss trend:", history)
# restore original if you don't want to keep the change:
# test.optimizer._update_parameters(params0)

start loss: 38.21510561321624
iter 0: flag=radius t=0.5 loss=2.5480531923313716
iter 1: flag=radius t=0.5 loss=0.15064699374156462
iter 2: flag=radius t=0.5 loss=0.10947923170336686
iter 3: flag=radius t=0.5 loss=0.07109906685868055
iter 4: flag=radius t=0.5 loss=0.05296490902586828
iter 5: flag=radius t=0.5 loss=0.04083015756411104
iter 6: flag=radius t=1.0 loss=0.03432765462993657
iter 7: flag=radius t=0.5 loss=0.027993926795296657
iter 8: flag=radius t=0.5 loss=0.025564907746785687
iter 9: flag=radius t=0.5 loss=0.02356511665798449
loss trend: [38.21510561321624, 2.5480531923313716, 0.15064699374156462, 0.10947923170336686, 0.07109906685868055, 0.05296490902586828, 0.04083015756411104, 0.03432765462993657, 0.027993926795296657, 0.025564907746785687, 0.02356511665798449]


In [16]:
from src.utils import _ddphi, _compute_prox, _compute_dprox

params = torch.cat([p.view(-1) for p in test.optimizer.param_groups[0]["params"]])
loss = closure()
grads = torch.autograd.grad(loss, test.optimizer.param_groups[0]["params"], create_graph=True, retain_graph=True)
grad_flat = torch.cat([g.view(-1) for g in grads])
D_nonconvex = torch.sign(params) * (_ddphi(torch.abs(params), test.th, test.gamma) - 1)
rterm = test.alpha * D_nonconvex + grad_flat
mu = test.alpha / test.optimizer.c

for beta in [0.0, 0.25, 0.5, 0.75, 1.0]:
    q = params - (beta / test.optimizer.c) * rterm
    Gq = test.optimizer._Gradient(q, params, loss)
    DG = test.optimizer._Hessian(q, params, loss)
    DP = _compute_dprox(q, mu)
    I_active = (torch.diagonal(DP) != 0)
    kmaxit = max(1, int(2 * I_active.sum().item()))
    sigma = getattr(test.optimizer, "sigma", 0.0)
    dq, flag, pred, relres, iters = mpcg(DG, -Gq, 1e-3, kmaxit, sigma, DP)

    # evaluate best t along dq
    ts = [1.0, 0.5, 0.25, 0.125]
    best_loss = float('inf')
    for t in ts:
        q_t = q + t * dq
        u_t = _compute_prox(q_t, mu)
        backup = params.clone()
        test.optimizer._update_parameters(u_t)
        loss_t = float(closure())
        test.optimizer._update_parameters(backup)
        best_loss = min(best_loss, loss_t)

    print(f"beta={beta:.2f} ||Gq||={float(torch.norm(Gq)):.3e} flag={flag} best_loss={best_loss} Δ={best_loss - float(loss):.3e}")

beta=0.00 ||Gq||=1.150e-02 flag=radius best_loss=0.02184191053258934 Δ=-1.723e-03
beta=0.25 ||Gq||=8.627e-03 flag=radius best_loss=0.021798467732128454 Δ=-1.767e-03
beta=0.50 ||Gq||=5.751e-03 flag=radius best_loss=0.02193777233292615 Δ=-1.627e-03
beta=0.75 ||Gq||=2.876e-03 flag=radius best_loss=0.023128094518807398 Δ=-4.370e-04
beta=1.00 ||Gq||=1.350e-16 flag=maxitr best_loss=0.028410341204405783 Δ=4.845e-03
