# Overview
This notebook is a short summary of the ***ProxSkip*** optimization algorithm introduced in the following ***[paper](https://proceedings.mlr.press/v162/mishchenko22b.html)***

# The problem statement
ProxSkip tackles the following class fo problems: 

$$ \min_{x \in \mathbb{R}^d} f(x) + \psi(x)$$

where $f: \mathbb{R}^d: \mathbb{R}$ is a smooth, convex function and $\psi: \mathbb{R}^d: \mathbb{R}$ is an expensive, non-smooth regularizer. 

Numerous applications can be represented in such setting: 

1. Signal Processing: Splitting a signal (a function) into a sum of functions with convex constraints: the constraints can be modeled as an indicator function across all sets [1](https://arxiv.org/pdf/0912.3522.pdf) 
2. Machine Learning: Decentralized / distributed training is crucial to train huge models. Consensus form is a mechanism to ensure that local solutions (in different devices) can be effectively leveraged to minimize the ***global objective*** 

# Prox Gradient Descent: The starting point
Such class of problems is generally sovled with the Proximal Gradient Descent:

$$ x_{t+1} = prox_{\gamma_t \psi}(x_t - \gamma_t \nabla f(x_t))$$

where the $prox$ operator is defined as: 

$$ prox_{\gamma \psi}(x) = \argmin_{y \in \mathbb{R}^d} [~\frac{1}{2} \| x - y \| ^ 2 + \gamma \cdot \psi(x)~]$$

Even though the proximity operator presents itself as a sub optimization problem, closed formulas have been developed for most standard and popular regularizers: such as  $\|x\|_1$ , $\|x\|_2$ ... 

However, since $\psi$ is generally non-smooth and not differentiable (at least not on its entire domain, take $\|x\|_1$ for example),the computation of the ***PROX OPERATOR*** can turn out quite computationally expensive. 

# Expensive Proxy Operators: 
## Inherently Computationally expensive

The proximity operator bridges the gap between constrained and unconstrained optimization where the problem: 

$$ 
\min_{x \in \mathbb{R}^d} f(x) \\
x \in X
$$

is formulized as: 

$$ \min_{x \in \mathbb{R}^d} f(x) + \psi(x)$$

where 

$$
\psi(x) =
\begin{equation*}
    \begin{cases}
    0 && \text{if $x \in X$} \\
    \infty && \text{if $x \not\in X$} \\
    \end{cases}
\end{equation*}
$$

This operation can represent a difficult optimization problem for a several complex sets.

## Expensive Communication-wise
The proximity operator emerges in decentralized training regime. Assuming $m$ devices, the global training objective is: 

$$ 
f(x) = \frac{1}{n} \sum_{i = 1} ^ n f_i(x)
$$

As explained in [2](https://web.stanford.edu/~boyd/papers/pdf/prox_algs.pdf) (chapter 5), the key to computational efficiency is to break the problem formulation into a sum of seperable problems: each problem can be then solved independently. However, the global objective should not be comprised and hence the addition of the consensus constraint: 

$$
\psi_C(x_1, x_2, ..., x_{n - 1}, x_n) =
\begin{equation*}
    \begin{cases}
    0 && \text{if $x_1 = x_2, ..., x_{n - 1} = x_n$} \\
    \infty && \text{otherwise} \\
    \end{cases}
\end{equation*}
$$

ensuring that local minimization (thus, computational gain) leads to the global minimum.

The solution to the proximity operator: 
$$
prox_{\gamma \psi_C}(x_1, x_2, ..., x_n) = \argmin_{y_1, y_2, .. y_n \in \mathbb{R}^d} [~\frac{1}{2} \sum_{i}^{n}\| x_i - y_i \| ^ 2 + \gamma \cdot \psi_C(y_i)~]
$$

is the average of $x_i$ which is theoretically straightforward. However, in Federated Learning, such a simple operation would require $O(n)$ communincations which can be prohibitly expensive mainly in the modern settings ($n$ is quite large.)

# ProxSkip: A provable solution:

The concensus constraint is just a single example of several other expensive constraints due to communication constraints. The federated Learning community has been seeking better local Gradient algorithms with lower communication rate that $ O(k \cdot \frac{1}{\epsilon})$ with no additional assumptions on data 
similarity of stronger smoothness assumptions. 

The authors of the paper introduce a version of the Prox Gradient Descent where the proximity operator is calculated $p$ times less frequently (on average) and ***Scaffnew*** an extension of this algorithm to distributed training settings without referring to any particular acceleration mechanisms.

Scaffnew achieves Linear Convergence rate : $$O(K \cdot \frac{1}{\epsilon})$$
and the theoretically optimal communication rate: 
$$O(\sqrt{K} \cdot \frac{1}{\epsilon})$$ 

The authors extend vanilla ProxSkip to Stochastic ProxSkip and Decentralized version.

# Convergence Analysis: The magic Explained


## Convergence Analysis
This section describes a sketch of the proof of the convergence and proxy operator rates: 

### Known facts and assumptions:
1. $f$ is convex and $\psi$ is proper convex closed regularizer
2. The assumption above implies that the problem has a unique solution denoted by $x^* = \argmin_{x \in \mathbb{R}^d} f(x) + \psi(x)$
3. Two other important implications are the following: 
    * $\| prox_f(x) − prox_f y \| ^ 2 + \|(x − prox_f (x)) − (y − prox_f (x)) \|^2 \leq \|x - y \| ^ 2$
    * $\forall \gamma$, $x^*$ satisfies $ x = prox_{f_1} (x  - \gamma \cdot \nabla f_2 (x))$

Introducing a bit of notation : 
* $h^{*} = \nabla f(x^*)$ 
* $P(.) = prox_{\gamma ~ \psi}(.)$
* $x = \hat{x}_{t+1} − \frac{γ}{p} \cdot h_t$
* $y = x^{*} - \frac{\gamma}{p} h^{*}$

1. First step let's rewrite the $x_{t+1}$ and $h_{t+1}$ in terms of $x_t$ and $h_t$:
$
x_{t + 1} = 
\begin{equation}
\begin{cases}
P(x) ~~  p \\
\hat{x}_{t+1} ~~ 1 - p
\end{cases}
\end{equation}
$

$
h_{t + 1} = 
\begin{equation}
\begin{cases}
h_t + \frac{p}{\gamma} (P(x) - \hat{x}_{t+1}) ~~ p \\
h_t ~~ 1 - p
\end{cases}
\end{equation}
$

The main result is: 

$$\mathbb{E}[\Psi(t)] = (1 - \xi)^{T} \Psi_{0}$$
where $\Psi(t) = \|x_t - x^{*} \| ^ 2 + \frac{\gamma ^ 2}{p^2} \|h_t - h^{*} \| ^ 2$ 

\begin{align*}
\mathbb{E}[\Psi(t + 1)] &= p (\|P(x) - x^{*} \| ^ 2 + \frac{\gamma ^ 2}{p^2} \|h_t + \frac{p}{\gamma} (P(x) - \hat{x}_{t + 1}) - h^{*} \| ^ 2) + (1 - p) \cdot (\|\hat{x}_{t + 1} - x^{*} \| ^ 2 + \frac{\gamma ^ 2}{p^2} \|h_t - h^{*} \| ^ 2) && \text{this can be written as}\\

\mathbb{E}[\Psi(t + 1)] &= p (\|P(x) - P(y) \| ^ 2 + \|P(x) - x + y - P(y) \| ^ 2) + (1 - p) \cdot (\|\hat{x}_{t + 1} - x^{*} \| ^ 2 + \frac{\gamma ^ 2}{p^2} \|h_t - h^{*} \| ^ 2)  && \text{algebric manipulation}\\

\mathbb{E}[\Psi(t + 1)] &= p (\|x - y \| ^ 2) + (1 - p) \cdot (\|\hat{x}_{t + 1} - x^{*} \| ^ 2 + \frac{\gamma ^ 2}{p^2} \|h_t - h^{*} \| ^ 2)  && \text{applying firm nonexpansiveness}\\
\end{align*}

recalling the definitions:

* $x = \hat{x}_{t+1} − \frac{γ}{p} \cdot h_t$
* $y = x^{*} - \frac{\gamma}{p} h^{*}$

then:


\begin{align*}
\mathbb{E}[\Psi(t + 1)] &\leq p (\|x - y \| ^ 2) + (1 - p) \cdot (\|\hat{x}_{t + 1} - x^{*} \| ^ 2 + \frac{\gamma ^ 2}{p^2} \|h_t - h^{*} \| ^ 2) \\
\mathbb{E}[\Psi(t + 1)] &\leq \|\hat{x}_{t + 1} - x^{*} \| ^ 2 - 2 \cdot \gamma <\hat{x}_{t + 1} - x^{*}, h_t - h^{*}> + \frac{\gamma ^ 2}{ p ^ 2} + \| h_t - h^{*} \| ^ 2 \\
\mathbb{E}[\Psi(t + 1)] &\leq \|\hat{x}_{t + 1} - x^{*} \| ^ 2 - 2 \cdot \gamma <\hat{x}_{t + 1} - x^{*}, h_t - h^{*}> + \gamma \cdot \|h_t - h^{*} \| ^ 2 + (\frac{\gamma ^ 2}{ p ^ 2} - \gamma)\| h_t - h^{*} \| ^ 2 \\

\mathbb{E}[\Psi(t + 1)] &\leq \|(\hat{x}_{t + 1} - h_t) - (x^{*} - h^{*}) \| ^ 2 + \frac{\gamma ^ 2}{ p ^ 2} \cdot (1 - p^2)\| h_t - h^{*} \| ^ 2 \\
\end{align*}


Using strong convexity and smoothness of $f$, we can estimate an upper bound for the first term

\begin{align*}
\|(\hat{x}_{t + 1} - h_t) - (x^{*} - h^{*}) \| ^ 2 &= \|x_{t} - x^{*} - \gamma(\nabla f(x_t) - \nabla f(x^*)) \| ^ 2\\
\|(\hat{x}_{t + 1} - h_t) - (x^{*} - h^{*}) \| ^ 2 &= \|x_{t} - x^{*}\| ^ 2 + \gamma ^ 2 \cdot \| \nabla f(x_t) - \nabla f(x^*) \| ^ 2 - 2\gamma <\nabla f(x_t) - \nabla f(x^*), x_t - x^{*}> \\
\|(\hat{x}_{t + 1} - h_t) - (x^{*} - h^{*}) \| ^ 2 & \leq (1 - \gamma \mu) \|x_{t} - x^{*}\| ^ 2 + \gamma ^ 2 \cdot \| \nabla f(x_t) - \nabla f(x^*) \| ^ 2 - 2\gamma D_f(x_t, x^*) && \text{using strong convexity} \\
\|(\hat{x}_{t + 1} - h_t) - (x^{*} - h^{*}) \| ^ 2 & \leq (1 - \gamma \mu) \|x_{t} - x^{*}\| ^ 2 - 2 \gamma ^  \cdot (2\gamma D_f(x_t, x^*) - \frac{\gamma}{2} \| \nabla f(x_t) - \nabla f(x^*) \| ^ 2 )\\

\|(\hat{x}_{t + 1} - h_t) - (x^{*} - h^{*}) \| ^ 2 & \leq (1 - \gamma \mu) \|x_{t} - x^{*}\| ^ 2 && \text{The last term is negative for $0 < \gamma < \frac{1}{L}$}\\
\end{align*}


Combining both intermediate results, we reach the main result of the paper:
\begin{align}
\mathbb{E}[\Psi(t + 1)] &\leq \|(\hat{x}_{t + 1} - h_t) - (x^{*} - h^{*}) \| ^ 2 + \frac{\gamma ^ 2}{ p ^ 2} \cdot (1 - p^2)\| h_t - h^{*} \| ^ 2 \\
\mathbb{E}[\Psi(t + 1)] &\leq (1 - \mu \gamma) \|x_t - x^{*}\|^2 + \frac{\gamma ^ 2}{ p ^ 2} \cdot (1 - p^2)\| h_t - h^{*} \| ^ 2 \\
\mathbb{E}[\Psi(t + 1)] &\leq (1 - \xi) (\|x_t - x^{*}\|^2 + \frac{\gamma ^ 2}{ p ^ 2} \cdot \| h_t - h^{*} \| ^ 2) && \text{$\xi = \min(\gamma \mu , p^2)$}\\
\mathbb{E}[\Psi(t + 1)] &\leq (1 - \xi) \Psi(t)\\
\end{align}

This equality proves the linear convergence rate of the ***ProxSkip*** method while proving that $h_t$ converges to $\nabla f (x^*)$


## Proxy operator rates
Using the convergence rate, we can say that for $T \geq \max(\frac{1}{\mu \gamma}, \frac{1}{p^2}) \log (\frac{1}{\epsilon})$, we have 

$$\mathbb{E}[\Psi(T)] \leq \epsilon \Psi(t)$$

Since the proxy operatory will be called $p \cdot T$ (on average) after $T$ iterations we can say that for

$$ p \cdot \max(\frac{1}{\mu \gamma}, \frac{1}{p^2}) \log (\frac{1}{\epsilon}) =  \max(\frac{p}{\mu \gamma}, \frac{1}{p}) \log (\frac{1}{\epsilon})$$
prox operator calls we have:

$$\mathbb{E}[\Psi(T)] \leq \epsilon \Psi(t)$$

The next step is to minimize the term $\max(\frac{p}{\mu \gamma}, \frac{1}{p})$ which reaches the minimum value for the maximum step size $\gamma = \frac{1}{L}$ and $\frac{p}{\mu \gamma} = \frac{1}{p}$ implying $p = \frac{1}{\sqrt{k}}$.

Thus, for $\gamma = \frac{1}{L}$, $p = \frac{1}{\sqrt{k}}$, The proxy operator rate is:
$$O(\sqrt{K} \frac{1}{\epsilon})$$ 


The authors apply standard techniques to prove similar rates for the case of the Stochastic and Federated Learning versions of the algorithm.