In [None]:
import torch
import matplotlib.pyplot as plt
import non_local_boxes
import numpy as np

# Sugar coating for reloading
%matplotlib inline
%load_ext autoreload
%autoreload 2

# in ordert to have unblurred pictures
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

In [None]:
non_local_boxes.evaluate.nb_columns = int(1e3)

>This notebook plots the histograms of the ouputs of `Algorithms 2` and `3`. In the manuscript, it is useful espacially for Figure 9.
>
>Standard reference in Numerical Optimization: https://doi.org/10.1007/b98874

---
# I. Gradient Descent

### Definition

$$
\begin{array}{l}
    x_{k+1}= \texttt{proj}(x_k + \alpha \nabla \phi(x_k))
\end{array}
$$

In [None]:
projected_wiring = non_local_boxes.utils.projected_wiring

In [None]:
def gradient_descent(starting_W, P, Q, learning_rate, nb_iterations = 400, tolerance=1e-6):
    m = non_local_boxes.evaluate.nb_columns
    external_grad = torch.ones(m)
    W = starting_W
    for _ in range(nb_iterations):
        Wold = W
        non_local_boxes.evaluate.phi_flat(W, P, Q).backward(gradient=external_grad)
        W = projected_wiring(W + learning_rate*W.grad).detach() 
        if (torch.max(torch.abs(W-Wold)) < tolerance):   return W
        W.requires_grad=True
    return W

### Histogram of the Gradient Descent

In [None]:
PR = non_local_boxes.utils.PR
SR = non_local_boxes.utils.SR
I = non_local_boxes.utils.I

In [None]:
p=0.39
q=0.6
P = p*PR +q*SR + (1-p-q)*I
BoxProduct = non_local_boxes.evaluate.phi_flat

m = non_local_boxes.evaluate.nb_columns
alpha = 0.01   # 0.01
K=int(1e2)     # int(1e2)
epsilon=1e-6   # 1e-6

W = gradient_descent(
    starting_W=non_local_boxes.utils.random_wiring(m),
    P=P, Q=P,
    learning_rate=alpha, nb_iterations=K, tolerance=epsilon
)
histogramGD = BoxProduct(W, P, P).tolist()

plt.hist(histogramGD, bins=22, 
         label=f"""Projected Gradient Descent\n(α={alpha}, K=$10^{{{int(np.log10(K))}}}$, ε=$10^{{{int(np.log10(epsilon))}}}$, m=$10^{{{int(np.log10(m))}}}$).""")
plt.xlim(0.3, 1)
plt.ylim(0.9, m)
plt.xlabel("$\Phi(\mathsf{W}_{{out}})$", fontsize=13)
plt.ylabel("Number of reruns", fontsize=14)
plt.yscale("log")
plt.legend(fontsize=13)
plt.show()

---
# II. Line Search

### Definition

$$
\left\{
\begin{array}{l}
    \alpha^*_k = \argmax_\alpha \phi(x_k + \alpha \nabla \phi(x_k))\\
    x_{k+1}= \texttt{proj}(x_k + \alpha^*_k \nabla \phi(x_k))
\end{array}
\right.
$$

In [None]:
def reorder_list(L, phi):
    j=0
    while j<len(L):
        if j!=0 and phi[L[j-1]]<phi[L[j]]:
            L[j-1],L[j]=L[j],L[j-1]
            j-=2
        j+=1
    return L

# example of use:
phi=[0.1, 0.3, 0, 10, 9, 0.5]
L = [*range(len(phi))]
L = reorder_list(L, phi)
print([phi[k] for k in L])

In [None]:
def select_best_columns(W, P, Q, integer):
    if integer==0: return non_local_boxes.utils.random_wiring(m).detach()
    # L is the list of the "best" indexes of the columns of W
    # At the begining, we take the first indexes of W
    # We will change the list L by comparing the value at the other indexes
    # When we add a term to L, we also remove the "worst" one, and we re-order the list L
    L = [*range(integer)]
    # phi is the list of values:
    phi= non_local_boxes.evaluate.phi_flat(W,P,Q).tolist()
    # We re-order the list L:
    L = reorder_list(L, phi)
    for i in range(integer,non_local_boxes.evaluate.nb_columns):
        if phi[i]>phi[L[-1]]:
            L[-1]=i # We remove and replace the worst index
            L = reorder_list(L, phi)

    W_new = non_local_boxes.utils.random_wiring(m).detach()
    for k in range(integer): W_new[:,L[k]] = W[:,L[k]] # We keep only the best ones

    return W_new

In [None]:
def line_search_with_resets(P, Q, LS_iterations, K_reset, chi):
    # P,Q are 4x4 matrices
    m = non_local_boxes.evaluate.nb_columns
    phi_flat = non_local_boxes.evaluate.phi_flat
    W, external_grad = torch.zeros(32,m), torch.ones(m)
    Krange, LSrange = range(K_reset), range(LS_iterations)
    
    for j in range(0,int(1/chi)):
        # Reset some of the wirings:
        W = select_best_columns(W, P, Q, min(m, int(j*m*chi))).detach()
        W.requires_grad=True

        # At the end, we do a lot of steps:
        if j==int(1/chi)-1:  Krange=range(10*K_reset)

        # Line search:
        for _ in Krange:
            phi_flat(W, P, Q).backward(gradient=external_grad)
            gradient=W.grad
            alpha = torch.ones(m)*0.01
            Gains = phi_flat(W, P, Q)
            Gains_futur = phi_flat(W + alpha*gradient, P, Q)
            for _ in LSrange:
                mask = 0.0 + (Gains>Gains_futur)
                alpha = 0.8*mask*alpha + 1.3*(1-mask)*alpha
                Gains = torch.max(Gains, Gains_futur)
                Gains_futur = phi_flat(W + alpha*gradient, P, Q)
            W = projected_wiring(W + alpha*gradient).detach()
            W.requires_grad=True

    return W

### Histogram of the Line Search

In [None]:
p,q=0.39, 0.6
P = p*PR +q*SR + (1-p-q)*I
BoxProduct = non_local_boxes.evaluate.phi_flat

m = non_local_boxes.evaluate.nb_columns
LS_iterations = 10
K_reset=5
chi = 0.3

W=line_search_with_resets(
    P, 
    P, 
    LS_iterations=LS_iterations, 
    K_reset=K_reset, 
    chi=chi
    )
histogramLS = BoxProduct(W, P, P).tolist()

plt.hist(histogramLS, bins=15, color='purple', 
         label=f"""Line Search with resets\n($K_{{reset}}$={K_reset}, χ=${round(chi/10**int(np.log10(chi)-1))}\cdot10^{{{int(np.log10(chi))}}}$, m=$10^{int(np.log10(m))}$, M={LS_iterations}).""")
plt.xlim(0.3, 1)
plt.ylim(0.9, m)
plt.xlabel("$\Phi(\mathsf{W}_{{out}})$", fontsize=13)
plt.ylabel("Number of reruns", fontsize=14)
plt.yscale("log")
plt.legend(fontsize=12)
plt.show()

-----
# III. Comparison of GD and LS

In [None]:
plt.hist(histogramLS, bins=20, color='purple', label="Line Search with reruns ($K_{reset}$="+str(K_reset)+", χ="+str(chi)+", m=10^"+str(int(np.log10(m)))+", M="+str(LS_iterations)+")")
plt.hist(histogramGD, bins=20, label="Gradient Descent (α="+str(alpha)+", K=10^"+str(int(np.log10(K)))+", ε=10^"+str(int(np.log10(epsilon)))+", m=10^"+str(int(np.log10(m)))+")")
plt.xlim(0.3, 1)
plt.ylim(0.9, m)
plt.xlabel("$\Phi(\mathsf{W}_{{out}})$")
plt.ylabel("Number of reruns")
plt.yscale("log")
plt.legend()
plt.show()