<a href="https://colab.research.google.com/github/ashegde/notebooks/blob/main/newton_schulz_iteration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
In this notebook, we take a very simple look at the Rectangular Newton-Schulz algorithm
developed in:

Bernstein, J., & Newhouse, L. (2024). Modular Duality in Deep Learning. arXiv preprint arXiv:2410.21265.

This paper, and several preceding works, e.g.,

Bernstein, J., & Newhouse, L. (2024). Old optimizer, new norm: An anthology. arXiv preprint arXiv:2409.20325.

and

Yang, G., Simon, J. B., & Bernstein, J. (2023). A spectral condition for feature learning. arXiv preprint arXiv:2310.17813.

provide a rather fascinating look at architecture-adapted optimization algorithms for neural networks. The overarching
principal seems to be to ensure that the layer-wise weight matrices remain well-conditioned or well-normed at initialization
and during training. This is accomplished by ensuring that the input and output vectors of each layer are appropriately sized,
where "size" is measured by the corresponding vector space norms. The general rule is that in a d-dimensional vector space,
vectors should be of size approximately \sqrt{d}, which coincides with standard normalization schemes such as BatchNorm, LayerNorm, and RMSNorm.

From my reading, this perspective seems quite distict from and complementary to other works in the field, which focus on the training
dynamics of stochastic gradient methods and the associated implicit biases (e.g., towards low "complexity" solutions). There is a lot more
to explore in this space.
"""

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def Newton_Schulz(A: np.ndarray, n_steps: int = 15, eps: float = 1e-9):
  """
  Rectangular Newton-Schulz Iteration
  """
  # normalize to ensure singular values of A are
  # contained in between 0 and \sqrt{3}.
  X = A / (eps + np.linalg.norm(A, ord = 'fro'))
  a = float(3/2)
  b = float(1/2)
  for _ in range(n_steps):
      X = a * X - b * X @ X.T @ X
  return X

# Note, to see what the above iteration is doing, consider the SVD of X: X = USV'
# The RHS breaks down into U(a*S - b*S^3)V', hence the polynomial in question
# is applied directly to each singular value.

In [None]:
n = 100
d = 500
A = 1/np.sqrt(d) * np.random.randn(n,d)
Ua, Sa, Vha = np.linalg.svd(A, full_matrices=False)

In [None]:
X = Newton_Schulz(A)
Ux, Sx, Vhx = np.linalg.svd(X, full_matrices=False)

In [None]:
# Error (how close is X to UV'?)
rel_error = np.linalg.norm(X - Ua@Vha) / np.linalg.norm(Ua@Vha)
print(rel_error)

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
bins = np.linspace(min( np.min(Sa), np.min(Sx)), max( np.max(Sa), np.max(Sx)), 20)
plt.hist(Sa, bins=bins, color ='b', alpha=0.5, label = 'Pre-NS singular values')
plt.hist(Sx, bins=bins, color='r', alpha=0.5, label = 'Post-NS singular values')
plt.xlabel('Singular values')
plt.ylabel('Counts')
plt.title('Singular value distributions')
plt.legend()
plt.show()