In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import scipy as sp
import math
import copy
from matplotlib import pyplot as plt
from dataclasses import dataclass
from tqdm import tqdm

from test_sampler import TestProblem, TestProblem2

In [None]:
def nesterov(x_0, grad, L, mu, K):
    x_cur = x_0
    y_cur = x_0
    x_list = [x_0]
    y_list = [x_0]
    for i in range(K):
        x_upd = y_cur - (1 / L) * (grad(y_cur, L))
        y_upd = x_upd + ((np.sqrt(L) - np.sqrt(mu)) / (np.sqrt(L) + np.sqrt(mu))) * (x_upd - x_cur)

        x_list.append(x_upd)
        y_list.append(y_upd)

        x_cur = x_upd
        y_cur = y_upd

    return x_list, y_list



In [None]:
# test_problem = TestProblem(gamma=1e5)
test_problem = TestProblem2(La=1000, Lb=10)

x1_star = - np.linalg.pinv(test_problem.A) @ test_problem.a
x2_star = - np.linalg.pinv(test_problem.B) @ test_problem.b
# less stable
# x_star = np.linalg.solve(test_problem.A, -test_problem.a)
# y_star = np.linalg.solve(test_problem.B, -test_problem.b)

f_star = test_problem.calc(x1_star, x2_star)[0]
def f(x1, x2):
    return x1.T @ test_problem.A @ x1 / 2 + test_problem.a @ x1 + x2.T @ test_problem.B @ x2 / 2 + test_problem.b @ x2 

f_star, f(x1_star, x2_star)

In [None]:
# ACRCD
history = []
grad_x_norms = []
grad_y_norms = []

# y (paper) = q(code_)

def ACRCD(x_0, y_0, K):
    x_list = [x_0]
    y_list = [y_0]

    z1_cur = x_0
    z2_cur = y_0

    q1_cur = x_0
    q2_cur = y_0

#     L1 = L2 = 200
    L1 = test_problem.La
    L2 = test_problem.Lb
    beta = 1 / 2

    n_ = L1 ** beta + L2 ** beta

    # q_cur_block (code) = y (paper)
    # z_cur_block (code) = z (paper)
    for i in tqdm(range(K)):

        #####  redefine alpha, tau
        alpha = (i + 2) / (2 * n_ ** 2)
        tau = 2 / (i + 2)

        x1_upd = tau * z1_cur + (1 - tau) * q1_cur
        x2_upd = tau * z2_cur + (1 - tau) * q2_cur

        # test_problem.x = torch.tensor(x_upd, requires_grad=True)
        # test_problem.y = torch.tensor(y_upd, requires_grad=True)
        result, grad_x, grad_y = test_problem.calc(x1_upd, x2_upd)
        history.append(result.item())
        grad_x_norms.append(np.linalg.norm(grad_x))
        grad_y_norms.append(np.linalg.norm(grad_y))
        #         print(result, torch.norm(grad_x), torch.norm(grad_y))

        index_p = np.random.choice([0, 1], p=[L1 ** beta / n_,
                                              L2 ** beta / n_])

        if index_p == 0:
            q1_upd = x1_upd - (1 / L1) * grad_x
            q2_upd = q2_cur

            z1_upd = z1_cur - (1 / L1) * alpha * n_ * grad_x
            z2_upd = z2_cur


        if index_p == 1:
            q1_upd = q1_cur
            q2_upd = x2_upd - (1 / L2) * grad_y

            z1_upd = z1_cur
            z2_upd = z2_cur - (1 / L2) * alpha * n_ * grad_y

        x_list.append(x1_upd)
        y_list.append(x2_upd)

        z1_cur = z1_upd
        z2_cur = z2_upd

        q1_cur = q1_upd
        q2_cur = q2_upd

    return x_list, y_list

# x0 = np.random.random(test_problem.na)
# y0 = np.random.random(test_problem.nb)

x0 = np.zeros(test_problem.na)
y0 = np.zeros(test_problem.nb)

x_list_ACRCD, y_list_ACRCD = ACRCD(x0, y0, 20000)

# plt.plot(torch.log(torch.tensor(history)))

plt.plot(torch.tensor(history) - f_star, label='func')
plt.plot(torch.tensor(grad_x_norms), label='x grad norm')
plt.plot(torch.tensor(grad_y_norms), label='y grad norm')
plt.yscale("log")
plt.legend()
plt.show()

In [None]:
res_x, *gradients_x = test_problem.calc(x_list_ACRCD[-1], y_list_ACRCD[-1])
res_x

In [None]:
history = []
grad_x1_norms = []
grad_x2_norms = []

# y (paper) = q(code_)
count_one, count_two = 0, 0

def ACRCD_star(x1_0, x2_0, K):
    global count_one, count_two
    ADAPTIVE_DELTA = 1e-8

    x1_list = [x1_0]
    x2_list = [x2_0]

    z1 = y1 = x1_0
    z2 = y2 = x2_0

    L1 = L2 = 5000
    beta = 1 / 2

    for i in tqdm(range(K)):
        tau = 2 / (i + 2)

        x1 = tau * z1 + (1 - tau) * y1
        x2 = tau * z2 + (1 - tau) * y2

#         result, grad_x1, grad_x2 = test_problem.calc(x1, x2)

        res_x, *gradients_x = test_problem.calc(x1, x2) # moved out of the inner loop
        history.append(res_x.item())
        grad_x1_norms.append(np.linalg.norm(gradients_x[0]).item())
        grad_x2_norms.append(np.linalg.norm(gradients_x[1]).item())

        n_ = L1 ** beta + L2 ** beta
        index_p = np.random.choice([0, 1], p=[L1 ** beta / n_,
                                              L2 ** beta / n_])
        Ls = [L1, L2]
        Ls[index_p] /= 2

        # ADAPTIVE

        inequal_is_true = False
        xs = [x1, x2]
        sampled_gradient_x = gradients_x[index_p]
        # while not inequal_is_true:
        for j in range(100):
            if index_p == 0:
                count_one += 1
                y1 = xs[index_p] - 1 / Ls[index_p] * sampled_gradient_x
                y2 = x2
            else:
                count_two += 1
                y2 = xs[index_p] - 1 / Ls[index_p] * sampled_gradient_x
                y1 = x1
                
            res_y, *_ = test_problem.calc(y1, y2)
                
            inequal_is_true = 1 / (2 * Ls[index_p]) * np.linalg.norm(sampled_gradient_x) ** 2 <= res_x - res_y + ADAPTIVE_DELTA
#             y_minus_x = ([y1, y2][index_p] - xs[index_p])
#             inequal_is_true = (res_y - res_x - gradients_x[index_p] @ y_minus_x 
#                                 <= Ls[index_p] * (y_minus_x ** 2).sum() / 2 + ADAPTIVE_DELTA)
            if inequal_is_true: break
            Ls[index_p] *= 2
#             if Ls[index_p] > 4 * [test_problem.La, test_problem.Lb][index_p]:
#                 print(i, j, index_p, Ls)
#                 print(res_y, res_x)

        L1, L2 = Ls
        n_ = L1 ** beta + L2 ** beta
        alpha = (i + 2) / (2 * n_ ** 2)

        if index_p == 0:
            z1 = z1 - (1 / L1) * alpha * n_ * sampled_gradient_x

        if index_p == 1:
            z2 = z2 - (1 / L2) * alpha * n_ * sampled_gradient_x

        x1_list.append(x1)
        x2_list.append(x2)

    return x1_list, x2_list, [L1, L2]


# np.random.seed(228)
x1_list_ACRCD, x2_list_ACRCD, Ls = ACRCD_star(np.random.random(test_problem.na), np.random.random(test_problem.nb), 20000)


In [None]:
plt.plot(torch.tensor(history) - f_star, label='func')
plt.plot(torch.tensor(grad_x1_norms), label='x1 grad norm')
plt.plot(torch.tensor(grad_x2_norms), label='x2 grad norm')
plt.yscale("log")
plt.legend()
plt.show()

In [None]:
res_x, *gradients_x = test_problem.calc(x_list_ACRCD[-1], y_list_ACRCD[-1])
res_x

In [None]:
# gamma=1e5 [1250.0, 10000.0]
# gamma=1e-5 [2500.0, 10000.0]
