<a href="https://colab.research.google.com/github/KapilKhanal/thesis_differentiable/blob/master/IsoPerformancePipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install optimistix

Collecting optimistix
  Downloading optimistix-0.0.9-py3-none-any.whl.metadata (17 kB)
Collecting equinox>=0.11.7 (from optimistix)
  Downloading equinox-0.11.9-py3-none-any.whl.metadata (18 kB)
Collecting jaxtyping>=0.2.23 (from optimistix)
  Downloading jaxtyping-0.2.36-py3-none-any.whl.metadata (6.5 kB)
Collecting lineax>=0.0.6 (from optimistix)
  Downloading lineax-0.0.7-py3-none-any.whl.metadata (17 kB)
Downloading optimistix-0.0.9-py3-none-any.whl (83 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.6/83.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading equinox-0.11.9-py3-none-any.whl (179 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.2.36-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.8/55.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lineax-0.0.7-py3-none-any.whl (67 kB)
[2

In [52]:
import jax.numpy as jnp
import jax
import optimistix as optx

# Define the system of equations for coupled disciplines
def equations(vars, data):
    x1, z1, z2 = data
    y1, y2 = vars
    eq1 = y1 - z1**2 - z2 - x1 + y2
    eq2 = y2 - jnp.sqrt(y1) - z1 - z2
    return jnp.array([eq1, eq2])


def objectives(vars,data):
    x1, z1, z2 = data
    y1, y2 = vars
    obj1 = y1 * jnp.exp(y2)
    obj2 = y2 * jnp.exp(y1)
    obj3 = jnp.tanh(y1 * y2)
    return jnp.array([obj1, obj2, obj3])


# system objectives
def full_system(data):
    x1, z1, z2 = data
    initial_guess = jnp.array([0.1, 0.0])

    # Solve the coupled equations
    solver = optx.Newton(rtol=1e-5, atol=1e-5)
    solution = optx.root_find(equations, solver,initial_guess, args=data)
    y1, y2 = solution.value

    # Objectives
    obj1 = y1 * jnp.exp(y2)
    obj2 = y2 * jnp.exp(y1)
    obj3 = jnp.tanh(y1 * y2)
    return jnp.array([obj1, obj2, obj3])

# Example data / design variable which we can change
x1, z1, z2 = 1.0, 2.0, 3.0
data = (x1, z1, z2)

# Compute the Jacobian of the full objective with respect to data
full_system_jacobian = jax.jacobian(full_system)(data)

# Solve for y1, y2 - again outside to calculatee jacobian of equation
initial_guess = jnp.array([0.1, 0.0])
solver = optx.Newton(rtol=1e-5, atol=1e-5)
solution = optx.root_find(equations, solver, initial_guess, args=data)
y1, y2 = solution.value

# Define the regularization parameter - change and compare
mu = 1e3

# Compute the Jacobian of the equations with respect to vars - obtained via solution.value
vars = jnp.array([y1, y2])
jacobian_vars = jax.jacobian(objectives, argnums=0)(vars, data)

# Compute the regularized Jacobian
regularized_jacobian = mu * jacobian_vars


In [55]:
# Combine the Jacobians (regualirzed and full objective) into the isoJacobian
isoJacobian = jnp.hstack([
   jnp.vstack(full_system_jacobian),
  jnp.vstack( regularized_jacobian)
])

# Null space calculation via SVD
U, s, Vt = jnp.linalg.svd(isoJacobian)
null_space_basis = Vt[-2:]  # Assuming last two are in the null space

# Verify if null space basis is orthonormal
dot_product = jnp.dot(null_space_basis, null_space_basis.T)
tolerance = 1e-3
is_orthonormal = jnp.allclose(dot_product, jnp.eye(null_space_basis.shape[0]), atol=tolerance)

# Output results
print("Null space basis:", null_space_basis)
if is_orthonormal:
    print("The null space basis is orthonormal.")
else:
    print("The null space basis is not orthonormal.")


Null space basis: [[ 4.7087246e-03 -7.9947454e-01  6.0067827e-01  1.7297617e-03
  -9.9972822e-04]
 [-3.5377166e-03  6.0066599e-01  7.9949087e-01 -1.2996483e-03
   7.5114169e-04]]
The null space basis is orthonormal.
