In [1]:
%load_ext autoreload
%autoreload 2
import sys

sys.path.append('/home/iaamini/Documents/ML_practice/ML_fundamentals/ML_fundamentals/algorithms/')

# Recherche Linéaire


On cherche une valeur maximale admissible du pas d'apprentissage en allant selon un direction de descente $p_t$ (vérifiant $p_t^T\nabla\mathcal{L}(w)$.
On suit toujours un règle de mise à jour avec la condition de décroissance de la fonction de coût.
$$\forall{t} \in \mathbb{N}, \hat{\mathcal{L}}(w^{(t+1)}) < \hat{\mathcal{L}}(w^{(t)}) $$

L'algorithme de recherche linéaire consiste à vérifier des conditions que l'on appelle *conditions de Wolfe*.

### Condition d'Armijo

Permet de répondre aux deux situations où l'équation ci-dessus peut être satisfaite sans pour autant atteindre le minimiseur de $\mathcal{L}$.

$$\forall{t} \in \mathbb{N}, \hat{\mathcal{L}}(w^{(t)} + \eta_tp_t) \leq \hat{\mathcal{L}}(w_t) + \alpha\eta_tp_t^{T}\nabla\hat{\mathcal{L}}(w_t)$$

la contrainte de décroissance linéaire implique que le taux de décroissance allant $\hat{\mathcal{L}}(w^t)$ à $\hat{\mathcal{L}}(w^{t+1})$ ne doit pas être plus grand que la descente pondéré par un coefficient alpha.

### Condition de courbure

Cette condition implique que la descente lors de l'itération suivante soit au moins égale à une fraction $\beta \in (\alpha, 1)$. Donc:
$$ \forall{t} \in \mathbb{N}, p^T_t\nabla\hat{\mathcal{L}}(w^{(t)} + \eta_tp_t) \geq \beta p^T_t\nabla\hat{\mathcal{L}}(w^{(t)})$$

In [2]:
import numpy as np
from sklearn.datasets import make_classification
import logistic_regression as lg

In [3]:
def ls_armijo(w, f, fval0, pk, pente, args = (), alpha=1e-3, eta0=1, eta_min=0):
    w_new = w + eta0*pk
    fval_eta0 = f(w_new, X, y)
    print(fval_eta0)
    
#     print(fval0 + alpha*eta0*pente)
    
    if fval_eta0 <= fval0 + alpha*eta0*pente:
        print("returning eta0 at line 9")
        return eta0, fval_eta0
    # Sinon on calcul le minimiseur de l'interpolant quadratic
    eta1 = -(pente) * eta0**2 / 2.0 * (fval_eta0 - fval0 - pente * eta0)
    w_new = w + eta1*pk
    fval_eta1 = f(w_new, X, y)
    
    if fval_eta1 <= fval_eta0 + alpha*eta1*pente:
        print("returning eta1 at line 16")
        return eta1, fval_eta1
    
    while eta1 > eta_min:
        factor = 1 / eta1 - eta0
        
        coeff1 = fval_eta1 - fval0 - eta1 * pente
        coeff2 = fval_eta0 - fval0 - eta0 * pente

        a = (coeff1/(eta0**2) - coeff2/(eta1**2)) / (eta0 - eta1)
        b = (-eta1*coeff1/eta0**2)+eta0*coeff2/(eta1**2)/(eta0-eta1)

        # Calcul des coefficients du polynôme d'interpolation de degré 3 (Eq. 2.33)
#         a, b = factor * np.dot(mat1, mat2).flatten()
        if a != 0:
            delta = b**2 - 3*a*pente
            if delta >= 0:
                eta2 = (-b + np.sqrt(delta)) / 3*a
            else:
                raise ValueError("ls_armijo:problème d'interpolation")
        else:
            eta2 = eta1/2
            
        w_new= w + eta2*pk
        fval_eta2 = f(w_new, X, y)
        
        if fval_eta2 <= fval_eta0 + alpha*eta2*pente:
            print("returning eta2 at line 44")
            return eta2, fval_eta2
        if (eta1 - eta2) > eta1 / 2.0:
            eta2 = eta2 / 2.0
        eta0 = eta1
        eta1 = eta2
        fval_eta0 = fval_eta1
        fval_eta1 = fval_eta2
        
    return None, fval_eta1  
    

In [4]:
def line_search(wk, f, pk, gfk, old_fval, args = (),  alpha=1e-4, eta0=1, eta_min=0):
    """Temporary docstring for the line search
    X: array_like, matrix of data
    y: array_like, true labels
    w: array_like, weight vector
    f: callable, cost function
    pk: array_like, search direction
    gfk: array_like, gradient vector
    old_fval: float, old value of function
    alpha: float, constant coefficient
    eta0: float, int, initial step length
    eta_min: float, int minimum value for step length
    """
#     xk = np.atleast_1d(xk)
    fc = [0]

    def phi(eta1):
        fc[0] += 1
        return f(w + eta1*pk, *args)

    if old_fval is None:
        fval0 = phi(0.)
    else:
        fval0 = old_fval  # compute f(xk) -- done in past loop
    
    # Calcul de la pente
    pente = np.dot(pk, gfk)
    
    eta, phi1 = ls_armijo(w=w, f=f, fval0=fval0,pk=pk, pente=pente, args=args, alpha=alpha,
                            eta0=eta0)
    return alpha, phi1
    

    


In [5]:
def logistic(x):
    return 1 / (1 + np.exp(-x))

def stable_logistic(x):
    """Logistic function (for a scalar valued argument).

    A numerically stable implementation.
    """
    if x > 0:
        return 1/(1 + np.exp(-x))
    else:
        return np.exp(x)/(1 + np.exp(x))

stable_logistic = np.vectorize(stable_logistic)    

def logistic_surrogate_loss(w, X, y):
    # Computing the dot product
    n, d = X.shape
#     print(n)
#     print(d)
    ps = np.dot(X, w[:-1]) + w[-1]
#     print(ps)
    yps = y * ps
    loss = np.where(yps > 0,
                   np.log(1 + np.exp(-yps)),
                   (-yps + np.log(1 + np.exp(yps))))
    loss = loss.sum() / n
    return loss

def gradient_log_surrogate_loss(w, X, y):
    # defining dim variables
    n, d = X.shape
    z = X.dot(w[:-1]) + w[-1]
    z = stable_logistic(y * z)
    z0 = (z - 1) * y
    
    # initiating g: gradient vector
    g = np.zeros(d+1)
    # Computing dot product
#     ps = (np.dot(X, w[1:])) + w[0]
    g[:-1] = X.T.dot(z0)
    g[-1] = z0.sum()
    g /= n
    return g


In [7]:
X, y = make_classification(n_samples=100, n_features=10)
w = np.random.random(X.shape[1]+1)
loss = logistic_surrogate_loss(w, X, y)
g = gradient_log_surrogate_loss(w, X, y)
p = -g

In [8]:
loss

0.8791918395941279

In [10]:
line_search(w, f=logistic_surrogate_loss, args=(X, y), pk=p, gfk=g, old_fval=loss)

0.6535967361921888
returning eta0 at line 9


(0.0001, 0.6535967361921888)

In [12]:
from scipy.optimize import line_search as spls
spls(f=logistic_surrogate_loss, myfprime=gradient_log_surrogate_loss, xk = w, old_fval=loss, pk=p, args=(X, y))

(1.0,
 1,
 1,
 0.6535967361921888,
 0.8791918395941279,
 array([ 0.04763232,  0.06916227,  0.04938514,  0.05991375, -0.19110761,
        -0.01601373,  0.12528538, -0.0107462 ,  0.03594911,  0.00514419,
        -0.20247957]))