In [None]:
import numpy as np
from numpy.testing import assert_allclose

import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8')  # make plots a bit nicer by default

# custom imports
from plot_utils import plot_matrix_evolution

## To do / Questions

1. Finish rewriting old code
    * $Q$, multi-$w$
    * $UU^\top$, training loop
    * $UU^\top$, single $w$
    * $UU^\top$, multi-$w$

2. Check that each of the above is learning properly
    * Performs better than one step of GD
        * on new data with the same $w_\star$
        * on new data with a different $w_\star$??
    * How consistent are learned matrices over multiple runs w/ same init?
        * if not consistent, what properties do they share?

3. What patterns are visible in learned $Q$?
    * Eigenvalues, vectors.
    * Low-rank $\Sigma_w$: do you capture this subspace?
    * Patterns in imshow

4. Update potential based on whole batch rather than a single datapoint, for easier interp?

5. Optimization-style objective, rather than crossval

# Generating data

In [None]:
def generate_linear_data(d_feature: int, n_samples: int, w_cov: np.ndarray = None, noise_scale: float = 0.0) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generates data y = Xw + e where X is (n_samples, d_feature) and w is (d_feature, 1).

    Args:
        d_feature: Dimension of feature space.
        n_samples: Number of samples to generate.
        w_cov: Covariance matrix for generating w. If None, uses identity matrix.
        noise_scale: Standard deviation of Gaussian noise.

    Returns:
        X: Feature matrix of shape (n_samples, d_feature).
        y: Target vector of shape (n_samples,).
        w: True weight vector of shape (d_feature,).
    """
    if w_cov is None:
        w_cov = np.eye(d_feature)
    
    w = np.random.multivariate_normal(mean=np.zeros(d_feature), cov=w_cov)
    X = np.random.randn(n_samples, d_feature)
    y = X @ w + np.random.normal(scale=noise_scale, size=(n_samples,))
    
    return X, y.ravel(), w


def generate_mixed_linear_data(d_feature: int, n_samples_per_w: int, n_ws: int, w_cov: np.ndarray = None, noise_scale: float = 0.0) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generates data y = Xw + e for multiple w vectors.

    Args:
        d_feature: Dimension of feature space.
        n_samples_per_w: Number of samples to generate for each w.
        n_ws: Number of different w vectors to use.
        w_cov: Covariance matrix for generating w. If None, uses identity matrix.
        noise_scale: Standard deviation of Gaussian noise.

    Returns:
        X: Feature matrix of shape (n_ws * n_samples_per_w, d_feature).
        y: Target vector of shape (n_ws * n_samples_per_w,).
        W: Matrix of true weight vectors of shape (n_ws, d_feature).
    """
    if w_cov is None:
        w_cov = np.eye(d_feature)
    
    W = np.random.multivariate_normal(mean=np.zeros(d_feature), cov=w_cov, size=n_ws)
    X = np.random.randn(n_ws * n_samples_per_w, d_feature)
    y = np.concatenate([X[i * n_samples_per_w : (i+1) * n_samples_per_w] @ W[i] for i in range(n_ws)])
    y += np.random.normal(scale=noise_scale, size=y.shape)
    
    return X, y, W


def generate_specified_linear_data(n_samples_per_w: int, w_stars: np.ndarray, noise_scale: float = 0.0) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generates data y = Xw + e for specified w vectors.

    Args:
        n_samples_per_w: Number of samples to generate for each w.
        w_stars: Array of weight vectors, shape (n_ws, d_feature).
        noise_scale: Standard deviation of Gaussian noise.

    Returns:
        X: Feature matrix of shape (n_ws * n_samples_per_w, d_feature).
        y: Target vector of shape (n_ws * n_samples_per_w,).
        W: Matrix of true weight vectors of shape (n_ws, d_feature).
    """
    n_ws, d_feature = w_stars.shape
    X = np.random.randn(n_ws * n_samples_per_w, d_feature)
    y = np.concatenate([X[i*n_samples_per_w:(i+1)*n_samples_per_w] @ w_stars[i] for i in range(n_ws)])
    y += np.random.normal(scale=noise_scale, size=y.shape)
    
    return X, y, w_stars


def generate_std_basis_data(d_feature: int, n_samples: int, w: np.ndarray = None) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generates data y = Xw where X consists of standard basis vectors.

    Args:
        d_feature: Dimension of feature space.
        n_samples: Number of samples to generate (must be <= d_feature).
        w: True weight vector. If None, randomly generated.

    Returns:
        X: Feature matrix of shape (n_samples, d_feature).
        y: Target vector of shape (n_samples,).
        w: True weight vector of shape (d_feature,).
    """
    assert n_samples <= d_feature, "n_samples must be <= d_feature for standard basis vectors"
    
    if w is None:
        w = np.random.randn(d_feature)
    
    X = np.eye(d_feature)[:n_samples]
    y = (X @ w).ravel()
    
    return X, y, w


def generate_mixed_std_basis_data(d_feature: int, n_samples_per_w: int, n_ws: int, w_cov: np.ndarray = None, noise_scale: float = 0.0) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Generates data y = Xw + e for multiple w vectors, where X consists of standard basis vectors.

    Args:
        d_feature: Dimension of feature space.
        n_samples_per_w: Number of samples to generate for each w (must be <= d_feature).
        n_ws: Number of different w vectors to use.
        w_cov: Covariance matrix for generating w. If None, uses identity matrix.
        noise_scale: Standard deviation of Gaussian noise.

    Returns:
        X: Feature matrix of shape (n_ws * n_samples_per_w, d_feature).
        y: Target vector of shape (n_ws * n_samples_per_w,).
        W: Matrix of true weight vectors of shape (n_ws, d_feature).
    """
    assert n_samples_per_w <= d_feature, "n_samples_per_w must be <= d_feature for standard basis vectors"
    
    if w_cov is None:
        w_cov = np.eye(d_feature)
    
    W = np.random.multivariate_normal(mean=np.zeros(d_feature), cov=w_cov, size=n_ws)
    X = np.tile(np.eye(d_feature)[:n_samples_per_w], (n_ws, 1))
    y = np.concatenate([W[i, :n_samples_per_w] for i in range(n_ws)])
    y += np.random.normal(scale=noise_scale, size=y.shape)
    
    return X, y, W

# $Q$ parameterization: not ensuring psd or symmetric

## Basic training loop ✅

In [None]:
def mirror_descent_step(w: np.ndarray, Q: np.ndarray, lr: float, x: np.ndarray, y: float) -> np.ndarray:
    """
    Single step of mirror descent.
    [ 7.1: this function isn't actually used right now, since the potential update computes this manually ]

    Args:
        w: Current weight vector (d_feature, 1).
        Q: Potential matrix (d_feature, d_feature).
        lr: Learning rate.
        x: Feature vector (d_feature,).
        y: Target value.

    Returns:
        Updated weight vector (d_feature,).
    """
    return w - 2 * lr * (np.inner(w, x) - y) * (Q @ x)


def crossval(w: np.ndarray, Q: np.ndarray, lr: float, X: np.ndarray, y: np.ndarray,
             ignore_diag=True) -> float:
    """
    Perform leave-one-out cross-validation.

    For each (x_i, y_i) in the dataset:
    1. "Train" a model with a single step of mirror descent on (x_i, y_i)
    2. Evaluate it on the rest of the dataset

    Tested against non-vectorized version, performs identically.

    Args:
        w: Initial weight vector (d_feature,).
        Q: Potential matrix (d_feature, d_feature).
        lr: Learning rate.
        X: Feature matrix (n_samples, d_feature).
        y: Target vector (n_samples,).

    Returns:
        Average loss over all i, j with i != j.
    """
    n, d = X.shape

    errors = X @ w - y  # (n_samples,)
    XQX = X @ Q.T @ X.T  # (n_samples, n_samples)
    L_squared = (errors - 2 * lr * errors[:, np.newaxis] * XQX)**2  # (n_samples, n_samples)
    
    # Enforcing i != j condition, if applicable
    if ignore_diag:
        np.fill_diagonal(L_squared, 0)
        denom = 2 * n * (n-1)
    else:
        denom = 2 * n**2
    
    return np.sum(L_squared) / denom  # average over all L_ij

def potential_update(w: np.ndarray, Q: np.ndarray, outer_lr: float, inner_lr: float, X: np.ndarray, y: np.ndarray,
                     ignore_diag=True) -> np.ndarray:
    """
    Old version that uses an incomprehensible einsum instead of being clear.
    Should delete once I'm confident the new version is

    Update the potential matrix Q based on the derivative of the cross-validation loss.

    Args:
        w: Current weight vector (d_feature,).
        Q: Current potential matrix (d_feature, d_feature).
        outer_lr: Learning rate for updating Q.
        inner_lr: Learning rate for the inner mirror descent step (η in the formula).
        X: Feature matrix (n_samples, d_feature).
        y: Target vector (n_samples,).

    Returns:
        Updated potential matrix Q (d_feature, d_feature).
    """
    n_samples = X.shape[0]
    
    errors = X @ w - y
    XQX = X @ Q.T @ X.T
    L = errors - 2 * inner_lr * errors[:, np.newaxis] * XQX
    
    if ignore_diag:
        np.fill_diagonal(L, 0)
        denom = n_samples * (n_samples - 1)
    else:
        denom = n_samples ** 2
    
    # Corrected matrix operations to match the original einsum
    L_errors = L.T * errors  # (n_samples, n_samples)
    update = X.T @ L_errors @ X  # (d_feature, d_feature)
    
    update *= -2 * inner_lr / denom
    return Q - outer_lr * update

### `potential_training_loop`

In [None]:
def potential_training_loop(d, n, inner_lr, outer_lr, w0, Q0, n_iters,
                            w_cov=None, noise_scale=0., seed=None):
    """
    Basic training loop.
    Generates linear data and runs potential_update for n_iters iterations.

    Returns crossvals, Qs, X, y, w_star.

    **TODO: not sure "seed" works the way I want here to give reproducibility?
    """
    if seed:
        np.random.seed = seed

    X, y, w_star = generate_linear_data(d, n, w_cov, noise_scale)

    crossvals = np.zeros(n_iters)
    Q = Q0.copy()
    Qs = np.zeros((n_iters, d, d))
    Qs[0] = Q.copy()
    for i in range(n_iters):
        crossvals[i] = crossval(w0, Q, inner_lr, X, y, ignore_diag=True)
        Q = potential_update(w0, Q, outer_lr, inner_lr, X, y, ignore_diag=False)
        Qs[i] = Q.copy()

    return crossvals, Qs, X, y, w_star

### Checking correctness of `crossval` and `potential_update` ✅

In [None]:
# test crossval and potential_update against unvectorized versions

def random_experiment_setup():
    d_feature = np.random.randint(5, 15)
    n_samples = np.random.randint(5, 100)
    lr = np.random.uniform(0.01, 1)
    noise_scale = np.random.uniform(0, 1)

    X, y, w_star = generate_linear_data(d_feature, n_samples, noise_scale=noise_scale)

    w = np.random.randn(d_feature)
    Q = np.random.randn(d_feature, d_feature)

    return d_feature, n_samples, lr, X, y, w_star, w, Q

n_tests = 15

# ------------------------ Testing L_ij -----------------------------
for _ in range(n_tests):
    d, n, lr, X, y, w_star, w, Q = random_experiment_setup()
    z = X @ w - y
    M = X @ Q.T @ X.T
    L = (z - 2 * lr * z[:, np.newaxis] * M)  # L[i,j] = L_{ij}

    # check that it works
    for i in range(min(d,n)):
        for j in range(min(d,n)):
            manual_calc = z[j] - 2 * lr * z[i] * X[j].T @ Q @ X[i]
            assert np.isclose(L[i,j], manual_calc), f"{i,j} {L[i,j], manual_calc}"
print("L_ij passed test.")


# ------------------------ Testing crossval -----------------------------
def crossval_nonvec(w, Q, lr, xs, ys):
    """
    For each (x_i, y_i) in zip(xs, ys):
    1. "Train" a model with a single step of mirror descent on (x_i, y_i)
    2. Evaluate it on the rest of the dataset
    Return the average loss over all i, j with i != j.
    """
    k = len(xs)
    def L_ij(w, Q, xi, xj, yi, yj):
        return ( w.T @ xj - 2 * lr * (w.T @ xi - yi) * (xj.T @ Q @ xi) - yj )**2
    value = 0
    for i, (xi, yi) in enumerate(zip(xs, ys)):
        for j, (xj, yj) in enumerate(zip(xs, ys)):
            if i == j:
                continue
            value += L_ij(w, Q, xi, xj, yi, yj).item()
    return value / (2 * k * (k-1))

for _ in range(n_tests):
    # all parameters are randomized over a bunch of runs
    d_feature, n_samples, lr, X, y, w_star, w, Q = random_experiment_setup()

    # Prepare inputs for crossval_nonvec
    xs = [X[i, :].reshape(-1, 1) for i in range(n_samples)]  # list of (d_feature, 1) arrays
    ys = y.tolist()  # list of floats

    result_nonvec = crossval_nonvec(w[:, np.newaxis], Q, lr, xs, ys)
    result_vec = crossval(w, Q, lr, X, y)

    assert np.isclose(result_nonvec, result_vec) 
print("crossval passed test.")


# ------------------------ Testing potential_update -----------------------------
def potential_update_nonvec(w, Q, outer_lr, inner_lr, xs, ys):
    """
    Derivative of the cross-validation loss (as implemented in `crossval` above)
    with respect to the matrix Q.

    Returns the updated matrix.
    """
    k = len(xs)
    def cv_derivative(w, Q, lr, xi, xj, yi, yj):
        err = (w.T @ xi - yi).item()
        #return (w.T @ xj - 2 * lr * err * (xi.T @ Q @ xj) - yj).item() * err * np.outer(xi, xj)
        return (w.T @ xj - 2 * lr * err * (xj.T @ Q @ xi) - yj).item() * err * (xj @ xi.T)
    
    update = 0
    for i, (xi, yi) in enumerate(zip(xs, ys)):
        for j, (xj, yj) in enumerate(zip(xs, ys)):
            if i != j:
                update += cv_derivative(w, Q, inner_lr, xi, xj, yi, yj)
    update = - 2 * inner_lr * update / (k * (k - 1))
    return Q - outer_lr * update

for _ in range(n_tests):
    d_feature, n_samples, inner_lr, X, y, w_star, w, Q = random_experiment_setup()
    outer_lr = np.random.uniform(0.01, 1)

    xs = [X[i, :].reshape(-1, 1) for i in range(n_samples)]  # list of (d_feature, 1) arrays
    ys = y.tolist()  # list of floats

    result_nonvec = potential_update_nonvec(w, Q, outer_lr, inner_lr, xs, ys)
    result_vec = potential_update(w, Q, outer_lr, inner_lr, X, y)
    assert np.allclose(result_nonvec, result_vec) 
print("potential_update passed test.")
    


In [None]:
# Scratch work: how do you vectorize "sum of outer products of rows of X"

# i.e. sum_{i,j} np.outer(x_i, x_j)
d, n = 5, 10
X = np.random.randn(n,d)

# sum of outer products of rows, written as a loop
sopr = np.zeros((d,d))
for i in range(n):
    for j in range(n):
        sopr += np.outer(X[j], X[i])

# turn the loop into a nested list comprehension: (n,d) -> (n, n, d, d), and then sum over first two dims
expand = np.array( [[ np.outer(X[i], X[j]) for j in range(n)] for i in range(n)] )
sopr2 = expand.sum((0,1))
assert_allclose(sopr, sopr2)

# turn into einsum: same expand step first, then sum over first two indices
einsum_expand = np.einsum('ik,jl->ijkl', X, X)
sopr3 = np.einsum('ik,jl->ijkl', X, X).sum((0,1))
assert np.allclose(einsum_expand, expand)
assert_allclose(sopr, sopr3)

# next, combine expand and sum steps into a single einsum
sopr4 = np.einsum('ik,jl->kl', X, X)
assert_allclose(sopr, sopr4)

# finally, figure out how to write this without an einsum. 
rowsums = X.T @ np.ones(n)
sopr5 = np.outer(rowsums, rowsums)
assert_allclose(sopr, sopr5)

### Known example: noiseless case, standard basis data ✅

In [None]:
def optimal_non_pd_potential_std_basis(w0: np.ndarray, w_star: np.ndarray, lr: float = 0.5) -> np.ndarray:
    """
    Returns matrix Q that makes the error L_[ij] = 0 for standard basis data.

    Args:
        w0: Initial weight vector (d_feature,).
        w_star: True weight vector (d_feature,).
        lr: Learning rate.

    Returns:
        Optimal potential matrix Q (d_feature, d_feature).
    """
    d_feature = w0.shape[0]
    Q = np.zeros((d_feature, d_feature))
    for i in range(d_feature):
        for j in range(d_feature):
            Q[i, j] = ((w0[i] - w_star[i]) / (w0[j] - w_star[j])).item()
    return Q / (2 * lr)

def stationarity_condition_Q_Lij(Q, X, w, w_star, lr):
    """
    Stationarity condition: <u, Qv> = <u,z> / (2η <v,z>), where z = w_0 - w_star.
    
    Given a data matrix X, this returns <X[i], QX[j]> - <X[i], z> / (2η <X[j], z>) for all i,j
    """
    errors = X @ (w - w_star)  # prediction errors
    error_ratio_matrix = errors[:, np.newaxis] / (2 * lr * errors)  # (n, n), matrix of values <X[i], z> / (2η <X[j], z>)
    return X @ Q @ X.T - error_ratio_matrix

def stationarity_condition_Q_full(Q, X, w, w_star, lr):
    errors = X @ (w - w_star)   # (n,)
    
    # TODO: would be nice to replace this einsum with something legible
    left_hand_side = np.einsum('j,ia,ab,jb,ik,jl->kl', errors**2, X, Q, X, X, X)

    XTe = X.T @ errors
    right_hand_side = np.outer(XTe, XTe)  # (d,d)

    return 2 * lr * left_hand_side - right_hand_side

In [None]:
d, n = 5,5
inner_lr, outer_lr = 1, 0.4
n_potential_iterations = 6000

X, y, w_star = generate_std_basis_data(d, n)
w0 = np.ones((d,))
Q0 = np.eye(d)
#Q0 = np.random.randn(d_feature, d_feature)

Q_star = optimal_non_pd_potential_std_basis(w0, w_star, inner_lr)

Q = Q0.copy()
crossvals = np.zeros(n_potential_iterations)
Qs = np.zeros((n_potential_iterations, d, d))
for i in range(n_potential_iterations):
    crossvals[i] = crossval(w0, Q, inner_lr, X, y)
    Qs[i] = Q
    Q = potential_update(w0, Q, outer_lr, inner_lr, X, y, ignore_diag=False)
Q_dists = np.linalg.norm(Qs - Q_star, axis=(1, 2))

# Plotting results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(crossvals)
plt.title("Crossval loss")
plt.yscale('log')
plt.xlabel("Iterations")
plt.ylabel("Loss")

plt.subplot(1, 2, 2)
plt.plot(Q_dists)
plt.title("$d(Q, Q_\\star)$")
plt.yscale('log')
plt.xlabel("Iterations")
plt.ylabel("Distance")

plt.tight_layout()
plt.show()

**Plot failure of stationarity condition per iteration for several different dimensions**


In [None]:
colors = ['blue', 'red', 'green', 'orange', 'gray', 'brown', 'purple']
eps = 1e-8
ignore_diag = False

for size in range(3, 8):
    d, n = size, size
    inner_lr, outer_lr = 1, 0.4
    n_potential_iterations = 1000 * size

    #print(f"size={size}, rep={rep}")
    w0 = np.ones((d,))
    Q0 = np.eye(d)
    
    for rep in range(2):
        X, y, w_star = generate_std_basis_data(d, n)    
        Q_star = optimal_non_pd_potential_std_basis(w0, w_star, inner_lr)
        Q = Q0.copy()
        crossvals = []
        Qs = []

        for i in range(n_potential_iterations):
            Q = potential_update(w0, Q, outer_lr, inner_lr, X, y, ignore_diag=False)
            crossvals.append(crossval(w0, Q, inner_lr, X, y))
            Qs.append(Q.copy())
            if crossvals[-1] < eps:
                break

        crossvals = np.array(crossvals)
        Qs = np.array(Qs)
        #print(f"  final crossval: {crossvals[-1]}")
        
        j = size-3
        sp_diffs = np.array([stationarity_condition_Q_Lij(Q, X, w0, w_star, inner_lr) for Q in Qs])
        mean_abs_diffs = np.mean(np.abs(sp_diffs), axis=(1,2))
        plt.plot(mean_abs_diffs, color=f"{colors[j]}", label=f"d={size}, rep={rep+1}")
        
plt.yscale('log')
plt.legend()
plt.axhline(0, color="black")
plt.ylabel("Stationary point difference")
plt.xlabel("Iteration")
plt.title("Stationarity condition vs iteration for different dimensions, std basis data")
plt.show()

**Plot failure of stationarity condition at $\epsilon$ convergence**

In [None]:
X_val = np.random.randn(n, d)

sp_diffs = np.array([stationarity_condition_Q_Lij(Q, X, w0, w_star, lr) for Q in Qs])
sp_diffs_val = np.array([stationarity_condition_Q_Lij(Q, X_val, w0, w_star, lr) for Q in Qs])

mean_abs_diffs = np.mean(np.abs(sp_diffs), axis=(1,2))
max_abs_diffs = np.max(np.abs(sp_diffs), axis=(1,2))
min_abs_diffs = np.min(np.abs(sp_diffs), axis=(1,2))

mean_abs_diffs_val = np.mean(np.abs(sp_diffs_val), axis=(1,2))
max_abs_diffs_val= np.max(np.abs(sp_diffs_val), axis=(1,2))
min_abs_diffs_val= np.min(np.abs(sp_diffs_val), axis=(1,2))

#plt.plot(max_abs_diffs)
plt.plot(mean_abs_diffs)
plt.plot(mean_abs_diffs_val)
#plt.plot(min_abs_diffs)
#plt.yscale('log')
plt.title("$L_{ij}$ stationarity condition on std basis data")
plt.show()

In [None]:
X_val = np.random.randn(n, d)

sp_diffs = np.array([stationarity_condition_Q_full(Q, X, w0, w_star, lr) for Q in Qs])
sp_diffs_val = np.array([stationarity_condition_Q_full(Q, X_val, w0, w_star, lr) for Q in Qs])

mean_abs_diffs = np.mean(np.abs(sp_diffs), axis=(1,2))
max_abs_diffs = np.max(np.abs(sp_diffs), axis=(1,2))
min_abs_diffs = np.min(np.abs(sp_diffs), axis=(1,2))

mean_abs_diffs_val = np.mean(np.abs(sp_diffs_val), axis=(1,2))
max_abs_diffs_val= np.max(np.abs(sp_diffs_val), axis=(1,2))
min_abs_diffs_val= np.min(np.abs(sp_diffs_val), axis=(1,2))

#plt.plot(max_abs_diffs)
plt.plot(mean_abs_diffs, label="train")
#plt.plot(mean_abs_diffs_val, label="val")
plt.legend()
#plt.plot(min_abs_diffs)
#plt.yscale('log')
plt.title("Full stationarity condition on std basis data")
plt.show()

In [None]:
z = w0 - w_star
z.shape
X.shape

In [None]:
def stationary_closed_form(X, w, w_star, lr):
    z = w - w_star
    E = np.diag(X @ z)**2
    return np.outer(z, z) @ X.T @ X @ np.linalg.inv(X.T @ E @ X) / (2 * lr)


# Check that it works the way it should
X_cov = np.eye(d)
for i in range(3):
    X_cov[i,i] *= 1000

d, n = 10, 100
lr = 0.5
X = np.random.multivariate_normal(mean=np.zeros(d), cov=X_cov, size=(n,))
w0 = np.random.randn(d)
w_star = np.ones(d)

Q_star = stationary_closed_form(X, w0, w_star, lr)
plt.imshow(Q_star)
plt.grid(0)
plt.colorbar()
plt.show()

# pretty relaxed atol due to matrix inverse, sometimes using poorly-conditioned X_cov
assert_allclose(stationarity_condition_Q_full(Q_star, X, w0, w_star, lr), 0, atol=1e-3, rtol=0)

In [None]:
# idea here: resample X a bunch of times with same cov (with distinctive subspace)
# and then check if the top eigenvectors are aligned with each other, or something
#  or generally, what properties of the resulting `stationary_closed_form` matrix are preserved / similar
#
# eigenvector thing seems mostly negative. a bit complicated by complex values, not sure what to make of that.

X_cov = np.eye(d)
for i in range(3):
    X_cov[i,i] *= 1000

d, n = 10, 100
lr = 0.5
w_star = np.ones(d)

Qs, Xs, ws = [], [], []
eigenvecs = []
for i in range(8):
    Xs.append(np.random.multivariate_normal(mean=np.zeros(d), cov=X_cov, size=(n,)))
    ws.append(np.random.randn(d))
    Qs.append(stationary_closed_form(Xs[-1], ws[-1], w_star, lr))

    eigs = np.linalg.eig(Qs[-1])
    absvals = np.abs(eigs.eigenvalues)
    eigindex = np.argmax(absvals)
    eigenvecs.append(eigs.eigenvectors[eigindex])

eigenvecs = np.array(eigenvecs)
xxT = eigenvecs @ np.conj(eigenvecs).T
np.fill_diagonal(xxT, 0)
plt.imshow(np.abs(xxT))  # i want to know if other vectors are aligned
plt.colorbar()
plt.show()

# for i, Q in enumerate(Qs):

#     plt.scatter(range(d), absvals)
#     plt.title(f"{eigindex}")
#     plt.show()

## Graphs for single $w$

In [None]:
# Training setup
d,n = 5, 50
inner_lr, outer_lr = 0.1, 0.02
n_iters = 3500
w_cov = np.eye(d)
w0 = np.ones(d)
Q0 = np.eye(d)

# Training loop
crossvals, Qs, X, y, w_star = potential_training_loop(d, n, inner_lr, outer_lr, w0, Q0, n_iters, w_cov, seed=123)

# Plot crossvals
plt.plot(crossvals)
plt.title(f"CV loss: d={d}, n={n}")
#plt.yscale('log')

plot_matrix_evolution(Qs, main_title=f"Evolution of $Q$ during training: d={d}, n={n}")

### Compare to a single step of gradient descent

In [None]:
Q_final = Qs[-1]
lr = inner_lr

trained_md_crossvals = []
gd_crossvals = []
X_vals = []
for i in np.arange(1500):
    X_val = np.random.randn(n, d)
    y_val = X_val @ w_star
    trained_md_crossvals.append(crossval(w0, Q_final, lr, X_val, y_val))
    gd_crossvals.append(crossval(w0, np.eye(d), lr, X_val, y_val))
    X_vals.append(X_val)


In [None]:
md_gd_crossvals = np.array([trained_md_crossvals, gd_crossvals])
sorted_indices = np.argsort(md_gd_crossvals[0])                        # sort by value of md crossval
sorted_indices = np.argsort(md_gd_crossvals[0] - md_gd_crossvals[1])   # sort by difference between crossvals
md_gd_crossvals = md_gd_crossvals[:, sorted_indices]
X_vals = np.array(X_vals)[sorted_indices]

plt.plot(md_gd_crossvals[0], label="md")
plt.plot(md_gd_crossvals[1], label="gd")
plt.plot(md_gd_crossvals[1] - md_gd_crossvals[0], alpha=0.5, linestyle='--', color='gray', label="diffs")
plt.axhline(y=np.mean(md_gd_crossvals[0]), label="md avg", color='blue', linestyle='--')
plt.axhline(y=np.mean(md_gd_crossvals[1]), label="gd avg", color='green', linestyle='--')
plt.ylabel("Cross-validation loss over validation datasets")
plt.title("Mirror descent (blue) vs gradient descent (green).\nSorted by md crossval")
plt.legend()
plt.show()

**Explanation of next cell: Investigating `X_val` statistics across runs.**

`-----------------There's stuff here that I'd like to follow up on-------------------`

Typically, what I see above is that mirror descent does better (yay!).

Sometimes it seems to do worse, consistently -- not sure if this was an error in the code, a bad $Q$, or what.

And sometimes it seems to do better most of the time, maybe 90%, but then do worse sometimes. I (i.e. Claude) wrote the code below to investigate whether, in these cases, the `X_val` data that md performs better on is systematically different from the data that gd performs better on.

I did a 3-minute training run for $Q$ to make sure I was working with something good, and then had a consistently good gap between md and gd over many iterations -- they never swapped places. So maybe it was an undertraining issue the whole time?

In [None]:
from scipy import stats

# Find the split point where md_gd_crossvals[0,k] >= md_gd_crossvals[1,k]
split_point = np.argmax(md_gd_crossvals[0] >= md_gd_crossvals[1])

# Calculate various statistics
means = np.mean(X_vals, axis=(1, 2))
mean_distances = np.abs(means)
cov_matrices = np.array([np.cov(X.T) for X in X_vals])
top_eigenvalues = np.array([np.linalg.eigvals(cov).max() for cov in cov_matrices])
condition_numbers = np.array([np.linalg.cond(cov) for cov in cov_matrices])
std_devs = np.std(X_vals, axis=(1, 2))
skewness = np.mean(((X_vals - np.mean(X_vals, axis=(1, 2), keepdims=True)) / 
                    np.std(X_vals, axis=(1, 2), keepdims=True))**3, axis=(1, 2))
kurtosis = np.mean(((X_vals - np.mean(X_vals, axis=(1, 2), keepdims=True)) / 
                    np.std(X_vals, axis=(1, 2), keepdims=True))**4, axis=(1, 2)) - 3
ks_stats = np.array([stats.kstest(X.flatten(), 'norm').statistic for X in X_vals])

# Calculate the difference between the two runs
run_differences = X_vals[1] - X_vals[0]
diff_means = np.mean(run_differences, axis=1)

# Plotting function for line plots
def plot_metric(ax, metric, title):
    ax.plot(metric)
    ax.set_title(title)
    ax.axvline(x=split_point, color='r', linestyle='--', label=f'Split point (k={split_point})')
    ax.legend()

# Create line plots
fig, axs = plt.subplots(3, 3, figsize=(15, 15))
plot_metric(axs[0, 0], mean_distances, 'Mean Distances from Zero')
plot_metric(axs[0, 1], top_eigenvalues, 'Top Eigenvalues')
plot_metric(axs[0, 2], condition_numbers, 'Condition Numbers')
plot_metric(axs[1, 0], std_devs, 'Standard Deviations')
plot_metric(axs[1, 1], skewness, 'Skewness')
plot_metric(axs[1, 2], kurtosis, 'Kurtosis')
plot_metric(axs[2, 0], ks_stats, 'KS Statistics')
plot_metric(axs[2, 1], diff_means, 'Mean Differences Between Runs')

# Plot md_gd_crossvals
axs[2, 2].plot(md_gd_crossvals[0], label='First run')
axs[2, 2].plot(md_gd_crossvals[1], label='Second run')
axs[2, 2].axvline(x=split_point, color='r', linestyle='--', label=f'Split point (k={split_point})')
axs[2, 2].set_title('md_gd_crossvals')
axs[2, 2].legend()

plt.tight_layout()
plt.show()

# Function to plot overlapping histograms
def plot_overlapping_histogram(data, title):
    plt.figure(figsize=(10, 6))
    plt.hist(data[:split_point], bins=20, alpha=0.5, label='Before split')
    plt.hist(data[split_point:], bins=20, alpha=0.5, label='After split')
    plt.title(title)
    plt.legend()
    plt.show()

# # Plot overlapping histograms for each metric
# metrics = [mean_distances, top_eigenvalues, condition_numbers, std_devs, skewness, kurtosis, ks_stats, diff_means]
# metric_names = ['Mean Distances', 'Top Eigenvalues', 'Condition Numbers', 'Standard Deviations', 
#                 'Skewness', 'Kurtosis', 'KS Statistics', 'Mean Differences Between Runs']

# for metric, name in zip(metrics, metric_names):
#     plot_overlapping_histogram(metric, name)

# # Print summary statistics for before and after split point
# def print_summary(metric, name):
#     print(f"\n{name}:")
#     print(f"Before split: mean = {np.mean(metric[:split_point]):.4f}, std = {np.std(metric[:split_point]):.4f}")
#     print(f"After split: mean = {np.mean(metric[split_point:]):.4f}, std = {np.std(metric[split_point:]):.4f}")

# for metric, name in zip(metrics, metric_names):
#     print_summary(metric, name)

# Print the split point
print(f"\nSplit point (k) where md_gd_crossvals[0,k] >= md_gd_crossvals[1,k]: {split_point}")

## Graphs for mixed $w$ / recovering covariance

Questions:
* Are there consistent, visible patterns in $Q$?
* What properties are preseved between multiple runs with different data?
* How does $\Sigma_w$ itself perform as a mirror map?

### Generic "run experiment" function

In [None]:
def potential_training_loop_mixed_w(d, n_samples_per_w, n_ws, inner_lr, outer_lr, w0, Q0, n_iters,
                                    w_cov=None, noise_scale=0.):
    
    X, y, W_stars = generate_mixed_linear_data(d, n_samples_per_w, n_ws, w_cov, noise_scale)

    Q = Q0.copy()
    crossvals = np.zeros(n_iters)
    Qs = np.zeros((n_iters, d, d))
    for i in range(n_iters):
        Qs[i] = Q.copy()
        crossvals[i] = crossval(w0, Q, inner_lr, X, y, ignore_diag=True)
        Q = potential_update(w0, Q, outer_lr, inner_lr, X, y, ignore_diag=False)

    return crossvals, Qs, X, y, W_stars

### Low-rank $\Sigma_w$

In [None]:
# other parameters for training
d, nspw, n_ws = 8, 3, 25  # nspw = n_samples_per_w
inner_lr, outer_lr = 0.3, 0.005
n_iters = 700
w0 = np.ones(d)
Q0 = np.eye(d)

# setting up w_cov
rank = 2
assert rank < d, f"rank {rank} should be less than dimension {d}"
w_cov = np.diag([1]*rank + [0]*(d-rank))

# Train potential
crossvals, Qs, X, y, W_stars = potential_training_loop_mixed_w(d, nspw, n_ws, inner_lr, outer_lr, w0, Q0, n_iters, w_cov)

# Plot crossvals
plt.plot(crossvals)
plt.title(f"CV loss: d={d}, n={nspw * n_ws}, rank($\Sigma_w$)={rank}")
plt.yscale('log')

# Plot images of Q iterates over time
plot_matrix_evolution(Qs, extra_matrix=w_cov, extra_matrix_title=f'w_cov, rank={rank}', main_title='Evolution of $Q$ during training')

### Poorly-conditioned $\Sigma_w$

## Stationary-Point Conditions

### $L_{ij}$ stationary point condition: $x_j^\top Q x_i = z_j / (2 \eta z_i)$

In [None]:
X_val = np.random.randn(n, d)

sp_diffs = np.array([stationarity_condition_Q_Lij(Q, X, w0, w_star, lr) for Q in Qs])
sp_diffs_val = np.array([stationarity_condition_Q_Lij(Q, X_val, w0, w_star, lr) for Q in Qs])

mean_abs_diffs = np.mean(np.abs(sp_diffs), axis=(1,2))
max_abs_diffs = np.max(np.abs(sp_diffs), axis=(1,2))
min_abs_diffs = np.min(np.abs(sp_diffs), axis=(1,2))

mean_abs_diffs_val = np.mean(np.abs(sp_diffs_val), axis=(1,2))
max_abs_diffs_val= np.max(np.abs(sp_diffs_val), axis=(1,2))
min_abs_diffs_val= np.min(np.abs(sp_diffs_val), axis=(1,2))

#plt.plot(max_abs_diffs)
plt.plot(mean_abs_diffs, label="train")
plt.plot(mean_abs_diffs_val, label="val")
plt.legend()
#plt.plot(min_abs_diffs)
plt.yscale('log')
plt.show()


### Full-energy stationary-point condition

$$2\eta \sum_{i,j} \varepsilon_j (x_i^\top Q x_j) x_{ik}x_{j\ell}  = \sum_{i,j} \varepsilon_i \varepsilon_j x_{ik} x_{j\ell}.$$

In [None]:
X_val = np.random.randn(n, d)

sp_diffs = np.array([stationarity_condition_Q_full(Q, X, w0, w_star, lr) for Q in Qs])
sp_diffs_val = np.array([stationarity_condition_Q_full(Q, X_val, w0, w_star, lr) for Q in Qs])

mean_abs_diffs = np.mean(np.abs(sp_diffs), axis=(1,2))
max_abs_diffs = np.max(np.abs(sp_diffs), axis=(1,2))
min_abs_diffs = np.min(np.abs(sp_diffs), axis=(1,2))

mean_abs_diffs_val = np.mean(np.abs(sp_diffs_val), axis=(1,2))
max_abs_diffs_val= np.max(np.abs(sp_diffs_val), axis=(1,2))
min_abs_diffs_val= np.min(np.abs(sp_diffs_val), axis=(1,2))

#plt.plot(max_abs_diffs)
plt.plot(mean_abs_diffs, label="train")
plt.plot(mean_abs_diffs_val, label="val")
plt.legend()
#plt.plot(min_abs_diffs)
plt.yscale('log')
plt.show()

In [None]:
plt.plot(mean_abs_diffs[:500], label="train")

# $UU^\top$ parameterization

## Basic training loop

In [None]:
def mirror_descent_step_U(w: np.ndarray, U: np.ndarray, lr: float, x: np.ndarray, y: float) -> np.ndarray:
    """
    Single step of mirror descent.
    Just calls `mirror_descent_step` with `Q = U @ U.T`.

    [ 7.1: this function isn't actually used right now, since the potential update computes this manually ]

    Args:
        w: Current weight vector (d_feature, 1).
        U: Factor of potential matrix: Q = U @ U.T
        lr: Learning rate.
        x: Feature vector (d_feature,).
        y: Target value.

    Returns:
        Updated weight vector (d_feature,).
    """
    return mirror_descent_step(w, U @ U.T, lr, x, y)


def crossval_U(w: np.ndarray, U: np.ndarray, lr: float, X: np.ndarray, y: np.ndarray,
             ignore_diag=True) -> float:
    """
    Perform leave-one-out cross-validation.
    Just calls `crossval` with `Q = U @ U.T`.

    For each (x_i, y_i) in the dataset:
    1. "Train" a model with a single step of mirror descent on (x_i, y_i)
    2. Evaluate it on the rest of the dataset

    Tested against non-vectorized version, performs identically.

    Args:
        w: Initial weight vector (d_feature,).
        U: Factor of potential matrix: Q = U @ U.T
        lr: Learning rate.
        X: Feature matrix (n_samples, d_feature).
        y: Target vector (n_samples,).

    Returns:
        Average loss over all i, j with i != j.
    """
    return crossval(w, U @ U.T, lr, X, y, ignore_diag)

def potential_update_U(w: np.ndarray, U: np.ndarray, outer_lr: float, inner_lr: float, X: np.ndarray, y: np.ndarray,
                       ignore_diag=True) -> np.ndarray:
    """
    Update the potential matrix U based on the derivative of the cross-validation loss, where Q = UU^T.

    Args:
        w: Current weight vector (d_feature,).
        U: Current potential matrix factor (d_feature, d_feature).
        outer_lr: Learning rate for updating U.
        inner_lr: Learning rate for the inner mirror descent step (η in the formula).
        X: Feature matrix (n_samples, d_feature).
        y: Target vector (n_samples,).
        ignore_diag: Whether to ignore diagonal terms in the loss calculation.

    Returns:
        Updated potential matrix factor U (d_feature, d_feature).
    """
    n_samples = X.shape[0]
    denom = n_samples ** 2
    
    errors = X @ w - y
    L = errors - 2 * inner_lr * errors[:, np.newaxis] * X @ U @ U.T @ X.T
    
    if ignore_diag:
        np.fill_diagonal(L, 0)
        denom -= n_samples

    L_errors = L * errors[:, np.newaxis]  # (n_samples, n_samples)
    update = X.T @ (L_errors + L_errors.T) @ X @ U
    
    # Apply scaling factor
    update *= -2 * inner_lr / denom
    assert update.shape == U.shape, f"Potential update of shape {update.shape} is incompatible with U shape of {U.shape}."
    
    # Update U using gradient descent
    return U - outer_lr * update

In [None]:
def potential_training_loop_U(d, n, inner_lr, outer_lr, w0, U0, n_iters,
                            w_cov=None, noise_scale=0., seed=None):
    """
    Basic training loop.
    Generates linear data and runs potential_update for n_iters iterations.

    Returns crossvals, Qs, X, y, w_star.
    """
    X, y, w_star = generate_linear_data(d, n, w_cov, noise_scale)

    Us = [U0]
    crossvals = [crossval_U(w0, U0, inner_lr, X, y, ignore_diag=True)]
    U = U0.copy()
    for i in range(n_iters):
        U = potential_update_U(w0, U, outer_lr, inner_lr, X, y, ignore_diag=False)
        Us.append(U)
        crossvals.append(crossval_U(w0, U, inner_lr, X, y, ignore_diag=True))

    Us = np.array(Us)
    crossvals = np.array(crossvals)

    return crossvals, Us, X, y, w_star

### Claude's numerical correctness check for `potential_update_U` 
## **TODO: VERIFY THAT THIS MAKES SENSE**

In [None]:
from scipy.linalg import eigvalsh

# TODO: check that this is the same as `crossval`, and replace?
def compute_cv_loss(w, U, inner_lr, X, y):
    """Compute the cross-validation loss."""
    n_samples = X.shape[0]
    errors = X @ w - y
    XUUTX = X @ U @ U.T @ X.T
    L = errors[:, np.newaxis] - 2 * inner_lr * errors * XUUTX
    np.fill_diagonal(L, 0)
    return np.sum(L**2) / (2 * n_samples * (n_samples - 1))

def numerical_gradient(w, U, inner_lr, X, y, epsilon=1e-8):
    """Compute numerical gradient of CV loss with respect to U."""
    grad = np.zeros_like(U)
    for i in range(U.shape[0]):
        for j in range(U.shape[1]):
            U_plus = U.copy()
            U_plus[i, j] += epsilon
            U_minus = U.copy()
            U_minus[i, j] -= epsilon
            grad[i, j] = (compute_cv_loss(w, U_plus, inner_lr, X, y) - 
                          compute_cv_loss(w, U_minus, inner_lr, X, y)) / (2 * epsilon)
    return grad

def test_potential_update_U():
    # Generate random data
    n_samples, d_feature = 20, 5
    X = np.random.randn(n_samples, d_feature)
    w = np.random.randn(d_feature)
    y = X @ w + np.random.randn(n_samples) * 0.1
    U = np.random.randn(d_feature, d_feature)
    
    # Set learning rates
    outer_lr, inner_lr = 0.01, 0.1
    
    # Compute update using our function
    U_new = potential_update_U(w, U, outer_lr, inner_lr, X, y)
    actual_update = (U_new - U) / outer_lr
    
    # Compute numerical gradient
    numerical_grad = numerical_gradient(w, U, inner_lr, X, y)
    
    # Compare the results
    assert_allclose(actual_update, -numerical_grad, rtol=1e-4, atol=1e-4,
                    err_msg="Gradient from potential_update_U doesn't match numerical gradient")
    
    # Check if the resulting Q is positive semidefinite
    Q_new = U_new @ U_new.T
    min_eigenvalue = eigvalsh(Q_new).min()
    assert min_eigenvalue >= -1e-10, f"Resulting Q is not positive semidefinite. Min eigenvalue: {min_eigenvalue}"
    
    print("All tests passed!")

# Run the test
test_potential_update_U()

## Graphs for single $w$

In [None]:
certified_random_matrix = np.random.randn(d, d)

In [None]:
# Training setup
d,n = 5, 100
inner_lr, outer_lr = 0.1, 0.01
n_iters = 1000
w_cov = np.eye(d)
w0 = np.ones(d)
U0 = certified_random_matrix

# Training loop
crossvals, Us, X, y, w_star = potential_training_loop_U(d, n, inner_lr, outer_lr, w0, U0, n_iters, w_cov)

UUTs = np.array([U @ U.T for U in Us])

# Plot crossvals
plt.plot(crossvals)
plt.title(f"CV loss: d={d}, n={n}")
plt.yscale('log')

plot_matrix_evolution(Us, main_title=f"Evolution of $U$ during training: d={d}, n={n}")
plot_matrix_evolution(UUTs, main_title=f"Evolution of $UU^\\top$ during training: d={d}, n={n}")

In [None]:
plot_matrix_evolution(UUTs[250:], 5)
plot_matrix_evolution(Us - Us.transpose(0,2,1), 5, main_title="Asymmetry during training: $U - U^\\top$")

## Graphs for mixed $w$ / recovering covariance