In [None]:
import numpy as np
from numpy.linalg import norm
from numpy.random import normal, choice
from numpy.typing import NDArray
from typing import Callable

## Стохастический градиентный спуск RMSProp

In [30]:
def SGD_RMSProp(
    start: NDArray,
    X: NDArray,
    y: NDArray,
    L: Callable,
    L_grad: Callable,
    learning_rate: float = 0.01,
    batch_size: int = 64,
    decay_rate: float = 0.5,
    max_iter=1000,
    tol=1e-7,
    **kwargs
) -> dict:
    curr_point = start
    W_error = None
    run_avg = np.zeros(np.size(start))
    curr_iter = 0
    while W_error is None or (curr_iter < max_iter and W_error >= tol):
        idx = choice(X.shape[0], batch_size, replace=False)
        batch_X, batch_y = X[idx, :], np.array([y[idx]]).reshape(idx.shape)

        curr_value = L(curr_point, batch_X, batch_y, **kwargs)
        curr_grad = L_grad(curr_point, batch_X, batch_y, **kwargs)
        run_avg = decay_rate * run_avg + (1 - decay_rate) * curr_grad**2

        curr_point -= learning_rate / np.sqrt(run_avg) * curr_grad
        W_error = norm(learning_rate * curr_grad)
        curr_iter += 1

    return {
        "point": curr_point,
        "L_value": curr_value,
        "grad_value": curr_grad,
        "iterations": curr_iter,
    }

### Тест SGD RMSPror

In [34]:
def L(w, X, y):
    X_tmp = np.hstack([X, np.ones((y.size, 1))])
    return norm(X_tmp.dot(w) - y) ** 2 / y.size


def L_grad(w, X, y):
    X_tmp = np.hstack([X, np.ones((y.size, 1))])
    return 2 * X_tmp.T.dot(X_tmp.dot(w) - y) / y.size

np.random.seed(42)
nrow, ncol = 500, 4
X = normal(0, 1, ncol * nrow).reshape(nrow, ncol)
true_w = np.array([2, -3, 1, 0.5, 4])
y = np.hstack([X, np.ones((nrow, 1))]).dot(true_w) + normal(0, 1, nrow)
w_start = normal(0, 1, ncol + 1)

sgd_rmsprop_res = SGD_RMSProp(start = w_start, X = X, y = y, L = L, L_grad = L_grad, batch_size=100)

In [35]:
print(f'Iterations: {sgd_rmsprop_res["iterations"]}')
print(f'||w_e-w_t||^2 = {norm(sgd_rmsprop_res["point"] - true_w) ** 2}')

Iterations: 1000
||w_e-w_t||^2 = 0.008405491722783303
