# PLnet for Rosenbrock

This notebook contains a tutorial on training Lipschitz-bounded deep networks with PLnet from [Wang & Dvijotham & Manchester (2024)](https://arxiv.org/html/2402.01344v2). We'll demonstrate how to train PL networks on the [Rosenbrock](https://en.wikipedia.org/wiki/Rosenbrock_function) function.

This notebook serves as an initial idea/simple example of training PLnet and implement the inverse. Better example might be provided later on. 

## dependency
The plnet in torch with version is mostly based on numpy and pytorch.

In [1]:
import os
from typing import Union, Dict

import torch
import torch.nn as nn
import torch.optim as optim
import scipy.io

import torch.nn.functional as F
import numpy as np
from flax import linen 

device = 'cuda' if torch.cuda.is_available() else 'cpu'



## Data generation
code to generate Rosenbrock data. 

In [2]:
def Rosenbrock(x: torch.Tensor) -> torch.Tensor:
    """
    Rosenbrock function for torch tensors.
    """
    f = lambda x, y: (x - 1.) ** 2 / 200. + 0.5 * (y - x ** 2) ** 2

    single = x.ndim == 1
    if single:
        x = x.unsqueeze(0)

    N = x.shape[-1]
    y = torch.stack([f(x[..., i], x[..., i+1]) for i in range(N - 1)], dim=1)
    y = torch.mean(y, dim=1)

    if single:
        y = y.squeeze(0)

    return y


def Sampler(
        batches: int,
        data_dim: int,
        x_min: Union[float, torch.Tensor] = -2.,
        x_max: Union[float, torch.Tensor] = 2.,
        device: str = "cpu"
) -> torch.Tensor:
    """
    Uniform sampler similar to JAX random.uniform.
    """
    return (x_max - x_min) * torch.rand((batches, data_dim), device=device) + x_min


def data_gen(
    data_dim: int = 20,
    val_min: float = -2.,
    val_max: float = 2.,
    train_batch_size: int = 200,
    test_batch_size: int = 5000,
    train_batches: int = 200,
    test_batches: int = 1,
    eval_batch_size: int = 5000,
    eval_batches: int = 100,
    device: str = "cpu"
) -> Dict[str, torch.Tensor]:
    """
    Generate synthetic train/test/eval datasets for the Rosenbrock function.
    """
    # Sample points
    xtrain = Sampler(train_batch_size * train_batches, data_dim, x_min=val_min, x_max=val_max, device=device)
    xtest  = Sampler(test_batch_size * test_batches, data_dim, x_min=val_min, x_max=val_max, device=device)
    xeval  = Sampler(eval_batch_size * eval_batches, data_dim, x_min=val_min, x_max=val_max, device=device)

    # Compute labels
    ytrain, ytest, yeval = Rosenbrock(xtrain), Rosenbrock(xtest), Rosenbrock(xeval)

    data = {
        "xtrain": xtrain,
        "ytrain": ytrain,
        "xtest": xtest,
        "ytest": ytest,
        "xeval": xeval,
        "yeval": yeval,
        "train_batches": train_batches,
        "train_batch_size": train_batch_size,
        "test_batches": test_batches,
        "test_batch_size": test_batch_size,
        "eval_batches": eval_batches,
        "eval_batch_size": eval_batch_size,
        "data_dim": data_dim
    }

    return data


### Model & Data Configuration

Configure the data dimension and learning parameters


In [3]:
data_dim = 20
lr_max = 1e-2
epochs = 100
n_batch = 50
depth = 2 
layer_size = [256]*8
tau=2


name = 'BiLipNet_torch'
root_dir =  f'{os.getcwd()}/plnet/results_exp/{name}-rosenbrock-dim{data_dim}-batch{n_batch}'

In [4]:
data= data_gen(data_dim, train_batches=n_batch, eval_batch_size=500,eval_batches=5)

# visualize the data
print(data)
print(data['xtrain'].shape)

{'xtrain': tensor([[-1.2666, -1.0204,  0.7798,  ...,  0.6958,  0.9287,  1.5379],
        [-1.0117, -0.8180,  1.7388,  ..., -0.1832,  0.7874,  1.9441],
        [-0.8629,  1.5679, -0.6890,  ..., -0.7470,  0.1682, -1.8098],
        ...,
        [-1.3806, -1.3370,  0.8478,  ...,  0.7322, -1.3290,  0.5042],
        [ 0.2738, -0.8625, -1.2013,  ...,  0.5024, -0.0582, -0.0217],
        [-0.5152, -0.0226,  0.2163,  ..., -0.2501,  1.0713, -1.4838]]), 'ytrain': tensor([1.2955, 1.0681, 1.5642,  ..., 4.1177, 2.3219, 0.9298]), 'xtest': tensor([[ 1.8743, -0.4425,  0.8785,  ..., -1.6351, -0.7958,  1.5558],
        [-1.0153, -0.0483, -0.7021,  ..., -0.6903,  1.0848, -0.8744],
        [ 1.3316, -1.2648, -1.2832,  ...,  0.2153,  1.0958,  1.3890],
        ...,
        [ 1.9334, -0.7758,  1.1866,  ..., -1.4052,  0.6648,  0.8666],
        [ 0.9033,  1.8937,  1.6830,  ..., -1.1012, -1.5068, -0.5422],
        [ 0.9079, -0.7346,  1.8269,  ..., -1.2821, -1.1292, -1.1362]]), 'ytest': tensor([1.9030, 1.1314, 2.7

# Train
train model for rosenbrock function

### train function
define train function with configurable loss function

In [5]:
def train_with_flexible_loss(
    model: nn.Module,
    data: dict,
    fitness_func,  # function(model, x, y) -> loss
    fitness_eval_func,  # function(model, x, y) -> eval_loss
    name: str = 'bilipnet',
    train_dir: str = './results/rosenbrock-nd',
    lr_max: float = 1e-3,
    epochs: int = 600,
    figure_generation_function=lambda model, epoch: (model, epoch),
    figure_generation_period=50,
    device: str = 'cpu',
):
    os.makedirs(train_dir, exist_ok=True)
    ckpt_dir = os.path.join(train_dir, 'ckpt')
    os.makedirs(ckpt_dir, exist_ok=True)

    train_batches = data['train_batches']
    train_batch_size = data['train_batch_size']

    train_size = train_batches * train_batch_size
    idx_shp = (train_batches, train_batch_size)

    model.to(device)
    model.train()

    # Count parameters
    param_count = sum(p.numel() for p in model.parameters())
    print(f'model: {name}, size: {param_count/1e6:.2f}M')

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr_max)
    # Optionally: implement a linear schedule manually if needed
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=lr_max, total_steps=train_batches*epochs, pct_start=0.25
    )

    train_loss, val_loss = [], []
    Lipmin, Lipmax, Tau = [], [], []

    xtrain = data['xtrain'].to(device)
    ytrain = data['ytrain'].to(device)
    xtest = data['xtest'].to(device)
    ytest = data['ytest'].to(device)
    xeval = data['xeval'].to(device)
    yeval = data['yeval'].to(device)

    for epoch in range(epochs):
        # Shuffle indices
        idx = torch.randperm(train_size, device=device).view(idx_shp)
        tloss = 0.

        for b in range(train_batches):
            x = xtrain[idx[b, :], :]
            y = ytrain[idx[b, :]]

            optimizer.zero_grad()
            loss = fitness_func(model, x, y)
            loss.backward()
            optimizer.step()
            scheduler.step()  # update LR if using scheduler

            tloss += loss.item()

        tloss /= train_batches
        train_loss.append(tloss)

        # Validation loss
        model.eval()
        with torch.no_grad():
            vloss = fitness_eval_func(model, xtest, ytest)
        val_loss.append(vloss.item())
        model.train()

        # Example: get bounds if model has such method
        if hasattr(model, 'get_bounds'):
            lipmin, lipmax, tau = model.get_bounds()
            Lipmin.append(lipmin)
            Lipmax.append(lipmax)
            Tau.append(tau)
        else:
            Lipmin.append(0.0)
            Lipmax.append(0.0)
            Tau.append(0.0)

        if epoch % figure_generation_period == 0:
            figure_generation_function(model, epoch)

        print(f'Epoch: {epoch+1:3d} | loss: {tloss:.3f}/{vloss:.3f}, tau: {Tau[-1]:.3f}, Lip: {Lipmin[-1]:.3f}/{Lipmax[-1]:.2f}')

    # Eval loss
    model.eval()
    with torch.no_grad():
        eloss = fitness_eval_func(model, xeval, yeval)
    print(f'{name}: eval loss: {eloss:.3f}')

    # Save metrics to data dict
    data['train_loss'] = torch.tensor(train_loss)
    data['val_loss'] = torch.tensor(val_loss)
    data['lipmin'] = torch.tensor(Lipmin)
    data['lipmax'] = torch.tensor(Lipmax)
    data['tau'] = torch.tensor(Tau)
    data['eval_loss'] = torch.tensor(eloss)

    scipy.io.savemat(os.path.join(train_dir, 'data.mat'), data)

    # Save model checkpoint
    torch.save(model.state_dict(), os.path.join(ckpt_dir, 'params.pt'))


### Loss Configuration
Define loss function. Here we keep if as simple. You can choose MSE for regresion or entropy loss for classification. Here we use MSE for Rosbenbrock function.

In [6]:
def get_fitness_loss(
    optimal_point: torch.Tensor = None,
    threshold_value: float = 0.5,
    is_optimal: bool = False,
    is_regression: bool = True,   # choose loss type
):
    """
    Returns a loss function compatible with PyTorch models.
    """

    def fitloss(model, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Compute loss for a batch of inputs x and targets y.
        """
        # Apply the model
        if is_optimal and optimal_point is not None:
            yh = model(x, optimal_point)  # model should handle optional extra argument
        else:
            yh = model(x)

        # Regression or classification
        if is_regression:
            loss = F.mse_loss(yh, y, reduction='mean')
        else:
            # Binary classification with threshold
            # Compute sigmoid BCE loss
            loss = F.binary_cross_entropy_with_logits(yh - threshold_value, y.float(), reduction='mean')

        return loss

    return fitloss


Configure the neural network

In [7]:
from robustnn.plnet_torch.bilipnet import BiLipNet
from robustnn.plnet_torch.plnet import PLNet

train_dir = f'{root_dir}/{name}-{depth}-tau{tau}'
mu = 0.1
nu = 10
block = BiLipNet(data_dim, layer_size, mu=0.1, nu=10, depth=depth, tau=tau, is_tau_fixed=True)
model = PLNet(block)

Train the model

In [None]:
# Define loss function
fitness_func_pl = get_fitness_loss(model)

train_with_flexible_loss( model, data, fitness_func_pl, fitness_eval_func=fitness_func_pl, name=name, 
                         train_dir=train_dir, lr_max=lr_max, epochs=epochs)

model: BiLipNet_torch, size: 2.05M
Epoch:   1 | loss: 0.203/0.159, tau: 0.000, Lip: 0.000/0.00
Epoch:   2 | loss: 0.101/0.100, tau: 0.000, Lip: 0.000/0.00
Epoch:   3 | loss: 0.064/0.083, tau: 0.000, Lip: 0.000/0.00
Epoch:   4 | loss: 0.058/0.088, tau: 0.000, Lip: 0.000/0.00
Epoch:   5 | loss: 0.078/0.114, tau: 0.000, Lip: 0.000/0.00
Epoch:   6 | loss: 0.106/0.159, tau: 0.000, Lip: 0.000/0.00
Epoch:   7 | loss: 0.154/0.165, tau: 0.000, Lip: 0.000/0.00
Epoch:   8 | loss: 0.166/0.195, tau: 0.000, Lip: 0.000/0.00
Epoch:   9 | loss: 0.192/0.165, tau: 0.000, Lip: 0.000/0.00
Epoch:  10 | loss: 0.152/0.205, tau: 0.000, Lip: 0.000/0.00
Epoch:  11 | loss: 0.144/0.138, tau: 0.000, Lip: 0.000/0.00
Epoch:  12 | loss: 0.129/0.134, tau: 0.000, Lip: 0.000/0.00
Epoch:  13 | loss: 0.127/0.136, tau: 0.000, Lip: 0.000/0.00
Epoch:  14 | loss: 0.137/0.131, tau: 0.000, Lip: 0.000/0.00
Epoch:  15 | loss: 0.123/0.131, tau: 0.000, Lip: 0.000/0.00
Epoch:  16 | loss: 0.107/0.104, tau: 0.000, Lip: 0.000/0.00
Epoch

  data['eval_loss'] = torch.tensor(eloss)


Restore the model

In [8]:
ckpt_dir = os.path.join(train_dir, 'ckpt')
ckpt_path = os.path.join(ckpt_dir, 'params.pt')

# Load the saved parameters
if os.path.exists(ckpt_path):
	model.load_state_dict(torch.load(ckpt_path, map_location=device))
	model.eval()  # set model to evaluation mode
	model.to(device)
	print("Model parameters restored successfully.")
else:
    raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")


Model parameters restored successfully.


# Evaluation - Inverse
Inverse the model and compare. We sample 10 points with 20 dimension as z. Run the inverse on it inv(z). Then, we campare the forward of the inverse with the original value z. Ideally, z = PLNet ( inv (z))

Define inverse parameters

In [9]:
max_iter = 800
alpha = 0.1
Lambda = 1

Sample data z

In [10]:
z = Sampler( 10, 20, device=device)

Call inverse

In [11]:
inv_z = (model.bln.inverse(z.numpy(force=True), alphas=[alpha]*depth, 
                                     inverse_activation_fns=[linen.relu]*depth, 
                                     Lambdas=[Lambda]*depth, iterations=[max_iter]*depth))
print(f'inverse z: {inv_z} with shape {inv_z.shape}')

inverse z: [[-0.7266202   0.41303885  0.1398462  -0.0871107   0.84646493  1.520392
   1.1599947   1.3756735   0.8531401   0.60397243  0.4329725   0.33819008
   1.4826813   0.52883506  1.0807513   1.6351691   2.1282651   2.0434139
   3.6237276   6.4443817 ]
 [-1.4208779   1.8727071   1.3566608   1.0139096   1.345627    0.4884618
   1.2530352   0.48576882  1.7153794   0.8822845   0.7155441   0.872139
   0.7797189   2.0450435   1.1782582   0.50292677  1.0747726   1.0052083
  -0.26609278  5.148472  ]
 [ 1.209498    1.8629102   1.485064    0.3392176   0.84985745  0.77795625
   0.08917625 -0.06314468  0.6732873  -1.411096    1.6455653   1.0673857
  -1.0541252   1.3409362   1.2574241   1.8078521   1.4726223   2.2568402
   3.0292487   5.7099743 ]
 [-1.0098727   0.7919287   0.5293421  -1.3831154   0.5609558   1.2761406
   2.048948    1.5843995   1.5466778   0.9357493   0.43859237 -0.35181242
  -0.24067342 -1.3541054   1.2931135   1.346929    1.0012894  -0.05674374
   0.9728743   6.466585  ]
 [-

Check if inverse is correct - Ideally, the output should match z with minimal error

In [13]:
model.eval()  # Set model to evaluation mode
model.to(device)
inv_z_torch = torch.tensor(np.array(inv_z), dtype=torch.float32).to(device)
print(f"device of inv_z_torch: {inv_z_torch.device}")
print(f"device of model: {next(model.parameters()).device}")
with torch.no_grad():
	output = model.bln(inv_z_torch)

diff = torch.norm(z.to(device) - output, dim=1)
print(f'Mean diff between PLNet(inv(z) and z): {diff.mean()} with shape {diff.shape}')

print(diff)

device of inv_z_torch: cuda:0
device of model: cuda:0
Mean diff between PLNet(inv(z) and z): 1.0113153621205129e-05 with shape torch.Size([10])
tensor([7.0828e-06, 1.7102e-05, 1.4842e-05, 6.5778e-06, 9.1296e-06, 8.1695e-06,
        1.0929e-05, 6.7303e-06, 9.0258e-06, 1.1543e-05], device='cuda:0')
