In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# How to use the state evolution package

Let's look at a simple example of how to use the state evolution package with custom teacher-student covariance rmrices. The class has three components:
- `data_model`: this class defines everything concerning the generative model for data - i.e. it initialises the covariances $\Psi, \Phi, \Omega$ and pre-computes all the quantities required for the state evolution.
- `model`: this class defines the task. It basically contains the updates for the overlaps and their conjugates. So far, we have implemented ridge and logistic regression.
- `algorithms`: this class defines the iterator for the state evolution.

In [2]:
from state_evolution.models.logistic_regression import LogisticRegression # logistic regression task
from state_evolution.algorithms.state_evolution import StateEvolution # Standard SP iteration

## Example 1: custom data model

Let's look at a simple example where we input the covariances.

In [3]:
from state_evolution.data_models.custom import Custom # Custom data model. You input the covariances

Recall that the G$^3$M is defined by a teacher-student model with:
- Teacher : $y = f^{0}(\theta^{0}\cdot z)$, $\theta^{0}\sim\mathcal{N}(0,\rm{I}_{k})$
- Student : $\hat{y} = \hat{f}(w\cdot x)$
where $z\in\mathbb{R}^{k}$ and $x\in\mathbb{R}^{p}$ are jointly Gaussian variables with covariances
$$ \Psi = \mathbb{E}zz^{\top}\in\mathbb{R}^{k\times k}, \qquad \Phi = \mathbb{E}xz^{\top}\in\mathbb{R}^{p\times k}, \qquad \Omega = \mathbb{E}xx^{\top}\in\mathbb{R}^{p\times p}
$$.

The class `Custom` takes as input the three covariance matrices that define a G$^3$M. 

As an example, let's look at a simple model where both the teacher and student come from a hidden-manifold model with different dimensions and activation functions:
$$
z = \rm{sign}\left(\frac{1}{\sqrt{d}}\bar{\rm{F}}c\right), \qquad x = \rm{erf}\left(\frac{1}{\sqrt{d}}\rm{F}c\right), \qquad c\sim\mathcal{N}(0,\rm{I}_{d})
$$

In this case, recall that the covariances can be computed analytically, and are given by:

 \begin{align}
 \Psi = \bar{\kappa}_{1}^2 \bar{\rm{F}}\bar{\rm{F}}^{\top}+\bar{\kappa}_{\star}^2\rm{I}_{k}, && \Phi = \bar{\kappa}_{1}\kappa_{1} \rm{F}\bar{\rm{F}}^{\top}, && \Omega = \kappa_{1}^2 \rm{F}\rm{F}^{\top}+\kappa_{\star}^2\rm{I}_{p}
 \end{align}
 
with $\kappa_{1} \equiv \mathbb{E}\left[\xi\sigma(\xi)\right]$ and $\kappa_{\star}^2 \equiv \mathbb{E}\left[\sigma(\xi)\right]^2-\kappa_{1}^2$ for $\xi\sim\mathcal{N}(0,1)$ (idem for the bar). 

In [4]:
COEFICIENTS = {'relu': (1/np.sqrt(2*np.pi), 0.5, np.sqrt((np.pi-2)/(4*np.pi))), 
               'erf': (0, 2/np.sqrt(3*np.pi), 0.200364), 'tanh': (0, 0.605706, 0.165576),
               'sign': (0, np.sqrt(2/np.pi), np.sqrt(1-2/np.pi))}

In [5]:
d = 1000 # dimension of c
p = 2000 # dimension of x
k = 1000 # dimension of k

F_teacher = np.random.normal(0,1, (d,k)) / np.sqrt(d) # teacher projection
F_student = np.random.normal(0,1, (d,p)) / np.sqrt(d) # student proojection

# Coefficients
_, kappa1_teacher, kappastar_teacher = COEFICIENTS['sign']
_, kappa1_student, kappastar_student = COEFICIENTS['erf']

# Covariances
Psi = (kappa1_teacher**2 * F_teacher.T @ F_teacher + kappastar_teacher**2 * np.identity(k))
Omega = (kappa1_student**2 * F_student.T @ F_student + kappastar_student**2 * np.identity(p))
Phi = kappa1_teacher * kappa1_student * F_student.T @ F_teacher

Now that we have our covariances, we can create our instance of `Custom`:

In [6]:
data_model = Custom(teacher_teacher_cov = Psi, 
                    student_student_cov = Omega, 
                    teacher_student_cov = Phi)

Now, we need to load our task. Let's look at logistic regression. The `model` class takes as an input the sample complexity $\alpha = n/p$ and the $\ell_2$ regularisation $\lambda>0$

In [7]:
task = LogisticRegression(sample_complexity = 0.5,
                          regularisation= 0.01,
                          data_model = data_model)

All that is left is to initialise the saddle-point equation iterator:

In [8]:
sp = StateEvolution(model = task,
                    initialisation = 'uninformed',
                    tolerance = 1e-7,
                    damping = 0.5,
                    verbose = True,
                    max_steps = 1000)

Now, we can simply iterate it

In [9]:
sp.iterate()

t: 0, diff: 478.0152120783315, self overlaps: 0.04902775035219596, teacher-student overlap: 0.05419850145465994
t: 1, diff: 240.006195014993, self overlaps: 0.15233031793698482, teacher-student overlap: 0.11537216518750265
t: 2, diff: 121.55004480881897, self overlaps: 0.3688979585546957, teacher-student overlap: 0.1936650297246116
t: 3, diff: 62.83227770088185, self overlaps: 0.7488293041269313, teacher-student overlap: 0.2898500244839212
t: 4, diff: 33.70665842750992, self overlaps: 1.2894607255376267, teacher-student overlap: 0.3963501023184458
t: 5, diff: 19.018559131874653, self overlaps: 1.932899112630993, teacher-student overlap: 0.50217971757138
t: 6, diff: 11.326343276129812, self overlaps: 2.6103030802214198, teacher-student overlap: 0.5988796957152782
t: 7, diff: 7.084365648846523, self overlaps: 3.2725934923209956, teacher-student overlap: 0.6824724789727779
t: 8, diff: 4.614101813960866, self overlaps: 3.891642914962431, teacher-student overlap: 0.7522656020377083
t: 9, di

Voila, now you can check the result with method `get_info`, which gives everything you might be interested in a dictionary.

In [13]:
sp.get_info()

{'hyperparameters': {'initialisation': 'uninformed',
  'damping': 0.5,
  'max_steps': 1000,
  'tolerance': 1e-07},
 'status': 1,
 'convergence_time': 62,
 'test_error': 0.37660906800285854,
 'train_loss': 0.10335060611232535,
 'overlaps': {'variance': 19.730186194816874,
  'self_overlap': 7.070624554079193,
  'teacher_student': 1.00514853830573}}

In [23]:
sp.model.get_info()

{'model': 'logistic_regression', 'sample_complexity': 0.5, 'lambda': 0.01}

In [24]:
sp.model.data_model.get_info()

{'data_model': 'custom', 'teacher_dimension': 1000, 'student_dimension': 2000}