<a href="https://colab.research.google.com/github/RiaStevens/rmt-research/blob/main/RMT_Sim_Plots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Random Matrix Theory Simulations with Neural Networks

Within this Jupyter notebook, we analyze the eigenvalue structure of the Hessians of a single-ReLU layer neural network throughout gradient descent and how this relates to its loss. For quick results, run "Imports" and "Key Function Definitions" as sections before plotting.

We set up our neural network using the following parameters:
* $n_0$ represents the number of parameters in the input layer. This is customizable.
* $n_1$ represents the size of the hidden ReLU layer. This is customizable.
* $n_2$ represents the size of the output layer. This is customizable.
* $m$ represents the number of data points passed into the neural network. This is customizable.
* $W^{(1)}$ represents the first weight matrix in the network. It is a $n_1$ x $n_0$ matrix with each entry being Gaussian with mean 0 and variance $\sigma^2$ at initialization. 
* $W^{(2)}$ represents the second weight matrix in the network. It is a $n_2$ x $n_1$ matrix with each entry being Gaussian with mean 0 and variance $\sigma^2$ at initialization. 
* $x$ represents our input data. It is an $n_0$ x $m$ matrix with Gaussian entries with mean 0 and variance $\sigma^2$.
* $y$ represents our target data. It is an $n_2$ x $m$ matrix with Gaussian entries with mean 0 and variance $\sigma^2$.

To generate predictions given $W^{(1)}$ and $W^{(2)}$ at some time t, we define $z = W_1x$ and $[A_{xy}]_+ = \text{max}(A_{xy}, 0)$ as our ReLU function, where $A$ is any matrix. This gives us the following output, $\hat{y}$:
\begin{equation}
  \hat{y}_{i\mu} =  \sum_{k=1}^{n_1} W^{(2)}_{ik} [z_{k\mu}]_+
\end{equation}
Using $\hat{y}$, we define $e = \hat{y} - y$ and continue by defining our loss function, $\mathcal{L}$, as the mean squared error using $e$:
\begin{equation}
  \mathcal{L} = \frac{1}{2m} \|e\|^2 =  \frac{1}{2m} \sum_{i, \mu=1}^{n_2, m} e_{iu}^2 
\end{equation}

# Imports

In [69]:
import numpy as np
import itertools
import matplotlib.pyplot as plt
import numpy.linalg as npla
from scipy.optimize import curve_fit
from ipywidgets import widgets, interactive
from IPython.utils import io

# Key Function Definitions

## Simulation Definitions

In [70]:
def gradient(x, W_2, e, prop):
  (n2, n1) = W_2.shape
  (n0, m) = x.shape

  J_1 = np.tensordot(W_2, x, axes=0)
  for a, u in itertools.product(range(n1), range(m)):
    if prop[a, u] == 0:
      J_1[0:n2, a, 0:n0, u] = 0

  J_2 = np.tensordot(np.ones((n2, n2)), prop, axes=0)

  A=np.einsum('iu,iabu->ab',e,J_1)/m
  B=np.einsum('iu,icdu->cd',e,J_2)/m
  return A, B

In [71]:
def gradient_descent(W1, W2, x, y, num_iters, step, prop, e, id):
  (n2, n1) = W2.shape
  (n0, m) = x.shape

  losses = np.empty(0)
  for j in range(num_iters):
    G1, G2 = gradient(x, W2, e, prop)
    W1 = W1 - G1 * step
    W2 = W2 - G2 * step
    
    z = W1 @ x
    prop = np.maximum(z, np.zeros_like(z))
    y_hat = W2 @ prop
    e = y_hat - y

    loss = npla.norm(e) ** 2 / (2 * m)
    losses = np.append(losses, loss)
    
  return W1, W2, losses, e, prop

In [72]:
def hessian(W1, W2, x, y, l, l0):
  (n2, n1) = W2.shape
  (n0, m) = x.shape

  z = W1 @ x
  prop = np.maximum(z, np.zeros_like(z))
  y_hat = W2 @ prop
  e = y_hat - y 

  def compute_H0():
    J_1 = np.tensordot(W2, x, axes=0)
    for a, u in itertools.product(range(n1), range(m)):
      if prop[a, u] == 0:
        J_1[0:n2, a, 0:n0, u] = 0

    J_2 = np.tensordot(np.ones((n2, n2)), prop, axes=0)

    H_0_tl = np.einsum('xaby,xcdy->abcd', J_1, J_1)
    H_0_tl_flat = H_0_tl.reshape((n1 * n0, n1 * n0))

    H_0_tr = np.einsum('xaby,xcdy->abcd', J_1, J_2)
    H_0_tr_flat = H_0_tr.reshape((n1 * n0, n1 * n2))

    H_0_bl = np.einsum('xcdy,xaby->cdab', J_2, J_1)
    H_0_bl_flat = H_0_bl.reshape((n1 * n2, n1 * n0))

    H_0_br = np.einsum('xaby,xcdy->abcd', J_2, J_2)
    H_0_br_flat = H_0_br.reshape((n1 * n2, n1 * n2))

    H_0 = np.block([[H_0_tl_flat, H_0_tr_flat], [H_0_bl_flat, H_0_br_flat]])
    return H_0

  def compute_H1():
    H_1_tl = np.zeros((n1 * n0, n1 * n0))

    H_1_tr = np.zeros((n0 * n1, n1 * n2))
    for q, p in itertools.product(range(n0 * n1), range(n1 * n2)):
      a = int(q / n0)
      b = q % n0
      c = int(p / n1)
      d = p % n1

      if a == d:
        e_masked = np.where(prop[a] > 0, e[c], 0)
        H_1_tr[q][p] = e_masked @ x[b]

    H_1_br = np.zeros((n1 * n2, n1 * n2))

    H_1_bl = H_1_tr.transpose()

    H_1 = np.block([[H_1_tl, H_1_tr], [H_1_bl, H_1_br]])
    return H_1 / m * np.sqrt(l / l0)

  H0, H1 = compute_H0(), compute_H1()
  H = H0 + H1

  return H0, H1, H

In [73]:
def run_gd(n0, n1, n2, m, sigma, num_iters):
  W1 = np.random.randn(n1, n0) * sigma
  W2 = np.random.randn(n2, n1) * sigma

  x = np.random.randn(n0, m) * sigma
  y = np.random.randn(n2, m) * sigma

  z = W1 @ x
  prop = np.maximum(z, np.zeros_like(z))
  y_hat = W2 @ prop
  e = y_hat - y

  step = 0.0005

  W1_gd, W2_gd, losses, e_gd, prop_gd = gradient_descent(W1, W2, x, y, num_iters, step, prop, e,1)

  return losses, W1_gd, W2_gd, x, y

In [74]:
def run_sim(n0, n1, n2, m, sigma, num_runs):
  num_iters = 5000
  H0_eigs_arr, H1_eigs_arr, H_eigs_arr = np.empty(0), np.empty(0), np.empty(0)
  phi = ((n0 + n2) * n1) / (n2 * m)
      
  loss_array = np.empty((num_runs, num_iters))
  for j in range(0, num_runs):
    losses, W1, W2, x, y = run_gd(n0, n1, n2, m, sigma, num_iters)

    # find hessians of W1, W2
    H0, H1, H = hessian(W1, W2, x, y, losses[-1], losses[0])
    eigs_H0, eigs_H1, eigs_H = npla.eigvalsh(H0), npla.eigvalsh(H1), npla.eigvalsh(H)

    H0_eigs_arr = np.concatenate((H0_eigs_arr, eigs_H0), axis=None)
    H1_eigs_arr = np.concatenate((H1_eigs_arr, eigs_H1), axis=None)
    H_eigs_arr = np.concatenate((H_eigs_arr, eigs_H), axis=None)
    
    loss_array[j] = losses
  return H0_eigs_arr, H1_eigs_arr, H_eigs_arr, loss_array

## Plot Definitions

In [75]:
def plot_losses(losses, phi, num_iters, num_runs):
  mean_losses = np.average(losses, axis = 0)
  std_losses = np.std(losses, axis=0)
  ci = 1.96 * std_losses / np.sqrt(num_runs)
  
  plt.yscale('log')
  plt.title('Loss Function throughout Gradient Descent')
  plt.ylabel('Loss')
  plt.xlabel('Epoch')

  plt.plot(np.arange(1, len(mean_losses) + 1), mean_losses, label='$\phi$ = '+ str(round(phi, 1)))
  plt.fill_between(np.arange(1, len(mean_losses) + 1), mean_losses + ci, mean_losses - ci, alpha=.2)
  plt.annotate(mean_losses[num_iters - 1], (num_iters - 1, mean_losses[num_iters - 1]))
  plt.legend()
  plt.show()

In [76]:
def plot_eigs(H0_eigs, H1_eigs, H_eigs):
  plt.title("H0's Eigenvalues")
  plt.ylabel('Density')
  plt.xlabel('Eigenvalue')
  plt.hist(H0_eigs[(abs(H0_eigs) < 100) & (abs(H0_eigs) > 0.01)], bins=50, density=True, ec='black')
  plt.show()
  
  plt.title("H1's Eigenvalues")
  plt.ylabel('Density')
  plt.xlabel('Eigenvalue')
  plt.hist(H1_eigs[(abs(H1_eigs) < 100) & (abs(H1_eigs) > 0.0001)], bins=50, density=True, ec='black')
  plt.show()

  plt.title("H's Eigenvalues")
  plt.ylabel('Density')
  plt.xlabel('Eigenvalue')
  plt.hist(H_eigs[(abs(H_eigs) < 100) & (abs(H_eigs) > 0.0001)], bins=50, density=True, ec='black')
  plt.show()

# Plotting

In [77]:
n0_widget = widgets.SelectionSlider(options=[10, 20, 50], 
                             value=20, 
                             description='n0', 
                             disabled=False, 
                             continuous_update=True, 
                             orientation='horizontal', 
                             readout=True)
n1_widget = widgets.SelectionSlider(options=[10, 20, 50], 
                             value=20, 
                             description='n1', 
                             disabled=False, 
                             continuous_update=True, 
                             orientation='horizontal', 
                             readout=True)
n2_widget = widgets.SelectionSlider(options=[10, 20, 50], 
                             value=20, 
                             description='n2', 
                             disabled=False, 
                             continuous_update=True, 
                             orientation='horizontal', 
                             readout=True)
m_widget = widgets.SelectionSlider(options=[10, 50, 200], 
                             value=50, 
                             description='m', 
                             disabled=False, 
                             continuous_update=True, 
                             orientation='horizontal', 
                             readout=True)
num_iters_widget = widgets.SelectionSlider(options=[0, 10, 100, 1000, 2500, 5000, 10000], 
                             value=2500, 
                             description='num_iters', 
                             disabled=False, 
                             continuous_update=True, 
                             orientation='horizontal', 
                             readout=True)

def plot_it(n0, n1, n2, m, num_iters):
  string = f'n0 = {n0}, n1 = {n1}, n2 = {n2}, m = {m}, num_iters = {num_iters}'
  print('updating the plot with ' + string)

  sigma = 1
  num_runs = 2
  H0_eigs, H1_eigs, H_eigs, losses = run_sim(n0, n1, n2, m, sigma, num_runs)

  phi = ((n0 + n2) * n1) / (n2 * m)
  plot_losses(losses, phi, num_iters, num_runs)
  plot_eigs(H0_eigs, H1_eigs, H_eigs)

interactive(plot_it, n0=n0_widget, n1=n1_widget, n2=n2_widget, m=m_widget, num_iters=num_iters_widget)

interactive(children=(SelectionSlider(description='n0', index=1, options=(10, 20, 50), value=20), SelectionSli…