In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import functools
import nvgpu

In [None]:
# we want to JIT the methods of a class on specific devices
# for convenience, we define a decorator that does this for us, this will use
# the device attribute of the class to determine the device to JIT on
def jit_with_device(method):
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        # Compile the method with jax.jit for the specific device
        wrapped = jax.jit(method, device=self._device)
        return wrapped(self, *args, **kwargs)

    return wrapper

In [None]:
from jax.tree_util import register_pytree_node
import copy


# This will be registered as proper pytree
class Optimizable:
    def __init__(self, N, coefs):
        self.N = N
        self.coefs = coefs
        # actual Optimizable class has more attributes, but we don't need them here


# This will be registered as proper pytree
class Objective:
    def __init__(self, opt, grid, target, device_id=0):
        self.opt = opt
        self.grid = grid
        self.target = target
        self._device_id = device_id
        self._device = jax.devices("gpu")[device_id]

    def build(self):
        # the transform matrix A such that A @ coefs gives the
        # values of the function at the grid points
        self.A = jnp.vstack([jnp.cos(i * self.grid) for i in range(self.opt.N)]).T

    @jit_with_device
    def compute_error(self, coefs, A=None):
        if A is None:
            A = self.A
        vals = A @ coefs
        return vals - self.target

    @jit_with_device
    def jac_error(self, coefs, A=None):
        if A is None:
            A = self.A
        return jax.jacfwd(self.compute_error)(coefs, A)

In [None]:
def pconcat(arrays):
    """Concatenate arrays from multiple devices"""
    # we will use either CPU or GPU[0] for the matrix decompositions, so the
    # array of float64 should fit into single device
    device = jax.devices("gpu")[0]
    out = jnp.concatenate([jax.device_put(x, device=device) for x in arrays])
    return out

In [None]:
# This will be registered as proper pytree
class ObjectiveFunctionParallel:
    def __init__(self, objectives):
        self.objectives = objectives
        self.num_device = len(objectives)

    def build(self):
        # construct the constant arrays
        ...

    def compute_error(self, coefs=None, A=None):
        # compute the error for each objective and concatenate them
        fs = [
            obj.compute_error(jax.device_put(coefi, device=obj._device), Ai)
            for obj, coefi, Ai in zip(self.objectives, coefs, A)
        ]
        return pconcat(fs)

    def jac_error(self, coefs=None, A=None):
        # compute the jacobian for each objective and concatenate them
        fs = [
            obj.jac_error(jax.device_put(coefi, device=obj._device), Ai)
            for obj, coefi, Ai in zip(self.objectives, coefs, A)
        ]
        return pconcat(fs)

In [9]:
N = 40
num_nodes = 15
coefs = np.zeros(N)
coefs[2] = 3
eq = Optimizable(N, coefs)
grid1 = jnp.linspace(-jnp.pi, 0, num_nodes, endpoint=False)
grid2 = jnp.linspace(0, jnp.pi, num_nodes, endpoint=False)
grid3 = jnp.concatenate([grid1, grid2])
target1 = grid1**2
target2 = grid2**2
target3 = grid3**2

obj1 = Objective(eq, grid1, target1, device_id=0)
obj2 = Objective(eq, grid2, target2, device_id=1)
obj1.build()
obj2.build()

# we will put different objectives to different devices
obj1 = jax.device_put(obj1, jax.devices("gpu")[0])
obj2 = jax.device_put(obj2, jax.devices("gpu")[1])
# if we don't assign the eq again, there will be no connection
# between obj.opt.coefs. Since they are supposed to be the same optimizable,
# they need to have same pointers (jax.device_put creates a copy which has
# different memory location)
obj1.opt = eq
obj2.opt = eq

objp_fun = ObjectiveFunctionParallel([obj1, obj2])
objp_fun.build()

In [10]:
objective = objp_fun
print(jnp.linalg.norm(objective.compute_error()))
step = 0
while jnp.linalg.norm(objective.compute_error()) > 1e-3:
    eq.coefs = (
        eq.coefs
        - 1e-1 * jnp.linalg.pinv(objective.jac_error()) @ objective.compute_error()
    )
    step += 1

print(jnp.linalg.norm(objective.compute_error()))

25.148645
0.00091715116
