In [20]:
import jax.numpy as jnp
import torch

import numpy as np

In [23]:
samples = np.ones((10,3))
vectors = np.random.randn(30).reshape(10,3)

# Torch test

In [42]:
def score1(x):
    
    return x

def score2(x):
    
    return torch.vmap(torch.diag)(x)

In [43]:
samples_torch = torch.Tensor(samples).reshape(10,-1)
vectors_torch = torch.Tensor(vectors).reshape(10,-1)

In [50]:
# pair x+sigma*z with x-sigma*x
# updated to improve numerical stability
def hosm_plus_vr_low_rank(score1, score2, samples, vectors, sigma=0.01):
    n, dim = vectors.shape

    perturbed_inputs1 = samples + vectors * sigma
    perturbed_inputs2 = samples - vectors * sigma

    # x+sigma*z
    s2_1 = score2(perturbed_inputs1).reshape(n, dim, dim)
    # with torch.no_grad():
    s1_1 = score1(perturbed_inputs1)
    s1_product_1 = torch.einsum('ij, ik -> ijk', s1_1, s1_1)
    h_1 = (s2_1 + s1_product_1).view(n, -1)

    # x-sigma*z
    s2_2 = score2(perturbed_inputs2).reshape(n, dim, dim)
    # with torch.no_grad():
    s1_2 = score1(perturbed_inputs2)
    s1_product_2 = torch.einsum('ij, ik -> ijk', s1_2, s1_2)
    h_2 = (s2_2 + s1_product_2).view(n, -1)

    # (I - z*z^T) / sigma ** 2
    vectors_product = torch.einsum('ij, ik -> ijk', vectors, vectors)
    eye = torch.eye(dim, device=vectors.device)
    eye = eye.unsqueeze(0)
    eye = eye.repeat(n, 1, 1)
    diff = (eye - vectors_product) / (sigma ** 2)

    s2_vr = score2(samples).reshape(n, dim, dim)
    # with torch.no_grad():
    s1_vr = score1(samples)
    s1_product_vr = torch.einsum('ij, ik -> ijk', s1_vr, s1_vr)
    h_vr = (s2_vr + s1_product_vr).view(n, -1)

    loss = (h_1 ** 2 + h_2 ** 2) + 2 * diff.view(n, -1) * ((h_1 - h_vr) + (h_2 - h_vr))
    loss = loss.sum(dim=-1)
    loss = loss.mean(dim=0) / 2.

    return loss

In [51]:
hosm_plus_vr_low_rank(score1, score2, samples_torch, vectors_torch)

tensor(-16.4649)

# JAX Test

In [52]:
samples_jax = jnp.array(samples).reshape(10,-1)
vectors_jax = jnp.array(vectors).reshape(10,-1)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [None]:
def f(samples,vectors,sigma):
            
    dW = vectors/sigma

    xp = samples+dW
    xm = samples-dW

    s1 = generator.grad_TM(s1_model, x0, x0, t)
    s2 = generator.proj_hess(s1_model, s2_model, x0, x0, t)

    s1p = generator.grad_TM(s1_model, x0, xp, t)
    s2p = generator.proj_hess(s1_model, s2_model, x0, xp, t)

    s1m = generator.grad_TM(s1_model, x0, xm, t)
    s2m = generator.proj_hess(s1_model, s2_model, x0, xm, t)

    psi = s2+jnp.einsum('i,j->ij', s1, s1)
    psip = s2p+jnp.einsum('i,j->ij', s1p, s1p)
    psim = s2m+jnp.einsum('i,j->ij', s1m, s1m)
    diff = (jnp.eye(N_dim)-jnp.einsum('i,j->ij', dW, dW)/(sigma**2))/sigma

    loss1 = psip**2
    loss2 = psim**2
    loss3 = 2*diff*((psip-psi)+(psim-psi))

    loss_s2 = loss1+loss2+loss3

    return 0.5*jnp.sum(loss_s2)#jnp.mean(loss_s2)