In [1]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import equinox as eqx

import optax

import timeit

In [None]:
sigma2 = 0.5**2

L = 3.0


a = 0.75
b = 0.75

v_max = 1.5
p_max = 1 # rho_max cannot be learned Susana's paper

In [None]:
# Define a Multilayer Perceptron

class MLP(eqx.Module):
    layers: list  # Type annotation indicating that layers is a list. Needed to work properly.


    def __init__(self, key, input_dim, hidden_dims, n_layers, output_dim):
        """Define the MLP

        Args:
            key (jax key): Random Jey
            input_dim (int): Input layer dimension
            hidden_dims (int): Hidden layers dimension
            n_layers (_type_): Number of hidden layers
            output_dim (_type_): Output layer dimension
        """


        
        # jax handle random number generation in a different way. Favoring reproducibility.
        # This gives each layer a rnd initialization.
        keys = jax.random.split(key, n_layers + 1)
        
        dims = [input_dim] + [hidden_dims] * n_layers + [output_dim]
        self.layers = [eqx.nn.Linear(dims[i], dims[i + 1], key=keys[i]) for i in range(len(dims) - 1)]

    
    def __call__(self, x):
        """Forward-pass the input through the network

        Args:
            x (jnp.array): Input

        Returns:
            jnp.array: NN output
        """
        for layer in self.layers[:-1]:
            x = jax.nn.mish(layer(x))
        
        x = self.layers[-1](x)
        return x

In [None]:
def loss(model, sigma2, L, a, b, steps=50):

    x_flat = jnp.linspace(0, L, steps)
    inputs = jnp.concat([x_flat, jnp.array([0.0]), jnp.array([L])])
    

    rho, grad = jax.vmap(jax.value_and_grad(model))(inputs)
    hess = jax.vmap(jax.hessian(model))(inputs)

    dp_dx = grad
    d2p_dx2 = hess

    # PDE Loss
    pde_loss = sigma2*d2p_dx2 - dp_dx*(v_max*(1-2*rho))

    # in loss
    in_loss = sigma2*dp_dx[-2] - rho[-2]*v_max*(1-rho[-2]) + a*(1-rho[-2])

    # out loss
    out_loss = -sigma2*dp_dx[-1] + rho[-1]*v_max*(1-rho[-1]) - b*rho[-1]

    return jnp.average(pde_loss**2) + in_loss**2 + out_loss**2


In [133]:
# Initialize the model with multiple hidden layers and scalar output layer
model_key = jax.random.PRNGKey(1)

input_dim = 'scalar'
output_dim = 'scalar'
n_layers = 3
n_nodes = 3

model = MLP(model_key, input_dim=input_dim, hidden_dims=n_nodes, n_layers=n_layers, output_dim=output_dim)

In [130]:
# PSO parameters
n_particles = 30
n_iterations = 1000
w = 0.5  # inertia weight
c1 = 2.0  # cognitive parameter
c2 = 2.0  # social parameter

# Initialize particles
key = jax.random.PRNGKey(0)
particle_keys = jax.random.split(key, n_particles)
particles = [jax.random.normal(k, (sum(x.size for x in jax.tree_util.tree_leaves(model)),)) for k in particle_keys]
velocities = [jax.random.normal(k, (sum(x.size for x in jax.tree_util.tree_leaves(model)),)) for k in particle_keys]
personal_best_positions = particles.copy()
personal_best_scores = [float('inf')] * n_particles
global_best_position = None
global_best_score = float('inf')

def get_model_params(model):
    leaves = jax.tree_util.tree_leaves(model)
    return jnp.concatenate([x.ravel() for x in leaves])

def set_model_params(model, params):
    leaves = jax.tree_util.tree_leaves(model)
    indices = np.cumsum([0] + [x.size for x in leaves])
    new_leaves = [params[indices[i]:indices[i+1]].reshape(x.shape) for i, x in enumerate(leaves)]
    return jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(model), new_leaves)


In [131]:
# Training loop
for iteration in range(n_iterations):
    for i, particle in enumerate(particles):
        # Update model parameters
        model = set_model_params(model, particle)
        # Calculate loss (fitness)
        fitness = loss(model, sigma2, L, a, b)
        if fitness < personal_best_scores[i]:
            personal_best_scores[i] = fitness
            personal_best_positions[i] = particle
        if fitness < global_best_score:
            global_best_score = fitness
            global_best_position = particle

    # Update velocities and positions
    for i, particle in enumerate(particles):
        r1 = jax.random.uniform(model_key)
        r2 = jax.random.uniform(model_key)
        cognitive_velocity = c1 * r1 * (personal_best_positions[i] - particle)
        social_velocity = c2 * r2 * (global_best_position - particle)
        velocities[i] = w * velocities[i] + cognitive_velocity + social_velocity
        particles[i] = particle + velocities[i]

    # Print loss every 100 iterations
    if iteration % 100 == 0:
        print(f"Iteration {iteration}, Loss: {global_best_score}")

# Set the model to the best found parameters
model = set_model_params(model, global_best_position)

Iteration 0, Loss: 0.05563843995332718


KeyboardInterrupt: 

In [141]:
# PSO parameters
n_particles = 30
n_iterations = 1000
w = 0.5  # inertia weight
c1 = 2.0  # cognitive parameter
c2 = 2.0  # social parameter

# Initialize particles
key = jax.random.PRNGKey(0)
particle_keys = jax.random.split(key, n_particles)
particles = [jax.random.normal(k, (sum(x.size for x in jax.tree_util.tree_leaves(model)),)) for k in particle_keys]
velocities = [jax.random.normal(k, (sum(x.size for x in jax.tree_util.tree_leaves(model)),)) for k in particle_keys]
personal_best_positions = particles.copy()
personal_best_scores = [float('inf')] * n_particles
global_best_position = None
global_best_score = float('inf')

def get_model_params(model):
    leaves = jax.tree_util.tree_leaves(model)
    return jnp.concatenate([x.ravel() for x in leaves])

def set_model_params(model, params):
    leaves = jax.tree_util.tree_leaves(model)
    indices = np.cumsum([0] + [x.size for x in leaves])
    new_leaves = [jnp.array(params)[indices[i]:indices[i+1]].reshape(x.shape) for i, x in enumerate(leaves)]
    return jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(model), new_leaves)

@jax.jit
def compute_fitness(params, sigma2, L, a, b, model):
    model = set_model_params(model, params)
    return loss(model, sigma2, L, a, b)

@jax.jit
def update_particles(particles, velocities, personal_best_positions, global_best_position, r1, r2, w, c1, c2):
    cognitive_velocity = c1 * r1 * (personal_best_positions - particles)
    social_velocity = c2 * r2 * (global_best_position - particles)
    new_velocities = w * velocities + cognitive_velocity + social_velocity
    new_particles = particles + new_velocities
    return new_particles, new_velocities

# Vectorize the fitness computation across all particles
compute_fitness_vectorized = jax.vmap(compute_fitness, in_axes=(0, None, None, None, None, None))


In [142]:
# Training loop
for iteration in range(n_iterations):
    fitness_values = compute_fitness_vectorized(particles, sigma2, L, a, b, model)
    
    # Update personal bests and global best
    for i in range(n_particles):
        if fitness_values[i] < personal_best_scores[i]:
            personal_best_scores[i] = fitness_values[i]
            personal_best_positions[i] = particles[i]
        if fitness_values[i] < global_best_score:
            global_best_score = fitness_values[i]
            global_best_position = particles[i]
    
    # Generate random coefficients for velocity update
    r1 = jax.random.uniform(model_key, shape=(n_particles, particles[0].shape[0]))
    r2 = jax.random.uniform(model_key, shape=(n_particles, particles[0].shape[0]))
    
    # Update particles and velocities
    particles, velocities = update_particles(
        jnp.array(particles), jnp.array(velocities),
        jnp.array(personal_best_positions), jnp.array(global_best_position),
        r1, r2, w, c1, c2
    )

    # Print loss every 100 iterations
    if iteration % 100 == 0:
        print(f"Iteration {iteration}, Loss: {global_best_score}")

# Set the model to the best found parameters
model = set_model_params(model, global_best_position)

TypeError: cannot reshape array of shape (0,) (size 0) into shape (1, 3) (size 3)

In [46]:
@eqx.filter_jit # decorator that applies Just-in-time compilation just to the relevant parts. Improves performance.
def train_step(model, opt_state, L, sigma2, a, b, steps=20):
    loss_value, grads = jax.value_and_grad(loss)(model, sigma2, L, a, b) # Compute the loss and gradient
    updates, opt_state = optimizer.update(grads, opt_state) # Apply the backward propagation
    model = eqx.apply_updates(model, updates) # Update the NN
    return model, opt_state, loss_value

In [47]:
# Optimizer definition
lr = 10e-4
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(model)

In [48]:
# Training loop
num_epochs = 10000
losses = []

sigma2 = 0.5**2


for epoch in range(num_epochs):
    
    model, opt_state, loss_value = train_step(model, opt_state, L, sigma2, a, b, steps=200)

    if epoch % 100 == 0:
        losses.append(loss_value)
        print(f"Epoch {epoch}, Loss: {loss_value}")

        
    if epoch % 1000 == 0:
        lr *= 0.4
        optimizer = optax.adam(learning_rate=lr)
        opt_state = optimizer.init(model)  # Re-initialize optimizer state with new learning rate



# Print final loss
final_loss = loss(model, L, sigma2, a, b, steps=20)
print(f"Final Loss: {final_loss}")

Epoch 0, Loss: 3.3649051189422607
Epoch 100, Loss: 0.6583548188209534
Epoch 200, Loss: 0.20324784517288208
Epoch 300, Loss: 0.12363242357969284
Epoch 400, Loss: 0.08543533831834793
Epoch 500, Loss: 0.04002106562256813
Epoch 600, Loss: 0.031140180304646492
Epoch 700, Loss: 0.030779216438531876
Epoch 800, Loss: 0.03067518025636673
Epoch 900, Loss: 0.030574560165405273
Epoch 1000, Loss: 0.0304742269217968
Epoch 1100, Loss: 0.029191453009843826
Epoch 1200, Loss: 0.02875533141195774
Epoch 1300, Loss: 0.028714673593640327
Epoch 1400, Loss: 0.028697365894913673
Epoch 1500, Loss: 0.028681080788373947
Epoch 1600, Loss: 0.028665846213698387
Epoch 1700, Loss: 0.028651002794504166
Epoch 1800, Loss: 0.02863633818924427
Epoch 1900, Loss: 0.028621751815080643
Epoch 2000, Loss: 0.02860664762556553
Epoch 2100, Loss: 0.028589705005288124
Epoch 2200, Loss: 0.028566980734467506
Epoch 2300, Loss: 0.02853696420788765
Epoch 2400, Loss: 0.028489001095294952
Epoch 2500, Loss: 0.02837412618100643
Epoch 2600, Lo