In [23]:
import torch
from src import visualize_latent_space as vls
from functools import partial

In [46]:


#########################################################################################################################################################################
#########################################################################################################################################################################
#########################################################################################################################################################################
#Riemannian Gradient Descent on O(p)
#########################################################################################################################################################################
#########################################################################################################################################################################
#########################################################################################################################################################################

class Op_Riemannian_GD:
    
    """ 
    
    Given some data, find an orthogonal transforamtion that put it into the standard simplex using Riemannian gradient descent on Op, the orthogonal group of dimension p. 
    
    Args:
        data (torch.tensor of dimension n by p): under our settings, it is the raw ASE that we cannot use to estimate the model parameter
        tolerace (float): stopping criterion for the gradient descent
        mode (string): either "softplus" or "relu", it defines the type of penalty function used
        softplus_parameter (float): bigger parameter means that the softplus function looks more like the relu function, i.e. less smooth penalty
    
    Attributes:
        relu_loss: loss under the relu penalty
        softplus_loss: loss under the softplus penalty
        align_mat: the desired orthogonal transformation
    
    """
    
    def __init__(self, data, initialization = None, softplus_parameter = 5, tolerance = 0.01):

        self.data = data
        self.tolerance = tolerance
        self.initialization = initialization
        self.smoothing = softplus_parameter
        self.softplus_loss = self.simplex_loss_softplus(self.data, self.smoothing)
        self.align_mat = self.GD_Armijo()

    def update(self, mode = None, smoothing = None, tolerance = None):
        if mode is not None:
            self.mode = mode
        if smoothing is not None:
            self.smoothing = smoothing
        if tolerance is not None:
            self.tolerance = tolerance
        self.align_mat = self.GD_Armijo

    @staticmethod
    def cyclic_permutation_matrices(p):

        # Start with the identity matrix of size p x p
        identity_matrix = torch.eye(p)
        
        # Initialize an empty tensor to hold the result (p^2 x p)
        result = torch.zeros((p * p, p), dtype=identity_matrix.dtype)
        
        # Loop through each cyclic permutation
        for i in range(p):
            # Apply the i-th cyclic permutation (rotate rows of the identity matrix)
            permuted_matrix = torch.roll(identity_matrix, shifts=i, dims=0)
            
            # Stack the permuted matrix in the result
            result[i * p:(i + 1) * p, :] = permuted_matrix
        
        return result
        

    
    @staticmethod
    def simplex_loss_softplus(data_set, transformation, smoothing):
        
        """ 
        
        Same thing as the simplex_loss_relu function, but replacing the ReLU function with a softplus function with parameter smoothing. 
        The softplus function with parameter beta is defined to be:
        
        softplus(x, beta) = log(1 + exp(x * beta))/beta
        
        """

        X = data_set
        n, p = X.shape

        Ip = torch.eye(p)
        permu = Op_Riemannian_GD.cyclic_permutation_matrices(p)

        Z = torch.kron(Ip, X.matmul(transformation)).matmul(permu)
        
        mu = smoothing

        softplus = torch.nn.Softplus(beta = mu)

        negativity_loss = torch.sum(softplus(-Z))

        row_sum_minus_1 = torch.sum(Z, dim = 1) - 1
        simp_loss = torch.sum(softplus(row_sum_minus_1))

        return(negativity_loss + simp_loss)
    
    def deriv_W_softplus(self, W):

        X = self.data
        n, p = X.shape
        mu = self.smoothing
        W.requires_grad_(True)

        loss = Op_Riemannian_GD.simplex_loss_softplus(X, W, mu)
        loss.backward()
        return(W.grad)
            
  
    # def deriv_W_softplus(self, W):
        
    #     """ 
    #     The derivative of the softplus function, L(X*W, mu) with respect to W, the orthogonal transformation 
    #     """

    #     X = self.data
    #     n, p = X.shape
        
    #     mu = self.smoothing

    #     T0 = torch.exp(-mu* X @ W)
    #     deriv_neg = -X.T @ (T0/(1 + T0))

    #     row_sum_minus_1 = torch.sum(X @ W, dim = 1) - 1

    #     T1 = torch.exp(mu * row_sum_minus_1.unsqueeze(dim = 1))
    #     deriv_simp = X.T @ (T1/(1 + T1)) @ torch.ones((1, p))

    #     return(deriv_neg + deriv_simp)
    
    def proj_skew_sym_at_W(self, M, W):
        
        """ 
        Projection of M to the tangent space of Op at W
        """

        projection = W @ (W.T @ M - M.T @ W)/2

        return(projection)

    def matrix_exp_at_W(self, xi, W):
        
        """ 
        The retractiom, it takes xi, the computed gradient step, and map it onto Op along a geodesic that starts at W
        """

        Exp_w_xi = W @ torch.matrix_exp(W.T @ xi)

        return(Exp_w_xi)
    
    def GD_one_step(self, prev_position, step):
        
        """ 
        Given the current orthogonal transformation, and a step size, take a graident descent step, and get another orthogonal transformation
        """
        
        W_old = prev_position

        W_old = W_old * torch.sqrt(1/(W_old @ W_old.T)[0,0])
        
        euclid_deriv = self.deriv_W_softplus(W_old)
        
        tangent_deriv = self.proj_skew_sym_at_W(euclid_deriv, W_old)

        W_new = self.matrix_exp_at_W(-step*tangent_deriv, W_old)
        
        return(W_new)
    
    def GD_Armijo(self):
        
        """ 
        Backtracking line search but uses the riemannian gradient instead
        """

        X = self.data
        n, p = X.shape

        if self.initialization is not None:
            W = self.initialization
        else: 
            W = torch.eye(p)

        grad = self.deriv_W_softplus
        cost = partial(self.simplex_loss_softplus, smoothing = self.smoothing)
        
        b = 0.1; sigma = 0.1
        max_iter = 200 * p

        iter = 1
        go = True
        while go:
            
            t = 0.001
            k = 1
            while (cost(X @ self.GD_one_step(W, t)) > cost(X @ W) - sigma * t * torch.norm(grad(W))):
                t = t * (b**k)
                k += 1
                if k > 10:
                    break


            W = self.GD_one_step(W, t)
            jump = sigma * t * torch.norm(grad(W))

            go = (torch.norm(grad(W)) > self.tolerance) & (jump > 10e-8) & (iter < max_iter)
            iter += 1

        return(W)
