In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import scipy as sp
import math
import copy
from matplotlib import pyplot as plt
from dataclasses import dataclass
from tqdm import tqdm

from test_sampler import TestProblem, TestProblem2
from optimize_algos import test_algo_by_problem, ACRCD, ACRCD_star, just_USTM

## TEST ACRCD

In [None]:
problem_1 = TestProblem(x_dim=10, y_dim=200, gamma=1e-3, La=1e-1, svxpy_verbose=True)
print(f"{problem_1.Lx=}, {problem_1.Ly=}")

In [None]:
x = np.zeros(problem_1.x_dim)
# x = np.ones(problem_1.x_dim)
y = np.zeros(problem_1.y_dim)
problem_1.calc(x, y)

In [None]:
# test_algo_by_problem(test_problem=problem_1, algo_func=ACRCD_star, L1_init=problem_1.Lx, L2_init=problem_1.Ly, k=500)
# test_algo_by_problem(test_problem=problem_1, algo_func=ACRCD, L1_init=50, L2_init=problem_1.Ly, k=int(3e3))
test_algo_by_problem(test_problem=problem_1, algo_func=ACRCD, L1_init=problem_1.Lx, L2_init=problem_1.Ly, k=int(3e3))

## TEST USTM

In [None]:
test_problem = problem_1

x0 = np.ones(test_problem.x_dim + test_problem.y_dim)
t_history, value_log, grad_history, A_log, (start_L, L_value) = \
    just_USTM(test_problem, x0,
              eps_abs=1e-7,
              max_iter=3000)

res_f, grad = test_problem.calc_by_one_block(t_history[-1])

print("start f val: ", value_log[0])
print("result val: ", res_f)
print("grad norm: ", np.linalg.norm(grad))
print("solver/analytic f*: ", test_problem.f_star)
print("start, end L: ", start_L, L_value)

plt.plot(grad_history, label='grad norm')
plt.yscale("log")
plt.legend()
plt.show()

plt.plot(np.array(value_log)-test_problem.f_star, label="func_value - f*")
plt.yscale("log")
plt.legend()
plt.show()

# Вторая задача (решается аналитически)
## ACRCD - правильные L

In [None]:
problem_2 = TestProblem2(La=100, Lb=10)

In [None]:
test_algo_by_problem(test_problem=problem_2, algo_func=ACRCD, L1_init=problem_2.La, L2_init=problem_2.Lb)

## ACRCD - неправильные L

In [None]:
test_algo_by_problem(test_problem=problem_2, algo_func=ACRCD, L1_init=problem_2.La/2, L2_init=problem_2.Lb)

## ACRCD* - неправильные L

In [None]:
test_algo_by_problem(test_problem=problem_2, algo_func=ACRCD_star, L1_init=problem_2.La/2, L2_init=problem_2.Lb)

# Первая задача (решается солвером)
# ACRCD

In [None]:
test_algo_by_problem(test_problem=problem_1, algo_func=ACRCD, L1_init=100, L2_init=100)

In [None]:
test_algo_by_problem(test_problem=problem_1, algo_func=ACRCD_star, L1_init=100, L2_init=100)