In [80]:
import torch
from src import visualize_latent_space as vls
from functools import partial

In [81]:


#########################################################################################################################################################################
#########################################################################################################################################################################
#########################################################################################################################################################################
#Riemannian Gradient Descent on O(p)
#########################################################################################################################################################################
#########################################################################################################################################################################
#########################################################################################################################################################################

class Op_Riemannian_GD:
    
    """ 
    
    Given some data, find an orthogonal transforamtion that put it into the standard simplex using Riemannian gradient descent on Op, the orthogonal group of dimension p. 
    
    Args:
        data (torch.tensor of dimension n by p): under our settings, it is the raw ASE that we cannot use to estimate the model parameter
        tolerace (float): stopping criterion for the gradient descent
        mode (string): either "softplus" or "relu", it defines the type of penalty function used
        softplus_parameter (float): bigger parameter means that the softplus function looks more like the relu function, i.e. less smooth penalty
    
    Attributes:
        relu_loss: loss under the relu penalty
        softplus_loss: loss under the softplus penalty
        align_mat: the desired orthogonal transformation
    
    """
    
    def __init__(self, data, initialization = None, softplus_parameter = 5, tolerance = 0.01):

        self.data = data
        self.tolerance = tolerance
        self.initialization = initialization
        self.smoothing = softplus_parameter
        self.softplus_loss = self.simplex_loss_softplus(self.data, self.smoothing)
        self.align_mat = self.GD_Armijo()

    def update(self, mode = None, smoothing = None, tolerance = None):
        if mode is not None:
            self.mode = mode
        if smoothing is not None:
            self.smoothing = smoothing
        if tolerance is not None:
            self.tolerance = tolerance
        self.align_mat = self.GD_Armijo

    @staticmethod
    def cyclic_permutation_matrices(p):

        # Start with the identity matrix of size p x p
        identity_matrix = torch.eye(p)
        
        # Initialize an empty tensor to hold the result (p^2 x p)
        result = torch.zeros((p * p, p), dtype=identity_matrix.dtype)
        
        # Loop through each cyclic permutation
        for i in range(p):
            # Apply the i-th cyclic permutation (rotate rows of the identity matrix)
            permuted_matrix = torch.roll(identity_matrix, shifts=i, dims=0)
            
            # Stack the permuted matrix in the result
            result[i * p:(i + 1) * p, :] = permuted_matrix
        
        return result
        

    
    @staticmethod
    def simplex_loss_softplus(data_set, transformation, smoothing):
        
        """ 
        
        Same thing as the simplex_loss_relu function, but replacing the ReLU function with a softplus function with parameter smoothing. 
        The softplus function with parameter beta is defined to be:
        
        softplus(x, beta) = log(1 + exp(x * beta))/beta
        
        """

        X = data_set
        n, p = X.shape

        Ip = torch.eye(p)
        permu = Op_Riemannian_GD.cyclic_permutation_matrices(p)

        Z = torch.kron(Ip, X.matmul(transformation)).matmul(permu)
        
        mu = smoothing

        softplus = torch.nn.Softplus(beta = mu)

        negativity_loss = torch.sum(softplus(-Z))

        row_sum_minus_1 = torch.sum(Z, dim = 1) - 1
        simp_loss = torch.sum(softplus(row_sum_minus_1))

        return(negativity_loss + simp_loss)
    
    def deriv_W_softplus(self, W):

        X = self.data
        n, p = X.shape
        mu = self.smoothing
        W.requires_grad_(True)

        loss = Op_Riemannian_GD.simplex_loss_softplus(X, W, mu)
        loss.backward()
        return(W.grad)
            
  
    # def deriv_W_softplus(self, W):
        
    #     """ 
    #     The derivative of the softplus function, L(X*W, mu) with respect to W, the orthogonal transformation 
    #     """

    #     X = self.data
    #     n, p = X.shape
        
    #     mu = self.smoothing

    #     T0 = torch.exp(-mu* X @ W)
    #     deriv_neg = -X.T @ (T0/(1 + T0))

    #     row_sum_minus_1 = torch.sum(X @ W, dim = 1) - 1

    #     T1 = torch.exp(mu * row_sum_minus_1.unsqueeze(dim = 1))
    #     deriv_simp = X.T @ (T1/(1 + T1)) @ torch.ones((1, p))

    #     return(deriv_neg + deriv_simp)
    
    def proj_skew_sym_at_W(self, M, W):
        
        """ 
        Projection of M to the tangent space of Op at W
        """

        projection = W @ (W.T @ M - M.T @ W)/2

        return(projection)

    def matrix_exp_at_W(self, xi, W):
        
        """ 
        The retractiom, it takes xi, the computed gradient step, and map it onto Op along a geodesic that starts at W
        """

        Exp_w_xi = W @ torch.matrix_exp(W.T @ xi)

        return(Exp_w_xi)
    
    def GD_one_step(self, prev_position, step):
        
        """ 
        Given the current orthogonal transformation, and a step size, take a graident descent step, and get another orthogonal transformation
        """
        
        W_old = prev_position

        W_old = W_old * torch.sqrt(1/(W_old @ W_old.T)[0,0])
        
        euclid_deriv = self.deriv_W_softplus(W_old)
        
        tangent_deriv = self.proj_skew_sym_at_W(euclid_deriv, W_old)

        W_new = self.matrix_exp_at_W(-step*tangent_deriv, W_old)
        
        return(W_new)
    
    def GD_Armijo(self):
        
        """ 
        Backtracking line search but uses the riemannian gradient instead
        """

        X = self.data
        n, p = X.shape

        if self.initialization is not None:
            W = self.initialization
        else: 
            W = torch.eye(p)

        grad = self.deriv_W_softplus
        cost = partial(self.simplex_loss_softplus, smoothing = self.smoothing)
        
        b = 0.1; sigma = 0.1
        max_iter = 200 * p

        iter = 1
        go = True
        while go:
            
            t = 0.001
            k = 1
            while (cost(X @ self.GD_one_step(W, t)) > cost(X @ W) - sigma * t * torch.norm(grad(W))):
                t = t * (b**k)
                k += 1
                if k > 10:
                    break


            W = self.GD_one_step(W, t)
            jump = sigma * t * torch.norm(grad(W))

            go = (torch.norm(grad(W)) > self.tolerance) & (jump > 10e-8) & (iter < max_iter)
            iter += 1

        return(W)


In [8]:
import torch
import pandas as pd
from tqdm import tqdm as tm
from src import Simulation as sim
from src import Dir_Reg as DR
from src import ABC_Reg_copy as ABC_Reg 
from src import Align

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device("mps") if torch.backends.mps.is_available() else device

print(device)

cuda


In [2]:
# we want to investigate what happens to our estimation, when we use the correct model, but misspecify the dimension. 
# 1. generate some synthetic data with 8 dimensions
# 2. first do full oracle with the correct dimension to get base line
# 3. then embed the graph in lower dimension, run regression, and then compare result to full oracle

In [79]:
torch.manual_seed(5)

p, K, T, n = 6, 3, 2, 6000
alpha_0 = torch.ones(K, p)*2
# alpha_0 = torch.eye(p)*9 + 1

model = sim.ABC(time = T,
                nodes = n,
                beta = [1, 1, -4, 5],
                alpha_0 = alpha_0)

In [80]:
temp = ABC_Reg.est(two_lat_pos = model.synth_data['lat_pos'],
                   two_adj_mat = model.synth_data['obs_adj'],
                   groups = K,
                   )

In [87]:
temp.specify_mode('OL', fit = True)
DR.fit.proj_beta(temp.fitted.est_result["estimate"], DR.fit.gen_constraint(p, True))

tensor([ 0.9733,  0.8217, -4.2562,  4.3985])

In [None]:
p, K, T = 6, 3, 2
alpha_0 = torch.ones(K, p)*2

model = sim.ABC(time = T,
        nodes = 30,
        beta = [1, 1, -4, 5],
        alpha_0 = alpha_0)

seed_list = list(range(60))
n_list = list(range(1500, 6001, 1500))
p0_list = list(range(2, 8))

# Initialize an empty list to store results
results = []
for seed in seed_list:
    torch.manual_seed(seed)
    print(round(seed/60, 2))
    for n in n_list:
        for p0 in p0_list:
            # Update model settings
            model.update_settings(nodes = n)
            # Initialize estimation
            estimate = ABC_Reg.est(two_lat_pos = model.synth_data['lat_pos'],
                    two_adj_mat = model.synth_data['obs_adj'],
                    groups = K,
                    )
            # Perform estimation by specifying mode and embedding dimension p0
            estimate.specify_mode('NO', fit = True, embed_dim = p0)
            
            # Compute beta_est and info_lost
            beta_est = DR.fit.proj_beta(estimate.fitted.est_result["estimate"], DR.fit.gen_constraint(p0+1, True)).tolist()
            info_lost = estimate.fitted.est_result["info_lost"]
            
            # Create a dictionary for the current iteration
            result = {
                'seed': seed,
                'n': n,
                'p0': p0,
                'beta1': beta_est[0],
                'beta2': beta_est[1],
                'beta3': beta_est[2],
                'beta4': beta_est[3],
                'info_lost': info_lost  # Optional: Include if you want to store this value
            }
            # Append the result to the list
            results.append(result)

# Convert the list of dictionaries to a pandas DataFrame
df = pd.DataFrame(results)

0.0
0.02
0.03
0.05
0.07
0.08
0.1
0.12
0.13
0.15
0.17
0.18
0.2
0.22
0.23
0.25
0.27
0.28
0.3
0.32
0.33
0.35
0.37
0.38
0.4
0.42
0.43
0.45
0.47
0.48
0.5
0.52
0.53
0.55
0.57
0.58
0.6
0.62
0.63
0.65


In [115]:
p_0 = 3
temp.specify_mode('NO', fit = True, embed_dim = p_0)
print(temp.fitted.est_result["info_lost"])
DR.fit.proj_beta(temp.fitted.est_result["estimate"], DR.fit.gen_constraint(p_0+1, True))

0.032333333333333325


tensor([ 0.3957,  0.0590, -1.8203, 11.5494])

In [83]:
p_0 = 4
temp.specify_mode('NO', fit = True, embed_dim = p_0)
print(temp.fitted.est_result["info_lost"])
DR.fit.proj_beta(temp.fitted.est_result["estimate"], DR.fit.gen_constraint(p_0+1, True))

0.09533333333333338


tensor([ 0.3947,  0.0186, -1.5651,  0.8212])

In [91]:
p_0 = 5
temp.specify_mode('NO', fit = True, embed_dim = p_0)
print(temp.fitted.est_result["info_lost"])
DR.fit.proj_beta(temp.fitted.est_result["estimate"], DR.fit.gen_constraint(p_0+1, True))


0.2743333333333333


tensor([ 0.4642,  0.4226, -2.3290,  2.4509])

In [85]:
p_0 = 6
temp.specify_mode('NO', fit = True, embed_dim = p_0)
print(temp.fitted.est_result["info_lost"])
DR.fit.proj_beta(temp.fitted.est_result["estimate"], DR.fit.gen_constraint(p_0+1, True))

0.43066666666666664


tensor([ 0.6828,  0.8133, -3.9691, -3.2112])

In [86]:
p_0 = 7
temp.specify_mode('NO', fit = True, embed_dim = p_0)
print(temp.fitted.est_result["info_lost"])
DR.fit.proj_beta(temp.fitted.est_result["estimate"], DR.fit.gen_constraint(p_0+1, True))

0.7296666666666667


tensor([ 0.9345,  1.1643, -3.5133,  5.0706])