# ISE 633 Homework 5

Author: Yue Wu <wu.yue@usc.edu>
$\def\vf#1{\boldsymbol{#1}}$

This notebook is the coding part of the 5th homework of ISE 633. It is written in a style where code and explanations are interleaved. The notebook is written in Python 3.12 and is best viewed in a Jupyter notebook environment.

### Dependencies

It requires the following packages:

```plain
jax==0.4.25
jaxtyping==0.2.25
beartype==0.17.2
equinox==0.11.3
```

and reasonably new versions of `ipytest`, `seaborn`, `matplotlib` and `pandas`. The notebook is exclusively written in Python 3.12+. Since the computation is done with JAX with its fully deterministic RNGs, the results should be completely reproducible. The local dependencies (i.e. `utils.*`) will be available in the same directory as this notebook in my github [repository](https://github.com/EtaoinWu/ise633).

A significant portion of code in this notebook is identical to my previous homeworks.

In [None]:
from functools import partial

import beartype
import equinox as eqx
import ipytest
import jax
import jax.scipy.optimize
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from beartype.typing import Any, Callable, cast
from jax import numpy as jnp, random as jr, tree_util as jtu
from jaxtyping import Array, Bool, Float, Integer, Key, Scalar, jaxtyped
from matplotlib.axes import Axes
import tqdm.notebook as tqdm

import utils.platform
from utils.tree import tree_nfold_cross_validation_tests, tree_nfold_cross_validation_trains

ipytest.autoconfig()
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_threefry_partitionable", True)

utils.platform.init_matplotlib("svg")
sns.set_theme("notebook", style="whitegrid")

SingleKey = Key[Scalar, ""]
typechecked = jaxtyped(typechecker=beartype.beartype)
FloatLike = float | Float[Scalar, ""]

We first define our model.

In [None]:
class PartitionProblem(eqx.Module):
    n: int = eqx.field(static=True)
    l: Float[Scalar, ""]
    a: Float[Array, "n n"]
    def __check_init__(self):
        assert self.a.shape == (self.n, self.n)
        assert jnp.all(jnp.equal(self.a, self.a.T))

    @property
    def a_tilde(self) -> Float[Array, "n n"]:
        return jnp.ones_like(self.a) * self.l - self.a

@typechecked
def load_partition(fn: str) -> PartitionProblem:
    a_np = np.loadtxt("hw5/A.txt",dtype=float,delimiter=",")
    a = jnp.array(a_np)
    return PartitionProblem(n=a.shape[0], l=jnp.array(5.), a=a)

In [None]:
cur_partition = load_partition("hw5/A.txt")
print(f"Symmetric: {jnp.allclose(cur_partition.a, cur_partition.a.T)}")

### Visualization of the matrix


In [None]:
def visualize_matrix(a: Float[Array, "n n"], title: str | None = None) -> plt.Figure:
    fig, ax = plt.subplots()
    ax = cast(Axes, ax)
    sns.heatmap(a, ax=ax, square=True, cmap=["#ffffee", "black"], cbar=False, xticklabels=False, yticklabels=False)
    ax.set_title(title if title else "Matrix")
    return fig

In [None]:
plt.ioff()
fig = visualize_matrix(cur_partition.a, title="Input matrix A")
None

## Reformulation

We introduce another variable $\vf{Z}$ to decouple the constraint on the diagonal and the PSD constraint. In the below problem

$$
\begin{array}{rl}
\min\limits_{\vf X,\vf Z} &\langle X,\tilde{A}\rangle \\
\textrm{s.t. } & \vf X \succeq 0\\
&Z_{ii}=1, \forall i\in \{1, \dots, n\}\\
&\vf Z=\vf X
\end{array}
$$

we can see that the feasibility condition on $\vf X$ and the objective function is the same as the original.

## Rewrite the problem for solving with ADMM

With $\lambda=5$, we have:

In [None]:
a_tilde = cur_partition.a_tilde