In [5]:
import torch
import config
import numpy as np
import matplotlib.pyplot as plt
import os
from Alg.solving_algorithm import ModelGenerator
from CustomModels.my_models import Integrator
from CustomModels.my_models import weighted_amount
from aml.plotting import *
from Losses.Losses import *
from sklearn import decomposition
from tqdm import tqdm
from scipy.spatial import ConvexHull

def adjust_alpha(alpha_n):
    for j in range(len(alpha_n)):
        if alpha_n[j]<0.0:
            alpha_n[j] = 0.0
        elif alpha_n[j] > 1.0:
            alpha_n[j] = 1.0
    alpha_n = alpha_n/np.sum(alpha_n)
    return alpha_n 

def grad_descent_from_alpha_in_simplex(alpha_vec,p_list, shared_integration_supports):
    d = len(p_list)
    alpha_n = np.copy(alpha_vec)
    alpha_n = alpha_n/np.sum(alpha_n)
    n = 0
    p_mid = None
    L_mid = None
    all_losses = []
    while True:
        p_mid = weighted_amount(list_of_distributions=p_list, alpha_list=alpha_n)
        L_mid = get_L2_Distrib4D(p_mid,shared_integration_supports)
        gradient_ = np.zeros(shape=(d,))
        alpha_not_zero = alpha_n[np.argwhere(alpha_n != 0.0)]
        alpha_min_ = np.min(alpha_not_zero)
        # epsilon_ = np.minimum(alpha_min_/2.0, 1.0/d)
        epsilon_ = 10**(-3)
        for j in range(d):
            if alpha_n[j] < epsilon_:
                alpha_1 = np.copy(alpha_n)
                alpha_1[j] = alpha_1[j] + epsilon_
                p_1 = weighted_amount(list_of_distributions=p_list, alpha_list=alpha_1)
                L_1 = get_L2_Distrib4D(p_1,shared_integration_supports)
                gradient_[j] = (L_1 - L_mid)/epsilon_
                continue
            if alpha_n[j] > 1.0-epsilon_:
                alpha_2 = np.copy(alpha_n)
                alpha_2[j] = alpha_2[j] - epsilon_
                p_2 = weighted_amount(list_of_distributions=p_list, alpha_list=alpha_2)
                L_2 = get_L2_Distrib4D(p_2,shared_integration_supports)
                gradient_[j] = (L_mid-L_2)/epsilon_
                continue

            alpha_1 = np.copy(alpha_n)
            alpha_1[j] = alpha_1[j] + epsilon_
            alpha_2 = np.copy(alpha_n)
            alpha_2[j] = alpha_2[j] - epsilon_

            p_1 = weighted_amount(list_of_distributions=p_list, alpha_list=alpha_1)
            p_2 = weighted_amount(list_of_distributions=p_list, alpha_list=alpha_2)

            L_1 = get_L2_Distrib4D(p_1,shared_integration_supports)
            L_2 = get_L2_Distrib4D(p_2,shared_integration_supports)
            # print(L_1,L_2, epsilon_)
            gradient_[j] = (L_1 - L_2)/(2*epsilon_)
        

        lambda_vec = np.logspace(start=4,stop=-4,num=10)
        lambda_best = None
        loss_current = L_mid
        ls = []
        lambda_ls = []
        for lambda_ in lambda_vec:
            alpha_copy = np.copy(alpha_n)
            alpha_after = alpha_copy - lambda_*gradient_
            if np.sum(alpha_after < 0.0)==d:
                continue
            alpha_after = adjust_alpha(alpha_after)        
            p_after = weighted_amount(list_of_distributions=p_list, alpha_list=alpha_after)
            L_after = get_L2_Distrib4D(p_after,shared_integration_supports)
            ls.append(L_after)
            lambda_ls.append(lambda_)
            if L_after<loss_current:
                loss_current = L_after
                lambda_best = lambda_
        # arg_best = np.argsort(ls)[0]
        # left_pos = np.maximum(0, arg_best-1)
        # right_pos = np.minimum(len(ls)-1, arg_best+1)

        # left = lambda_ls[left_pos]
        # right = lambda_ls[right_pos]

        # lambda_vec = np.linspace(left,right,10)
        # addls = []
        # addlambda_ls = []
        # for lambda_ in lambda_vec:
        #     alpha_copy = np.copy(alpha_n)
        #     alpha_after = alpha_copy - lambda_*gradient_
        #     if np.sum(alpha_after < 0.0)==d:
        #         continue
        #     alpha_after = adjust_alpha(alpha_after)        
        #     p_after = weighted_amount(list_of_distributions=p_list, alpha_list=alpha_after)
        #     L_after = get_L2_Distrib4D(p_after,shared_integration_supports)
        #     addls.append(L_after)
        #     addlambda_ls.append(lambda_)
        #     if L_after<loss_current:
        #         loss_current = L_after
        #         lambda_best = lambda_
        # fig,ax = plt.subplots()
        # ax.plot(lambda_ls,ls,color= 'k')
        # ax.plot(addlambda_ls,addls,color= 'r')
        # plt.show()
        

        all_losses.append(loss_current)
        if loss_current == L_mid:
            break        

        alpha_n = alpha_n - lambda_best*gradient_
        alpha_n = adjust_alpha(alpha_n)        

        # board_reference.Push(experiment_metadata=experiment_metadata,
        #     x=n,y= lambda_best, label='best_lambda')
        
        # board_reference.Push(experiment_metadata=experiment_metadata,
        #     x=n,y= np.linalg.norm(gradient_), label='grad_norm')

        n+=1
    
    return p_mid, L_mid

def approx_equal(x1,x2,precision):
    if np.absolute(x1-x2) < precision:
        return True
    else:
        return False

def search_min_in_simplex_with_center(L_0, mid_of_simplex, simplex, p_list, shared_integration_supports):
    pos_of_mid = np.argwhere(simplex==mid_of_simplex).flatten()[0]
    d = len(p_list)
    all_losses = []
    # eps_vec = np.concatenate([np.linspace(0, 1.0, 10),np.logspace(-1,-6,10)])
    eps_vec = np.linspace(0.1, 0.9, 5)
    p_n = None
    L_n = 10**99
    loss_n = []
    n_ = []
    for ITER in range(10):
        min_L_in_iter = 10**99
        if ITER==0:
            for i in range(len(simplex)):
                if mid_of_simplex == simplex[i]:
                    continue
                for j in range(len(eps_vec)):
                    eps_= eps_vec[j]
                    p_j = weighted_amount(list_of_distributions=[p_list[pos_of_mid], p_list[i]], alpha_list=[1.0-eps_, eps_])
                    L_j = get_L2_Distrib4D(p_j, shared_integration_supports)
                    if (L_j < min_L_in_iter):
                        min_L_in_iter = L_j
                    if (L_j < L_n) or approx_equal(L_j, L_n, 10**(-4)):
                        L_n = L_j
                        p_n = p_j
        else:
            for i in range(len(simplex)):
                for j in range(len(eps_vec)):
                    eps_= eps_vec[j]
                    p_j = weighted_amount(list_of_distributions=[p_n, p_list[i]], alpha_list=[1.0-eps_, eps_])
                    L_j = get_L2_Distrib4D(p_j, shared_integration_supports)
                    if (L_j < min_L_in_iter):
                        min_L_in_iter = L_j
                    if (L_j < L_n) or approx_equal(L_j, L_n, 10**(-4)):
                        L_n = L_j
                        p_n = p_j
        if approx_equal(L_n, L_0,10**(-4)):
            break 
        print('ITER {} LOSS {}'.format(ITER, L_n))
        loss_n.append(L_n)
        n_.append(ITER)
        if min_L_in_iter >= L_0:
            break

    if L_n > L_0:
        return p_list[pos_of_mid], L_0
    else:
        return p_n, L_n

def grad_descent(L_0,p_list,mid_of_siplex,simplex, shared_integration_supports):
    # p_n, L_n = search_min_in_simplex_with_center(L_0,mid_of_siplex, simplex,p_list, shared_integration_supports) 
    # return p_n, L_n
    alpha_vec = np.zeros(shape=(len(p_list),))
    pos_of_mid = np.argwhere(simplex==mid_of_siplex).flatten()[0]
    alpha_vec[pos_of_mid] = 1.0
    p_mid, L_mid = grad_descent_from_alpha_in_simplex(alpha_vec,p_list, shared_integration_supports)
    return p_mid, L_mid

In [6]:
N = 1000

mg = ModelGenerator(rules=config.rules,
                            cache_dir=config.Phi_cache_dir,
                            clear_cache=False)
vectors = torch.load(config.Phi_vector_representation)
all_p = [torch.load(os.path.join(mg.cache_dir, 'distrib4D_{}.txt'.format(i))) for i in range(N)]
all_v = torch.load(os.path.join(config.task_dir, 'L2_for_Phi.txt'))
# print(np.sort(all_v))
simplices = torch.load(os.path.join(config.task_dir, 'triangulation_simplexes.txt'))
sorted_vertesices = [el for el in np.argsort(all_v)]  
# get simliexes with best loss 
# sorted_simplixes = []
# simplexes_best = []
simpexes_for_omptimization = []
for best_vertex in sorted_vertesices:
    vertex_siplixes = []
    for simplex in simplices:
        if best_vertex in  simplex:
            vertex_siplixes.append(simplex)
    unique_vertexes = np.unique(vertex_siplixes)
    simpexes_for_omptimization.append(unique_vertexes)

print('number of simplexes for oprimization {}'.format(len(simpexes_for_omptimization)))

d = len(vectors[0])
support_vertexes = torch.load(os.path.join(config.task_dir, 'support_points.txt'))

mg = ModelGenerator(rules=config.rules,
                            cache_dir=config.Phi_cache_dir,
                            clear_cache=False)
shared_integration_supports = Integrator(dir_=config.integrator_dir,
                                        shared_data=mg.shared_data,
                                        clear_cache=True).shared_integration_supports



# board = TensorBoard(tensorboard_exe_path=config.tensorboard_path,
#                     logdir=os.path.join(config.task_dir, 'descent_log'),
#                     port= '64001')
# exp_metadata = 'sorted_simplexes'+get_time()
# board.InitExperiment(experiment_metadata= exp_metadata)

print('start best {}'.format(np.min(all_v)))
global_L_min = 10**99
p_best = None
# losses_per_simplex_optimization = {i:[] for i in range(len(sorted_simplixes))}
# max_iter = np.maximum(int(0.005*len(sorted_simplixes)),1)
max_iter = N
for i in tqdm(range(max_iter)):
    # j = np.random.randint(0, len(sorted_simplixes))
    # simplex = sorted_simplixes[j]
    # BestOfSimplex = simplexes_best[j]
    simplex = simpexes_for_omptimization[i]

    p_in_simplex = [all_p[el] for el in simplex]
    mid_of_siplex = sorted_vertesices[i]
    BestOfSimplex = all_v[mid_of_siplex]

    # print(BestOfSimplex)
    L_0 = BestOfSimplex
    p_best, L_best = grad_descent(L_0,p_in_simplex,mid_of_siplex,simplex, shared_integration_supports)
    if L_best < global_L_min:
        global_L_min = L_best
        p_best = p_best
        torch.save(p_best, config.Phi_descent_best_p_path)
    print('simplex L_0 {} simplex fitted L {}'.format(BestOfSimplex, L_best))

print(global_L_min)

number of simplexes for oprimization 1000
     num_of_rect_in_intersection 17090
start best 0.658098526443756


  0%|          | 1/1000 [00:09<2:31:02,  9.07s/it]

simplex L_0 0.658098526443756 simplex fitted L 0.658098526443756


  0%|          | 2/1000 [00:44<6:51:56, 24.77s/it]

simplex L_0 0.7268416947884779 simplex fitted L 0.7203976909998147


  0%|          | 3/1000 [00:53<4:50:12, 17.47s/it]

simplex L_0 0.7452107457998413 simplex fitted L 0.7452107457998413


  0%|          | 4/1000 [01:03<4:01:06, 14.52s/it]

simplex L_0 0.7466058521185177 simplex fitted L 0.7466058521185177


  0%|          | 5/1000 [01:18<4:00:31, 14.50s/it]

simplex L_0 0.7467783333519011 simplex fitted L 0.744429981461491


  1%|          | 6/1000 [01:37<4:30:20, 16.32s/it]

simplex L_0 0.7588048771206818 simplex fitted L 0.7556086174879798


  1%|          | 7/1000 [01:48<4:00:17, 14.52s/it]

simplex L_0 0.766070964522576 simplex fitted L 0.766070964522576


  1%|          | 8/1000 [02:11<4:43:47, 17.16s/it]

simplex L_0 0.7660936431734452 simplex fitted L 0.7460598616393094


  1%|          | 9/1000 [02:19<3:58:12, 14.42s/it]

simplex L_0 0.7668395082440372 simplex fitted L 0.7668395082440372


  1%|          | 10/1000 [02:42<4:38:26, 16.87s/it]

simplex L_0 0.7715619939997257 simplex fitted L 0.7703827504457901


  1%|          | 11/1000 [02:50<3:54:32, 14.23s/it]

simplex L_0 0.7733347679978654 simplex fitted L 0.7733347679978654


  1%|          | 11/1000 [03:07<4:41:08, 17.06s/it]


KeyboardInterrupt: 