In [None]:
import matplotlib.pyplot as plt
import numpy as np
import math
import pandas as pd
import random

import plotly.express as px
import plotly.graph_objects as go

import sys; sys.path.append('../src/')
import one_dim_search.dichotomy as dichot
import one_dim_search.fib as fib
import one_dim_search.linear as lin
import one_dim_search.golden as gold
from descent.grad import gradient_descent_iter, get_constant_step_chooser

In [None]:
def gen_matrix(*, cond: float, n: float):
    min_v = 1
    max_v = cond
    assert max_v / min_v == cond
    diag = [min_v, max_v]
    while len(diag) < n:
        diag.append(random.uniform(min_v, max_v))
    return diag
    
def create_sq_fun(diag):
    def f(arg):
        return np.sum(diag * arg ** 2)
    return f
    
def create_sq_fun_grad(diag):
    def f(arg):
        return diag * arg * 2
    return f

In [None]:
d = gen_matrix(cond=5, n=3)
f = create_sq_fun(d)
f_grad = create_sq_fun_grad(d)

In [None]:
def analyze(step_chooser):
    data = []
    for cond in np.linspace(1, 1000, 50):
        for n in [5, 10, 20, 50, 100]:
            diag_m = gen_matrix(cond=cond, n=n)
            f = create_sq_fun(diag_m)
            f_grad = create_sq_fun_grad(diag_m)
            it = gradient_descent_iter(
                f=f, f_grad=f_grad, eps=1e-3,
                start=np.random.randn(n),
                step_chooser=step_chooser,
                _verbose=100000
            )
            data.append({
                'cond': cond,
                'n': n,
                'cnt': sum(1 for _ in it)
            })
    fig = px.line(data, x='cond', y='cnt', color='n')
    fig.show()

In [None]:
analyze(get_constant_step_chooser(1e-4))

In [None]:
def generic_step_chooser(one_dim_search):
    def step_chooser(f, x_k, cur_grad):
        phi = lambda h: f(x_k - h * cur_grad)
        l, r = lin.search(0, delta=1e-5, f=phi, eps=1e-5, multiplier=2)
        l, r = one_dim_search(l, r, f=phi, eps=1e-2)
        return (l + r) / 2
    return step_chooser

analyze(generic_step_chooser(gold.search))