In [1]:
from typing import NamedTuple

import jax
import jax.numpy as jnp
import jaxopt
from jax import grad, random

from utils.dataset import load_mnist784
from utils.tm import get_2d_tm

In [2]:
class OracleParams(NamedTuple):
  d: int # Problem size
  n: int # Number of samples on server
  T: int # Number of all samples 
  C: jnp.ndarray
  q: jnp.ndarray

In [3]:
target_digit = 1
data = jnp.array(load_mnist784(target_digit))

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB





In [5]:
T = 10
n = T // 3 * 4
d = 784
key = random.PRNGKey(0)
q = jax.random.choice(key, data, shape=(T,), replace=False)
C = jnp.array(get_2d_tm(28))

In [6]:
oracle_params = OracleParams(d, n, T, C, q)

In [7]:
class PAUSParams(NamedTuple):
    gamma: float
    L_F1: float
    composite_steps: int

In [8]:
delta = 2 * (T - n) / T
L_F1 = 2
composite_steps = 100

In [9]:
paus_params = PAUSParams(1 / delta, L_F1, composite_steps)

In [45]:
from __future__ import annotations

In [52]:
class Point(NamedTuple):
    x: jnp.ndarray
    p: jnp.ndarray
    u: jnp.ndarray
    v: jnp.ndarray

    def __add__(self, z: Point) -> Point:
        return Point(self.x + z.x, self.p + z.p, self.u + z.u, self.v + z.v)

    def __sub__(self, z: Point) -> Point:
        return Point(self.x - z.x, self.p - z.p, self.u - z.u, self.v - z.v)

In [11]:
def init(oracle_params: OracleParams) -> Point:
    d, T = oracle_params.d, oracle_params.T
    x = jnp.array([jnp.ones((d, d)) / (d * d)] * T)
    p = jnp.ones(d) / d
    u = jnp.array([jnp.zeros(d)] * T)
    v = jnp.array([jnp.zeros(d)] * T)

    return Point(x, p, u, v)

In [12]:
def f(x: jnp.ndarray, p: jnp.ndarray, u: jnp.ndarray, v: jnp.ndarray, C: jnp.ndarray, q: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the value of operator at point z = (x, C, p, q, u, v).

    :param jnp.ndarray x: Shape (d, d).
    :param jnp.ndarray p: Shape (d, ).
    :param jnp.ndarray u: Shape (d, ).
    :param jnp.ndarray v: Shape (d, ).
    :param jnp.ndarray C: Shape (d, d).
    :param jnp.ndarray q: Shape (d, ).
    :return jnp.ndarray: float
    """
    C_norm = jnp.max(jnp.abs(C))
    return jnp.multiply(C, x).sum() + 2 * C_norm * (jnp.dot(u, x.sum(axis=-1) - p) + jnp.dot(v, x.sum(axis=-2) - q))

In [13]:
def f_batched(x: jnp.ndarray, p: jnp.ndarray, u: jnp.ndarray, v: jnp.ndarray, C: jnp.ndarray, q: jnp.ndarray) -> jnp.ndarray:
    f_vmap = jax.vmap(f, in_axes=(0, None, 0, 0, None, 0))

    return f_vmap(x, p, u, v, C, q).sum()

In [14]:
gradient = jax.jit(grad(f_batched, argnums=(0, 1, 2, 3)))

In [25]:
grad_x, grad_p, grad_u, grad_v = gradient(z_k.x, z_k.p, z_k.u, z_k.v, oracle_params.C, oracle_params.q)

In [29]:
def oracle(z: Point, oracle_params: OracleParams) -> tuple[Point, Point]:
    grad_x, grad_p, grad_u, grad_v = gradient(z.x, z.p, z.u, z.v, oracle_params.C, oracle_params.q)
    return Point(
        jnp.sum(grad_x, axis=0) / oracle_params.T,
        jnp.sum(grad_p, axis=0) / oracle_params.T,
        jnp.sum(-grad_u, axis=0) / oracle_params.T,
        jnp.sum(-grad_v, axis=0) / oracle_params.T,
    ), Point(
        jnp.sum(grad_x[: oracle_params.n], axis=0) / oracle_params.n,
        jnp.sum(grad_p[: oracle_params.n], axis=0) / oracle_params.n,
        jnp.sum(-grad_u[: oracle_params.n], axis=0) / oracle_params.n,
        jnp.sum(-grad_v[: oracle_params.n], axis=0) / oracle_params.n,
    )

In [56]:
def projection_simplex(x: jnp.ndarray) -> jnp.ndarray:
    orig_shape = x.shape
    return jaxopt.projection.projection_simplex(x.flatten()).reshape(orig_shape)

In [57]:
projection_simplex_vmap = jax.vmap(projection_simplex, in_axes=(0,))

In [38]:
def update_composite_MP(
    z_k: Point, v_t: Point, G_z_k: Point, oracle_params: OracleParams, gamma: float, eta: float
) -> Point:
    _, F1_v_t = oracle(v_t, oracle_params)
    G_t = F1_v_t + G_z_k
    v_t_next = Point(
        projection_simplex_vmap(z_k.x * jnp.power(jnp.exp(-gamma * eta * G_t.x) * v_t.x / z_k.x, 1 / (eta + 1))),
        jaxopt.projection.projection_simplex(
            z_k.p * jnp.power(jnp.exp(-gamma * eta * G_t.p) * v_t.p / z_k.p, 1 / (eta + 1))
        ),
        jaxopt.projection.projection_linf_ball(0.5 * (z_k.u + v_t.u - gamma * eta * G_t.u)),
        jaxopt.projection.projection_linf_ball(0.5 * (z_k.v + v_t.v - gamma * eta * G_t.v)),
    )

    G_t_next = oracle(v_t_next, oracle_params) + G_z_k
    v_t = Point(
        projection_simplex_vmap(z_k.x * jnp.power(jnp.exp(-gamma * eta * G_t_next.x) * v_t.x / z_k.x, 1 / (eta + 1))),
        jaxopt.projection.projection_simplex(
            z_k.p * jnp.power(jnp.exp(-gamma * eta * G_t_next.p) * v_t.p / z_k.p, 1 / (eta + 1))
        ),
        jaxopt.projection.projection_linf_ball(0.5 * (z_k.u + v_t.u - gamma * eta * G_t.u)),
        jaxopt.projection.projection_linf_ball(0.5 * (z_k.v + v_t.v - gamma * eta * G_t.v)),
    )
    return v_t

In [49]:
def update_paus(z_k: Point, u_k: Point, oracle_params: OracleParams, paus_params: PAUSParams) -> tuple[Point, Point]:
    # update_composite_MP_jit = jax.jit(update_composite_MP)
    F_z_k, F1_z_k = oracle(z_k, oracle_params)
    G_z_k = F_z_k - F1_z_k
    eta = 1 / (2 * paus_params.gamma * paus_params.L_F1)
    for _ in range(paus_params.composite_steps):
        u_k = update_composite_MP(z_k, u_k, G_z_k, oracle_params, paus_params.gamma, eta)
    F_u_k, F1_u_k = oracle(u_k, oracle_params)
    G_u_k = F_u_k - F1_u_k
    G = G_u_k - G_z_k
    z_k = Point(
        projection_simplex_vmap(u_k.x * jnp.exp(-paus_params.gamma * G.x)),
        jaxopt.projection.projection_simplex(u_k.p * jnp.exp(-paus_params.gamma * G.p)),
        jaxopt.projection.projection_linf_ball(u_k.u - paus_params.gamma * G.u),
        jaxopt.projection.projection_linf_ball(u_k.v - paus_params.gamma * G.v),
    )
    return z_k, u_k

In [40]:
update_paus_jit = jax.jit(update_paus)

In [61]:
z_k, u_k = init(oracle_params), init(oracle_params)
for _ in range(100):
    z_k, u_k = update_paus(z_k, u_k, oracle_params, paus_params)

XlaRuntimeError: UNKNOWN: /Users/michael/Documents/PAUS/env/lib/python3.11/site-packages/jaxopt/_src/projection.py:101:13: error: failed to legalize operation 'mhlo.pad'
  cumsum_u = jnp.cumsum(u)
            ^
/Users/michael/Documents/PAUS/env/lib/python3.11/site-packages/jaxopt/_src/projection.py:101:13: note: called from
  cumsum_u = jnp.cumsum(u)
            ^
/Users/michael/Documents/PAUS/env/lib/python3.11/site-packages/jaxopt/_src/projection.py:101:13: note: see current operation: %235 = "mhlo.pad"(%234, %1) {edge_padding_high = dense<[0, 1]> : tensor<2xi64>, edge_padding_low = dense<0> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<10x1xf32>, tensor<f32>) -> tensor<10x2xf32>
  cumsum_u = jnp.cumsum(u)
            ^
