Copyright 2022 The DP Matrix Factorization Authors.
Licensed under the Apache License, Version 2.0
In this notebook we implement three different methods for computing optimal factorizations for [the matrix mechanism](https://people.cs.umass.edu/~mcgregor/papers/15-vldbj.pdf) under approximate differential privacy:

1. Gradient descent on an associated convex problem
1. The fixed-point iteration method described in the paper associated to this code.
1. A Newton-direction-based algorithm designed in [previous literature](https://arxiv.org/pdf/1602.04302v1.pdf).

We experimentally compare their numerical efficiency on the problem of factorizing the prefix-sum matrix: the lower-triangular matrix of all 1s, which takes a vector to its vector of partial sums. Other matrices can be used, and will generally produce similar results.

One note on initialization below: it is difficult to initialize the two descent-based schemes and the fixed-point based algorithm identically. In order to ensure similar initialization, it is easiest to select a *vector* (corresponding to the parameterization of the fixed-point algorithm), and attempt to generate a matrix from this vector. Generating this matrix, however, essentially takes advantage of the representations which yield the fixed-point problem--and this generated matrix usually has significantly lower loss than a general positive definite matrix with constant 1s on the diagonal. This observation is not necessarily surprising; the dimensionalities involved in a vector parameterization are much lower. It is slightly unfair to the fixed-point method to 'allow' the gradient-based methods to use this initialization; but since the fixed-point method is the one we propose, we reserve the right to make it look slightly worse than it otherwise might.

In [None]:
from typing import Optional, Tuple

import time

import jax
from jax import numpy as jnp
from jax import value_and_grad, jit
from jax.config import config
import numpy as np

# With large matrices, the extra precision afforded by performing all
# computations in float64 is critical.
config.update('jax_enable_x64', True)


@jit
def diagonalize_and_take_jax_matrix_sqrt(matrix: jnp.ndarray, min_eval: float = 0.0) -> jnp.ndarray:
  """Matrix square root for positive-semi-definite, Hermitian matrices."""
  evals, evecs = jnp.linalg.eigh(matrix)
  eval_sqrt = jnp.maximum(evals, min_eval)**0.5
  sqrt = evecs @ jnp.diag(eval_sqrt) @ evecs.T
  return sqrt

def hermitian_adjoint(matrix: jnp.ndarray) -> jnp.ndarray:
  return jnp.conjugate(matrix).T

def compute_loss_in_x(target: jnp.ndarray, x: jnp.ndarray):
  m = hermitian_adjoint(target) @ target @ jnp.linalg.inv(x)
  raw_trace = jnp.trace(m)
  max_diag = jnp.max(jnp.diag(x))
  return raw_trace * max_diag

def compute_normalized_x_from_vector(matrix_to_factorize, v, precomputed_sqrt: Optional[jnp.ndarray] = None):
  """Computes a normalized (to all-1s diagonal) version of the vector -> matrix
  mapping which defines the relationship between fixed points and optima.

  At a fixed point of phi (equivalently an optimum of the symmetrized
  factorization problem), the normalization below will no-op. But to normalize
  between iterations of the fixed-point method and iterations of the descent-
  based methods, it is useful to force the results of this transformation
  to always have constant-1 diagonals.
  """
  inv_diag_sqrt = jnp.diag(v ** -(0.5))
  diag_sqrt = jnp.diag(v ** 0.5)
  if precomputed_sqrt is None:
    target = hermitian_adjoint(matrix_to_factorize) @ matrix_to_factorize
    matrix_sqrt = diagonalize_and_take_jax_matrix_sqrt(
        diag_sqrt @ target.astype(diag_sqrt.dtype) @ diag_sqrt)
  else:
    # We simply assume that our caller did this computation correctly.
    matrix_sqrt = precomputed_sqrt
  x = inv_diag_sqrt @ matrix_sqrt @ inv_diag_sqrt
  # Force all-1s on diagonal. This normalization is a requirement for the
  # initial iterates of the descent-based methods, and we know it's true at the
  # optimum.
  x_sqrt_diag = jnp.diag(x)
  normalized_x = jnp.diag((x_sqrt_diag)**(-0.5)) @ x @ jnp.diag((x_sqrt_diag)**(-0.5))
  return normalized_x


# Algorithm implementation: gradient descent on the convex problem

In [None]:
def optimize_factorization_grad_descent(target: jnp.ndarray, n_iters: int, initial_x: jnp.ndarray, lr: float = 1., use_armijo_rule: bool = True):
  """Uses JAX-implemented gradient descent to optimize DP-MatFac problem."""

  # Capture target in loss definition.
  compute_loss = lambda x: compute_loss_in_x(target=target, x=x)
  compiled_loss = jit(compute_loss)

  def find_next_iterate(x_iter, grad, init_lr):
    candidate = x_iter - grad * init_lr
    non_pd = jnp.any(jnp.isnan(jnp.linalg.cholesky(candidate)))
    if non_pd:
      # We choose 0.1 as the Armijo factor; this is what the paper we're looking to reproduce does as well
      return find_next_iterate(x_iter, grad, init_lr * 0.1)
    else:
      sufficient_decrease_condition = compiled_loss(x_iter) + init_lr * 0.25 * jnp.sum(grad ** 2)
      if compiled_loss(candidate) <= sufficient_decrease_condition:
        return candidate
      return find_next_iterate(x_iter, grad, init_lr * 0.1)

  loss_and_grad = value_and_grad(compute_loss)

  x_iter = initial_x

  loss_array = []
  time_array = []

  start = time.time()
  for i in range(n_iters):
    # Gradient step
    loss, grad = loss_and_grad(x_iter)
    diag_elements = jnp.diag_indices_from(grad)
    grad1 = grad.at[diag_elements].set(0)
    loss_array.append(loss)
    if use_armijo_rule:
      x_iter = find_next_iterate(x_iter, grad1, lr)
    else:
      x_iter = x_iter - lr * grad1
    # Orthogonally project onto symmetric matrices.
    x_iter = (x_iter + x_iter.T) / 2
    
    time_array.append(time.time() - start)
  
  # Suppress any costs to the first iteration
  initial_time = time_array[0]
  time_array = [x - initial_time for x in time_array]
  return x_iter, loss_array, time_array

s_matrix = jnp.tril(jnp.ones(shape=(128, 128)))
opt, losses, time_in_loop = optimize_factorization_grad_descent(s_matrix, 100, jnp.eye(s_matrix.shape[0]), lr=1.)

print(f'Time in loop: {time_in_loop}')
print(opt)
print(losses)
print(jnp.min(jnp.linalg.eigh(opt)[0]))

# Algorithm implementation: fixed-point iteration

In [None]:
def compute_phi_fixed_point(
    matrix: jnp.ndarray,
    initial_v: jnp.array,
    rtol: float = 1e-5,
    max_iterations: Optional[int] = None,
) -> Tuple[jnp.ndarray, int, float]:

  target = hermitian_adjoint(matrix) @ matrix 
  v = initial_v

  n_iters = 0

  def continue_loop(iteration: int) -> bool:
    if max_iterations is None:
      return True
    return iteration < max_iterations

  @jit
  def _compute_loss(v, matrix_sqrt):
    # This computation in the middle may slow down our fixed point method and could
    # be bypassed. We only have it here to track the loss as we iterate.
    normalized_x = compute_normalized_x_from_vector(matrix, v, matrix_sqrt)
    loss = compute_loss_in_x(matrix, normalized_x)
    return loss

  def _update_loss(v, matrix_sqrt):
    loss = _compute_loss(v, matrix_sqrt)
    # We rely on Python late binding to capture start here.
    time_array.append(time.time() - start)
    loss_array.append(loss)

  time_array = []
  loss_array = []
  # We keep around the previously computed matrix square root
  # to save time in evaluating loss.
  matrix_sqrt = diagonalize_and_take_jax_matrix_sqrt(jnp.diag(v) ** 0.5 @ target @ jnp.diag(v) ** 0.5)

  start = time.time()
  while continue_loop(n_iters):
    n_iters += 1
    # Compute loss first, for first iteration, to normalize the loss trajectories
    # between these loss arrays and the descent-based ones.
    _update_loss(v, matrix_sqrt)
    diag = jnp.diag(v)
    diag_sqrt = diag ** 0.5
    new_v = jnp.diag(matrix_sqrt)
    # Set up matrix_sqrt for the next iteration. We use this wonky update order to be
    # able to cache this square root computation for loss evaluation.
    matrix_sqrt = diagonalize_and_take_jax_matrix_sqrt(jnp.diag(new_v) ** 0.5 @ target @ jnp.diag(new_v) ** 0.5)
    norm_diff = jnp.linalg.norm(new_v - v)
    rel_norm_diff = norm_diff / jnp.linalg.norm(v)
    if rel_norm_diff < rtol:
      _update_loss(new_v, matrix_sqrt)
      return new_v, n_iters, rel_norm_diff, time_array, loss_array
    v = new_v

  _update_loss(v, matrix_sqrt)
  return v, n_iters, rel_norm_diff, time_array, loss_array

def optimize_factorization_fixed_point(s_matrix: jnp.ndarray, max_iterations: int, rtol: float, initial_v: Optional[jnp.array]=None):
  if initial_v is None:
    initial_v = jnp.ones_like(jnp.diag(s_matrix))
  (lagrange_multiplier, n_iters,
   final_relnorm, timing, losses) = compute_phi_fixed_point(
       s_matrix, rtol=rtol, max_iterations=max_iterations, initial_v=initial_v)
   
  inv_diag_sqrt = jnp.diag(lagrange_multiplier**-(0.5))
  diag_sqrt = jnp.diag(lagrange_multiplier**0.5)

  target = hermitian_adjoint(s_matrix) @ s_matrix
  x = inv_diag_sqrt @ diagonalize_and_take_jax_matrix_sqrt(
      diag_sqrt @ target.astype(diag_sqrt.dtype) @ diag_sqrt) @ inv_diag_sqrt

  # Suppress any costs to the first iteration, often due to tracing, etc
  initial_time = timing[0]
  adj_time_array = [x - initial_time for x in timing]

  return x, losses, adj_time_array


opt, loss, time_to_compute = optimize_factorization_fixed_point(s_matrix, 2, rtol=1e-1)
print(f'Opt array: {opt}')
print(f'Time to compute: {time_to_compute}')
print(f'losses: {loss}')


# Newton-step-style algorithm from [existing literature](https://arxiv.org/pdf/1602.04302v1.pdf)

In [None]:
# Implementing Alg 1 from https://arxiv.org/pdf/1602.04302v1.pdf.

def compute_newton_direction(Z, grad, max_iter: int = 5):
  """Implements algorithm 2 from the referenced paper."""
  # Initialize according to line 4.
  D = jnp.zeros_like(grad)
  R = -grad + Z @ D @ grad + grad @ D @ Z
  # Set diag of D and R to zero; line 5
  diag_elements = jnp.diag_indices_from(D)
  D = D.at[diag_elements].set(0)
  R = R.at[diag_elements].set(0)
  # Initialize P and r_old as in line 6.
  P = R
  # Interestingly, this is the inner product used in the paper.
  r_old = jnp.sum(R * R)
  for i in range(max_iter):
    # Set B and alpha as in line 8
    B = -grad + Z @ D @ grad + grad @ D @ Z
    alpha = r_old / jnp.sum(P * B)
    # Update D and R as in line 9
    D = D + alpha * P
    R = R - alpha * B
    # Set diags of D anr R to 0, as in line 10
    D = D.at[diag_elements].set(0)
    R = R.at[diag_elements].set(0)
    # Set r_new and update P, as in line 11
    r_new = jnp.sum(R * R)
    P = R + r_new / r_old * P
    # Update r_old; line 12
    r_old = r_new
    if jnp.max(jnp.abs(R)) == 0:
      # Everything nans if this is violated. I assume this loop should terminate in this case.
      break
  return D

def optimize_factorization_newton_step(target: jnp.ndarray, n_iters: int, initial_x: jnp.ndarray, init_lr=1.):
  """Uses JAX-implemented gradient descent to optimize DP-MatFac problem."""

  # Setting to 1 reproduces the paper of interest; see section 4.2.
  # We parameterize for the purposes of tuning.
  lr = init_lr

  # Capture target in loss definition.
  compute_loss = lambda x: compute_loss_in_x(target=target, x=x)
  compiled_loss = jit(compute_loss)

  def find_next_iterate_armijo(x_iter, grad, newton_dir, init_lr):
    """Computes step size as in Sec 4.2 of referenced paper."""
    candidate = x_iter + newton_dir * init_lr
    # This is essentially the method for checking positive-definiteness proposed
    # by the paper.
    non_pd = jnp.any(jnp.isnan(jnp.linalg.cholesky(candidate)))
    if non_pd:
      # We choose 0.1 as the Armijo factor; this is what the paper we're looking to reproduce does as well
      return find_next_iterate_armijo(x_iter, grad, newton_dir, init_lr * 0.1)
    # Equation (16)
    target_decrease = compiled_loss(x_iter) + init_lr * 0.25 * jnp.sum(grad * newton_dir)
    if compiled_loss(candidate) <= target_decrease:
      return candidate
    return find_next_iterate_armijo(x_iter, grad, newton_dir, init_lr * 0.1)

  loss_and_grad = value_and_grad(compiled_loss)

  x_iter = initial_x

  loss_array = []
  time_array = []

  start = time.time()
  for i in range(n_iters):
    # Gradient step
    loss, grad = loss_and_grad(x_iter)
    inv_x = jnp.linalg.inv(x_iter)
    newton_direction = compute_newton_direction(Z=inv_x, grad=grad)
    loss_array.append(loss)
    x_iter = find_next_iterate_armijo(x_iter, grad, newton_direction, lr)
    # Orthogonally project onto symmetric matrices.
    x_iter = (x_iter + x_iter.T) / 2 
    time_array.append(time.time() - start)

  # Suppress any costs to the first iteration
  initial_time = time_array[0]
  time_array = [x - initial_time for x in time_array]
  return x_iter, loss_array, time_array

opt, losses, time_in_loop = optimize_factorization_newton_step(s_matrix, 100, jnp.eye(s_matrix.shape[0]))
print(f'Final loss: {losses}')
print(f'Time in loop: {time_in_loop}')
print(f'Min eval: {jnp.min(jnp.linalg.eigh(opt)[0])}')

# Data and plot generation

In [None]:
_MAX_ITERS = 1000

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def matrix_factorization_speed_data(matrix, dim: int, max_inv_rtol: int = 20, max_iter: int = 1000):

  # We fix a seed for reproducability, though the results are generally uniform
  # in this seed.
  key = jax.random.PRNGKey(256)
  initial_v = jax.random.uniform(key=key, shape=jnp.diag(matrix).shape)

  fp_data = {'x': [], 'y': []}
  gd_data = {'x': [], 'y': []}
  ns_data = {'x': [], 'y': []}

  _, losses_fp, time_to_compute_fp = optimize_factorization_fixed_point(matrix, _MAX_ITERS, rtol=10**-max_inv_rtol, initial_v=initial_v)
  fp_data['x'] = time_to_compute_fp
  fp_data['y'] = [float(x.to_py()) for x in losses_fp]

  initial_x = compute_normalized_x_from_vector(matrix, initial_v)

  n_iters = max_iter
  # True gradient descent on the convex problem
  _, losses, time_in_loop = optimize_factorization_grad_descent(matrix, n_iters, initial_x, lr=1., use_armijo_rule=True)
  gd_data['x'] = time_in_loop
  gd_data['y'] = [float(x.to_py()) for x in losses]

  # The Newton-direction-based method of https://arxiv.org/pdf/1602.04302v1.pdf
  _, losses, time_in_loop = optimize_factorization_newton_step(matrix, n_iters, initial_x, init_lr=1.)
  ns_data['x'] = time_in_loop
  ns_data['y'] = [float(x.to_py()) for x in losses]

  df = pd.DataFrame({'Elapsed Time (s)': fp_data['x'] + gd_data['x'] + ns_data['x'], 
                   'Loss': fp_data['y'] + gd_data['y'] + ns_data['y'],
                   'Method': ['Fixed point'] * len(fp_data['x']) + ['Gradient descent'] * len(gd_data['x']) + ['Newton-based step'] * len(ns_data['x'])})
  return df


def prefix_sum_factorization_speed_data(dim: int, max_inv_rtol: int = 20, max_iter: int = 1000) -> pd.DataFrame:
  s_matrix = jnp.tril(jnp.ones(shape=(dim, dim)))
  return matrix_factorization_speed_data(s_matrix, dim, max_inv_rtol, max_iter)



In [None]:
def generate_and_plot_data(dim: int, max_inv_rtol: int, max_iter: int):

  df = prefix_sum_factorization_speed_data(dim=dim, max_inv_rtol=max_inv_rtol, max_iter=max_iter)
  palette = sns.color_palette('muted')
  plt.figure(figsize=(10, 7))
  sns.set_style('whitegrid')
  sns.set_context('paper')
  line = sns.lineplot(data=df,
              x='Elapsed Time (s)',
              y='Loss',
              hue='Method',
              palette=[palette[0], palette[1], palette[2]],
              )

  # Compute the max time that all methods generated data for.
  max_elapsed_times = []
  for method in df['Method'].unique():
    max_elapsed_times.append(np.max(df[df['Method'] == method]['Elapsed Time (s)']))
  max_all_elapsed_time = max(max_elapsed_times)

  max_loss = np.max(df['Loss'])
  min_loss = np.min(df['Loss'])

  # Heuristic method to set the ranges for visibility
  plt.ylim(min_loss - (max_loss - min_loss) * 0.05, max_loss)
  plt.xlim(0, max_all_elapsed_time)
  return df 

df = generate_and_plot_data(2048, 10, 1000)
