In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# warnings are not important :)
import warnings
warnings.filterwarnings('ignore')

In [3]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [4]:
# import graph module
import sys
sys.path.append("../../")
from tools.graph_tools import *
from oracles.minimization import *
from methods.gradient_tracking import *

In [5]:
def calc_delta_Fx(oracles, x):
    res = []
    for i in range(len(oracles)):
        res.append( oracles[i].grad(np.array([x[i,j] for j in range(2)])) )
        #print(oracles[i].grad(np.array([x[i,j] for j in range(2)])))
    return np.matrix(res)

def calc_F(oracles, x):
    res = []
    for i in range(len(oracles)):
        res.append( oracles[i].func(np.array([x[i,j] for j in range(2)])) )
    return np.matrix(res)

def calc_error(oracles, x_curr, x_prev):
    return np.sum( abs(calc_F(oracles, x_curr) - calc_F(oracles, x_prev)) )  / len(oracles)

def _calc_error(oracles, x_curr, x_prev):
    return np.sum( abs(x_curr - np.array([1, 2])) ) / len(oracles)

In [6]:
N = 40
x = np.array([1, 2])
oracles = []
for i in range(N):
    A = np.random.random((2, 2))
    b = A.dot(x)
    oracles.append(LinearRegressionL2Oracle(A, b, regcoef=0.01))

In [7]:
W = make_random_graph_matrix(N, 0.2)
W = fill_metropolis_weigts(W)
#make_graph_img(W, fig_size=(20, 20))

In [8]:
alpha = 0.01
theta = 1
mu = 0

err = 0.01

In [9]:
x0 = np.full((N, 2), [0.5, 1])

In [10]:
res = gradient_tracking(F = oracles,
                 f = oracles,
                 calc_delta_Fx = calc_delta_Fx,
                 calc_error = _calc_error,
                 W = W,
                 x0 = x0,
                 alpha = alpha,
                 theta = theta,
                 mu = mu,
                 err = err,
                 max_iter = 3000,
                 need_log = False)

In [11]:
for i in range(N):
    print('i:', i, "func:", oracles[0].func(np.array([res[i,j] for j in range(2)])))

i: 0 func: 0.02373243222999908
i: 1 func: 0.023732432214340727
i: 2 func: 0.023732432223005535
i: 3 func: 0.023732432242692895
i: 4 func: 0.02373243222510876
i: 5 func: 0.02373243221735188
i: 6 func: 0.023732432221662488
i: 7 func: 0.02373243224010466
i: 8 func: 0.02373243222079018
i: 9 func: 0.023732432238224178
i: 10 func: 0.023732432230041435
i: 11 func: 0.023732432218155068
i: 12 func: 0.023732432225688552
i: 13 func: 0.02373243221836101
i: 14 func: 0.02373243223963679
i: 15 func: 0.023732432234557264
i: 16 func: 0.023732432229324835
i: 17 func: 0.023732432217936885
i: 18 func: 0.02373243221106548
i: 19 func: 0.02373243222548318
i: 20 func: 0.023732432216252007
i: 21 func: 0.023732432216473337
i: 22 func: 0.023732432230964075
i: 23 func: 0.023732432260927132
i: 24 func: 0.023732432223350374
i: 25 func: 0.023732432232654848
i: 26 func: 0.02373243222504053
i: 27 func: 0.023732432229250075
i: 28 func: 0.023732432217397768
i: 29 func: 0.023732432235328213
i: 30 func: 0.0237324322290741