In [1]:
import torch
import math
from torch.distributions import Dirichlet, Normal
from src import loss_functions as LF

In [164]:
def compute_Gpp_basis(g_prime, g_double_prime, B1, B2, Y, sigma = 0.02):
    """
    Compute empirical estimate of E[(B1 y)^T g''(y) B2 y - g'(y)^T B1^T B2 y]

    Args:
        g_prime: function (N x p) -> (N x p), gradient function
        g_double_prime: function (N x p) -> (N x p x p), Hessian function
        B1: skew-symmetric matrix (p x p)
        B2: skew-symmetric matrix (p x p)
        Y: data matrix (N x p)

    Returns:
        Scalar: empirical estimate of second derivative contribution
    """
    B1 = B1.float()
    B2 = B2.float()
    Y = Y.float()

    B1Y = Y @ B1.T  # (N x p)
    B2Y = Y @ B2.T  # (N x p)

    # Hessian term: (B1 y)^T H (B2 y)
    H = g_double_prime(Y, sigma)        # (N x p x p)
    B2Y_exp = B2Y.unsqueeze(2)   # (N x p x 1)
    HB2Y = torch.bmm(H, B2Y_exp) # (N x p x 1)
    term1 = torch.bmm(B1Y.unsqueeze(1), HB2Y).squeeze()  # (N,)
    
    # Gradient term: g'(y)^T B1^T B2 y
    B1B2Y = Y @ (B2.T @ B1).T  # (N x p)
    gY = g_prime(Y, sigma)            # (N x p)
    term2 = (gY * B1B2Y).sum(dim=1)  # (N,)

    return (term1 - term2).mean().item()


def mollified_relu_grad(Y, sigma=0.02):
    """
    Compute gradient of mollified ReLU loss w.r.t. Y.
    Y: (N x p), sigma: scalar
    Returns: (N x p) gradient matrix
    """
    normal = Normal(0, 1)
    N, p = Y.shape

    # Negative term gradient (for smoothed_relu(-x_j))
    grad_neg = -normal.cdf(-Y / sigma)  # shape: (N x p)

    # Sum constraint gradient (for smoothed_relu(sum(x) - 1))
    w = Y.sum(dim=1, keepdim=True) - 1  # shape: (N x 1)
    grad_sum = normal.cdf(w / (sigma * p**0.5))  # shape: (N x 1)
    grad_sum_expanded = grad_sum.expand(-1, p)

    return grad_neg + grad_sum_expanded


# def mollified_relu_hess(Y, sigma=0.02):
#     """
#     Computes the manual Hessian of the mollified ReLU simplex loss.
#     Args:
#         Y: (N, p) batch of data points
#         sigma: smoothing parameter

#     Returns:
#         Hessian: tensor of shape (N, p, p)
#     """
#     N, p = Y.shape
#     device = Y.device
#     normal = Normal(0.0, 1.0)

#     # First term: ReLU(-x_i) for each coordinate (diagonal Hessian)
#     z_neg = -Y / sigma  # shape (N, p)
#     phi_neg = normal.log_prob(z_neg).exp()  # PDF
#     H_neg = torch.diag_embed(phi_neg / sigma)  # shape (N, p, p)

#     # Second term: ReLU(sum(x_i) - 1) is rank-1
#     sum_x = Y.sum(dim=1, keepdim=True)  # shape (N, 1)
#     z_sum = (sum_x - 1) / (sigma * p**0.5)
#     phi_sum = normal.log_prob(z_sum).exp()  # shape (N, 1)
#     coeff = phi_sum / (sigma * p**0.5)  # shape (N, 1)

#     ones = torch.ones((N, p, 1), device=device)
#     H_sum = coeff.view(-1, 1, 1) * torch.bmm(ones, ones.transpose(1, 2))  # shape (N, p, p)

#     return H_neg + H_sum
def mollified_relu_hess(y, sigma = 0.02):
    
    p = torch.tensor(y.shape[0])
    device = y.device
    normal = Normal(0, 1)
    
    log_phi = normal.log_prob
    
    term_1 = torch.exp(log_phi(y/sigma))/sigma
    print(term_1)
    term_2 = torch.exp(log_phi((torch.sum(y)-1)/(torch.sqrt(p)*sigma)))/(torch.sqrt(p)*sigma)
    print(term_2)
    term_1 = torch.diag(term_1)
    term_2 = term_2 * torch.ones((p, p), device=device)
    
    return(term_1 + term_2)

def skew_basis(p, device = "cuda", dtype = torch.float32):
    basis = []
    for i in range(p):
        for j in range(i + 1, p):
            B = torch.zeros(p, p, device=device, dtype=dtype)
            B[i, j] = 1 / math.sqrt(2)
            B[j, i] = -1 / math.sqrt(2)
            basis.append(B)
    return torch.stack(basis)  # (d, p, p)

def construct_M_from_basis(g_prime, g_double_prime, B, Y):
    """
    Construct M matrix using basis elements B and gradient/Hessian functions.

    Args:
        g_prime: grad function (N x p) -> (N x p)
        g_double_prime: hessian function (N x p) -> (N x p x p)
        B: tensor of skew basis matrices (d x p x p)
        Y: data matrix (N x p)

    Returns:
        M: tensor of shape (d x d)
    """
    d = B.shape[0]
    M = torch.zeros(d, d, device=Y.device, dtype=Y.dtype)

    for i in range(d):
        for j in range(d):
            M[i, j] = compute_Gpp_basis(g_prime, g_double_prime, B[i], B[j], Y)

    return M



In [None]:
p = 6

alpha = torch.ones(p, p, dtype= torch.float32) + torch.eye(p, dtype= torch.float32)*0
# alpha = torch.tensor([[10, 1, 1, 1], [1, 10, 1], [1, 1, 10]], dtype= torch.float64)
# alpha = torch.tensor([[5, 1]], dtype= torch.float64)
n= 30000
K, p = alpha.shape
torch.manual_seed(5)
dir = Dirichlet(alpha)
X = dir.sample((n // K,)).transpose(0, 1).reshape(n, p)[:, :p].to("cuda")

B = skew_basis(p, "cuda")
M = construct_M_from_basis(mollified_relu_grad, mollified_relu_hess, B, X)

In [None]:
eigenvalues, eigenvectors = torch.linalg.eigh(M)
print(eigenvalues)

In [156]:
from src.utils import LossFunctionWrapper
import torch
import math
from torch.distributions import Dirichlet, Normal
from src import loss_functions as LF

def smoothed_relu(z, sigma):
    """
    Smooth approximation of ReLU using Gaussian mollifier.

    Args:
        z: tensor
        sigma: smoothing parameter (scalar or tensor broadcastable to z)

    Returns:
        Smoothed ReLU approximation, same shape as z
    """
    sqrt_2 = math.sqrt(2)
    sqrt_2pi = math.sqrt(2 * math.pi)

    scaled = z / sigma
    phi = torch.exp(-0.5 * scaled**2) / sqrt_2pi
    Phi = 0.5 * (1 + torch.erf(scaled / sqrt_2))

    return sigma * phi + z * Phi

def mollified_relu_simplex_core(x, sigma):
    """
    Smooth version of the ReLU simplex loss using smoothed ReLU.

    Args:
        x: (n, p) tensor
        sigma: smoothing parameter (float or tensor)

    Returns:
        Tensor of shape (n,) with smoothed loss values
    """
    if x.ndim == 1:
        x = x.unsqueeze(0)
    
    n, p = x.shape
    sigma = torch.as_tensor(sigma, dtype=x.dtype, device=x.device)
    p = torch.tensor(p, device=x.device)
    
    # Negative entries penalty
    neg_part = smoothed_relu(-x, sigma).sum(dim=1)

    # Sum constraint penalty
    sum_constraint = torch.sum(x) - 1  # penalize when > 1
    sum_part = smoothed_relu(sum_constraint, sigma*torch.sqrt(p))

    return neg_part + sum_part

In [160]:
temp = LossFunctionWrapper(mollified_relu_simplex_core)

p = 6

alpha = torch.ones(p, p, dtype= torch.float32) + torch.eye(p, dtype= torch.float32)*0
# alpha = torch.tensor([[10, 1, 1, 1], [1, 10, 1], [1, 1, 10]], dtype= torch.float64)
# alpha = torch.tensor([[5, 1]], dtype= torch.float64)
n= 30000
K, p = alpha.shape
torch.manual_seed(5)
dir = Dirichlet(alpha)
X = dir.sample((n // K,)).transpose(0, 1).reshape(n, p)[:, :p].to("cuda")

In [165]:
H1 = temp.hessian(X[0, :], 0.02)

In [166]:
H2 = mollified_relu_hess(X[0, :], 0.02)

tensor([6.0487e-01, 1.7785e+00, 6.8381e-03, 0.0000e+00, 3.9452e-22, 8.5327e+00],
       device='cuda:0')
tensor(8.1434, device='cuda:0')


In [167]:
H1 - H2

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00, -2.8610e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00]], device='cuda:0')

In [151]:
x = X[0, :]
x.requires_grad_(True)
loss = mollified_relu_simplex_core(x, 0.02)
grad = torch.autograd.grad(loss, x, create_graph=True)[0]

# Second derivative (Hessian)
p = x.numel()
hessian = torch.zeros(p, p, dtype=x.dtype, device=x.device)

for i in range(p):
    # Compute gradient of grad[i] w.r.t. x again
    grad2 = torch.autograd.grad(grad[i], x, retain_graph=True)[0]
    hessian[i] = grad2
print(hessian)

tensor([[20.5520, 19.9471, 19.9471, 19.9471, 19.9471, 19.9471],
        [19.9471, 21.7256, 19.9471, 19.9471, 19.9471, 19.9471],
        [19.9471, 19.9471, 19.9540, 19.9471, 19.9471, 19.9471],
        [19.9471, 19.9471, 19.9471, 19.9471, 19.9471, 19.9471],
        [19.9471, 19.9471, 19.9471, 19.9471, 19.9471, 19.9471],
        [19.9471, 19.9471, 19.9471, 19.9471, 19.9471, 28.4798]],
       device='cuda:0')


In [155]:
x = X[0, 0]
x.requires_grad_(True)
loss = smoothed_relu(x, 0.02)
grad = torch.autograd.grad(loss, x, create_graph=True)[0]
grad.backward()
x.grad

tensor(0.6049, device='cuda:0')

In [154]:
grad

tensor(0.9959, device='cuda:0', grad_fn=<AddBackward0>)

In [133]:
mollified_relu_grad(X[0, :].unsqueeze(dim = 0), 0.02)

tensor([[0.4959, 0.4861, 0.5000, 0.5000, 0.5000, 0.4037]], device='cuda:0')

tensor([0.4959, 0.4861, 0.5000, 0.5000, 0.5000, 0.4037], device='cuda:0')