In [1]:
import os
import numpy as np
import torch
#from torch import nn, optim
#from torchvision import transforms, models
#import torch.nn.functional as F
import torch.distributions as dist

#def ensure_positive_definite(covariance):
#    eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
#    eigenvalues = torch.clamp(eigenvalues, min=1e-6)  # Set small eigenvalues to a minimum positive value
#    covariance = torch.mm(torch.mm(eigenvectors, torch.diag(eigenvalues)), eigenvectors.T)
#    return covariance
def ensure_positive_definite(matrix, eps=1e-4):
    # Perform eigen decomposition
    eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
    
    # Ensure all eigenvalues are positive
    positive_eigenvalues = torch.clamp(eigenvalues, min=eps)
    
    # Reconstruct the matrix with modified eigenvalues
    positive_definite_matrix = torch.matmul(
        eigenvectors, torch.matmul(torch.diag(positive_eigenvalues), eigenvectors.T)
    )
    
    return positive_definite_matrix

# Example usage:
#covariance = torch.rand(32, 32)  # Replace with your covariance matrix
#positive_definite_covariance = ensure_positive_definite(covariance)

def multi_vars_SkewNormal_density_estimate(values,check=False):
    '''
    Estimate the probability density by SkewNormal density function
    Input: A 2D tensor contains n samples and d features
    check: Used to visualize the density estimation in smooth curve
    '''
    mean = torch.mean(values,0)
    std = torch.std(values,0)
    skewness = torch.mean((values - mean) ** 3,0) / torch.pow(std, 3)
    residuals = values - mean
    standard_normal = dist.Normal(0, 1)
    kernel_values = 2 * standard_normal.cdf((residuals / std) *skewness) *(1/(std*(2*torch.pi)**0.5))* torch.exp(-0.5 * (residuals / std).pow(2))
    pdf = kernel_values / (torch.sum(kernel_values,0))
    #Check
    if check:
        bins=torch.tensor(np.linspace((values.min(0).values),(values.max(0).values),100))
        r_check=bins-mean
        k_check=2 * standard_normal.cdf((r_check / std) * skewness) *(1/(std*(2*torch.pi)**0.5))* torch.exp(-0.5 * (residuals / std).pow(2))
        pdf_check=k_check / (torch.sum(k_check,0))
        return pdf,kernel_values,[bins,k_check,pdf_check]
    else: 
        return pdf,kernel_values

def estimate_total_joint_density_skew_normal_training(values):
    mean = torch.mean(values, dim=0)
    residuals = values - mean
    covariance = torch.cov(residuals.T)
    skewness = torch.mean((residuals) ** 3, dim=0) / torch.pow((torch.diagonal(covariance))**0.5, 3)

    standard_normal = dist.Normal(0, 1)
    #cdf term
    cdf=standard_normal.cdf((residuals/torch.diagonal(covariance).sqrt())@ (skewness.unsqueeze(0)).T)
    #phi(x) term
    try:
        multinormal=dist.MultivariateNormal(mean, covariance)
    except:
        covariance=ensure_positive_definite(covariance)
        multinormal=dist.MultivariateNormal(mean, covariance)
    pd=multinormal.log_prob(values).exp()

    jointskewtotal=2*pd.reshape(cdf.shape)*cdf
    nor_joint=jointskewtotal/(jointskewtotal.sum())
    return nor_joint,jointskewtotal

def estimate_total_joint_density_skew_normal(values,mean,covariance,skewness):
    #mean = torch.mean(values, dim=0)
    residuals = values - mean
    #covariance = torch.cov(residuals.T)
    #skewness = torch.mean((residuals) ** 3, dim=0) / torch.pow((torch.diagonal(covariance))**0.5, 3)

    standard_normal = dist.Normal(0, 1)
    #cdf term
    cdf=standard_normal.cdf((residuals/torch.diagonal(covariance).sqrt())@ (skewness.unsqueeze(0)).T)
    #phi(x) term
    try:
        multinormal=dist.MultivariateNormal(mean, covariance)
    except:
        covariance=ensure_positive_definite(covariance)
        multinormal=dist.MultivariateNormal(mean, covariance)
    pd=multinormal.log_prob(values).exp()

    jointskewtotal=2*pd.reshape(cdf.shape)*cdf
    nor_joint=jointskewtotal/(jointskewtotal.sum())
    return nor_joint,jointskewtotal



def estimate_total_joint_density(data):
    # Calculate the mean and covariance of the data
    mean = torch.mean(data, dim=0)
    covariance = torch.cov(data.T)
    try:
        multinormal=dist.MultivariateNormal(mean, covariance)
    except:
        covariance=ensure_positive_definite(covariance)
        multinormal=dist.MultivariateNormal(mean, covariance)

    # Evaluate the log probability density at the data points
    log_density = multinormal.log_prob(data)

    # Convert log density to tensor
    log_density_tensor = log_density.reshape(-1,1)#.detach()

    # Reshape the density tensor to match the input data shape
    joint_density = torch.exp(log_density_tensor)#.reshape(data.shape)
    nor_joint=joint_density/(joint_density.sum())
    return nor_joint, joint_density

def total_correlation_estimation(marginal,joint,weight_sum=True):
    eps=1e-6
    joint_log=torch.log(joint+eps)
    marginal_log = torch.log(marginal+eps)
    marginal_log_sum=marginal_log.sum(dim=1,keepdim=True)
    if weight_sum:
        total_corr = joint*(joint_log - marginal_log_sum)
    else:
        total_corr = (joint_log - marginal_log_sum)
    return total_corr

def pair_wise_MI_upper(z,full_cov=False):
    '''
    Input, z in (n,d)
    Output, mi_upper array, abs sum of array
    '''
    _,d=z.size()
    mi_upper=torch.zeros(d,d).to(z.device)
    _,marginal_kernel=multi_vars_SkewNormal_density_estimate(z)
    if full_cov:
        for i in range(d):
            for j in range(d):
                marginal_kernel_=marginal_kernel[:,[i,j]]
                _,joint_kernel=estimate_total_joint_density_skew_normal_training(z[:,[i,j]])
                mi=total_correlation_estimation(marginal_kernel_,joint_kernel).sum()
                mi_upper[i,j]=mi
    else:        
        for i in range(d):
            for j in range(i+1,d):
                marginal_kernel_=marginal_kernel[:,[i,j]]
                _,joint_kernel=estimate_total_joint_density_skew_normal_training(z[:,[i,j]])
                mi=total_correlation_estimation(marginal_kernel_,joint_kernel).sum()
                mi_upper[i,j]=mi
    return mi_upper, mi_upper.abs().sum()
#_,mi_sum=pair_wise_MI_upper(z)