In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import scipy
import matplotlib.pyplot as plt
from pathlib import Path
from read_data import get_network_df, get_corrs, get_full_df
from numba.core import types
import graph_tool
import numba

from numba import njit
from graph_tool.topology import shortest_distance
import pandas as pd
import networkx as nx
import graph_tool as gt
from graph_tool.topology import shortest_distance
import numpy as np
import typing

import numba
from numba.core import types


from sinkhorn import Sinkhorn
from transport_problem import OptimParams, DualOracle, HyperParams

In [None]:
import networkx as nx
import graph_tool as gt
from graph_tool.topology import shortest_distance
import numpy as np
import typing

import numba
from numba.core import types
from tqdm import tqdm

from transport_problem import OptimParams, DualOracle, HyperParams


# graph = None # TODO создавать граф
# oracle = None # TODO - создать оракла
# sources, targets = None, None    # определять sources и targets
# oracle_stacker = OracleStacker(oracle, graph, sources, targets)


class OracleSinkhornStacker:
    def __init__(self, oracle: DualOracle, graph, sources, targets, l, w, params):
        self.oracle = oracle
        self.graph = graph
        self.sources = sources
        self.targets = targets

        self.T_LEN = oracle.edges_num
        self.LA_LEN = oracle.zones_num
        self.MU_LEN = oracle.zones_num

        # размер вектора параметров [t, la, mu]
        self.parameters_vector_size = self.T_LEN + self.LA_LEN + self.MU_LEN

        self.t = oracle.t_bar.copy()
        self.la = np.zeros(oracle.zones_num)
        self.mu = np.zeros(oracle.zones_num)
        self.optim_params = OptimParams(self.t, self.la, self.mu)
        
        self.sinkhorn = Sinkhorn(l, w, max_iter=100)
        self.params = params

    def __call__(self, vars_block, *args, **kwargs):
        """
        :param vars_block: все оптимизируемые переменные stack[t, la, mu]
        :return:
        dual_value -  значение двойстенной функции для t, la, mu
        full_grad - градиент, stack[t_grad, la_grad, mu_grad]
        flows_averaged -  потоки при данных t (f)
        """
        print("vars block grad: ", np.linalg.norm(vars_block))
        assert len(vars_block) == self.T_LEN

        self.optim_params.t = vars_block
        print("t in optim params grad: ", np.linalg.norm(self.optim_params.t), np.linalg.norm(vars_block[:self.T_LEN]))
        print("la in optim params grad: ", np.linalg.norm(self.optim_params.la))
        print("mu in optim params grad: ", np.linalg.norm(self.optim_params.mu))

        T, pred_maps = self.oracle.get_T_and_predmaps(self.graph, self.optim_params, self.sources, self.targets)
        print("norm T: ", np.linalg.norm(T))

        # self.d = self.oracle.get_d(self.optim_params, T)
        self.d, self.optim_params.la, self.optim_params.mu = self.sinkhorn.run(T / self.params.gamma)
        flows_on_shortest = self.oracle.get_flows_on_shortest(self.sources, self.targets, self.d, pred_maps)

        grad_t = self.oracle.grad_dF_dt(self.optim_params, flows_on_shortest)
        grad_la = self.oracle.grad_dF_dla(self.d)
        grad_mu = self.oracle.grad_dF_dmu(self.d)

        full_grad = grad_t
        dual_value = self.oracle.calc_F_via_d(self.optim_params, self.d, T)

        self.flows = self.oracle.get_flows_on_shortest(self.sources, self.targets, self.d, pred_maps)

        return dual_value, full_grad, flows_on_shortest

    def get_prime_value(self):
        return self.oracle.prime(self.flows, self.d)

    def get_init_vars_block(self):
        return self.oracle.t_bar.copy()


# TODO: убрать unused переменные
def ustm_mincost_mcf(
        oracle_stacker: OracleSinkhornStacker,
        eps_abs: float,
        eps_cons_abs: float,
        max_iter: int = 10000,
        stop_by_crit: bool = True,
) -> tuple:
    dgap_log = []
    cons_log = []
    A_log = []
    history_dual_values = []
    history_prime_values = []
    d_history = []
    flows_history = []

    A_prev = 0.0
    print(1)

    # t_start = np.zeros(oracle_stacker.parameters_vector_size)  # dual costs w
    t_start = oracle_stacker.get_init_vars_block()  # dual costs w
    print(1)

    y_start = u_prev = t_prev = np.copy(t_start)
    assert y_start is u_prev  # acceptable at first initialization

    print(1)
    grad_sum_prev = np.zeros(len(t_start))

    _, grad_y, flows_averaged = oracle_stacker(y_start)
    d_avaraged = oracle_stacker.d.copy()

    L_value = np.linalg.norm(grad_y) / 10

    A = u = t = y = None
    inner_iters_num = 0

    print("start optimizing")
    # for k in tqdm(range(max_iter)):
    for k in tqdm(range(max_iter)):
        while True:
            inner_iters_num += 1

            alpha = 0.5 / L_value + (0.25 / L_value ** 2 + A_prev / L_value) ** 0.5
            A = A_prev + alpha

            y = (alpha * u_prev + A_prev * t_prev) / A
            func_y, grad_y, flows_y = oracle_stacker(y)
            #             history_dual_values.append(func_y)

            grad_sum = grad_sum_prev + alpha * grad_y

            u = y_start - grad_sum
            print("count values below t_bar in new t: ", (u[:oracle_stacker.T_LEN] < oracle_stacker.oracle.t_bar).sum())
            u[:oracle_stacker.T_LEN] = np.maximum(oracle_stacker.oracle.t_bar, u[:oracle_stacker.T_LEN])
            # u = np.maximum(0, y_start - grad_sum)

            t = (alpha * u + A_prev * t_prev) / A
            func_t, _, _ = oracle_stacker(t)

            lvalue = func_t

            print("norm (t - y): ", np.linalg.norm(t - y))
            print("norm t: ", np.linalg.norm(oracle_stacker.optim_params.t))
            print("norm la: ", np.linalg.norm(oracle_stacker.optim_params.la))
            print("norm mu: ", np.linalg.norm(oracle_stacker.optim_params.mu))
            print()

            rvalue = (func_y + np.dot(grad_y, t - y) + 0.5 * L_value * np.sum((t - y) ** 2) +
                      #                      0.5 * alpha / A * eps_abs )  # because, in theory, noise accumulates
                      0.5 * eps_abs)

            if lvalue <= rvalue:
                break
            else:
                L_value *= 2

            assert L_value < np.inf

        # history_dual_values.append(func_y)
        #         history_prime_values.append(oracle_stacker.get_prime_value())

        history_dual_values.append(func_t)
        history_prime_values.append(oracle_stacker.oracle.prime(flows_averaged, d_avaraged))

        A_prev = A
        L_value /= 2

        t_prev = t
        u_prev = u
        grad_sum_prev = grad_sum

        teta = alpha / A
        # TODO TODO
        print("#######################################teta: ", teta)
        flows_averaged = flows_averaged * (1 - teta) + flows_y * teta
        #         flows_averaged_e = flows_averaged.sum(axis=(0, 1))
        d_avaraged = d_avaraged * (1 - teta) + oracle_stacker.d * teta
        d_history.append(oracle_stacker.d)
        flows_history.append(flows_y)

        dgap_log.append(oracle_stacker.oracle.prime(flows_averaged, d_avaraged) + func_t)
        # cons_log.append(model.constraints_violation_l1(flows_averaged_e))
        A_log.append(A)

        if stop_by_crit and dgap_log[-1] <= eps_abs and cons_log[-1] <= eps_cons_abs:
            break

    return t, flows_history, flows_averaged, d_history, d_avaraged, history_prime_values, history_dual_values, dgap_log,\
           cons_log, A_log, history_dual_values, history_prime_values


In [None]:
T_LEN = 76
LA_LEN = 25
MU_LEN = 25

In [None]:
net_df = get_network_df(Path('SiouxFalls') / 'SiouxFalls_net.tntp')
corrs = get_corrs(Path('SiouxFalls') / 'SiouxFalls_trips.tntp')
people_count = corrs.sum()
corrs = corrs / people_count
net_df.capacity /= people_count
graph = graph_tool.Graph(net_df.values, eprops=[('capacity', 'double'), ('fft', 'double')])

In [None]:
from transport_problem import HyperParams, DualOracle, OptimParams
l = np.sum(corrs, axis=1)
w = np.sum(corrs, axis=0)

zones_num = len(l)
sources = np.arange(zones_num)
targets = np.arange(zones_num)

params = HyperParams(gamma=10, mu_pow=0.25, rho=0.15)
oracle = DualOracle(graph, l, w, params)
oracle_stacker = OracleSinkhornStacker(oracle, graph, sources, targets, l, w, params)

In [None]:
t, flows_history, flows_averaged, d_history, d_avaraged, history_prime_values, history_dual_values, dgap_log,\
           cons_log, A_log, history_dual_values, history_prime_values = ustm_mincost_mcf(oracle_stacker,  eps_abs=1e-3, eps_cons_abs=1e-6, max_iter=600, stop_by_crit=False)

In [None]:
plt.plot(dgap_log, label="dual gap")
plt.legend()
plt.yscale("log")
plt.show()


plt.plot(history_prime_values, label="history prime function")
plt.legend()
plt.show()

plt.plot(history_dual_values, label="history dual function")
plt.legend()
plt.show()

In [None]:
# расчет прямой функции
vars = np.zeros(oracle_stacker.parameters_vector_size)  # dual costs w
_ = oracle_stacker(vars)
print("prime function value: ", oracle_stacker.get_prime_value())

In [None]:
# расчет прямой функции
vars = np.zeros(oracle_stacker.parameters_vector_size)  # dual costs w
_ = oracle_stacker(vars)
print("prime function value: ", oracle_stacker.get_prime_value())