In [None]:
%pip install --upgrade pip 
%pip install --upgrade jax 
%pip install "flax[all]"

In [5]:
# Caution(0121): JAX worked in python version 3.9 -> Try new setting in desktop

import jax
import jax.numpy as jnp 

import numpy as np 

x_jnp= jnp.arange(10) 
x_np= np.arange(10) 

print(x_jnp) 
print(x_np)

[0 1 2 3 4 5 6 7 8 9]
[0 1 2 3 4 5 6 7 8 9]


In [7]:
def selu(x, alpha=1.67, lambda_=1.05): 
    return lambda_*jnp.where(x>0, x, alpha*jnp.exp(x)-alpha) 

x= jnp.arange(1000000) 
%timeit selu(x).block_until_ready() 

selu_jit= jax.jit(selu) 

selu_jit(x).block_until_ready() 
%timeit selu_jit(x).block_until_ready()

1.32 ms ± 55.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
383 μs ± 449 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
f= lambda x: x**3+2*x**2-2*x+1 

dfdx= jax.grad(f) 
d2fdx= jax.grad(dfdx) 
d3fdx= jax.grad(d2fdx) 

print(dfdx(1.)) 
print(d2fdx(1.)) 
print(d3fdx(1.))

5.0
10.0
6.0


In [None]:
## Condition Settings 
import jax
import jax.numpy as jnp 
import jax.random as random 
from jax.scipy.special import logsumexp 
from jax.scipy.stats import norm , t 

# Unnomralized target gamma(x), intitial state proposal pi_0(x), number steps K, stepsize eta, annealing schedule (beta_K), damping coefficient h, mass matrix M, score model

# Algorithm1 (AIS) 

def unadjusted_langevin_ais(log_target_fn, log_initial_fn, K, step_size, beta_schedule, rng_key): 
    rng_key, subkey= random.split(rng_key) 
    x= random.normal(subkey, shape=(log_initial_fn.ndim,)) 
    log_w= -log_initial_fn(x)  
    x_samples= [x]


    for k in range(1, K+1): 
        rng_key, subkey= random.split(rng_key) 
        beta_k= beta_schedule[k] 
        beta_k_minus_1= beta_schedule[k-1]   

        def log_intermediate_fn(z): 
            return beta_k*log_target_fn(z)+ (1-beta_k)*log_initial_fn(z) 
        def log_previous_fn(z): 
            return (beta_k_minus_1*log_target_fn(z)+(1-beta_k_minus_1)*log_initial_fn(z)) 
        
        grad_x= jax.grad(log_intermediate_fn(x)) 
        noise= random.normal(subkey, shape= x.shape) 

        x_new= x+ step_size*grad_x+jnp.sqrt(2.0*step_size)*noise 

        grad_x_new= jax.grad(log_intermediate_fn)(x_new)

        def log_normal_density(z, mean, var): 
            return -0.5*jnp.log(2.0*jnp.pi*var)-0.5*jnp.sum((z-mean)**2)/var 
        
        log_F=log_normal_density(
            x_new, 
            mean= x+step_size*grad_x, 
            var= 2.0*step_size
        ) 

        log_B= log_normal_density( 
            x, 
            mean= x_new+step_size*grad_x_new, 
            var= 2.0*step_size
        ) 

        log_w= log_w+(log_B-log_F) 

        x=x_new
        x_samples.append(x) 
    
    log_w= log_w+log_target_fn(x) 

    return x_samples, log_w

def make_gaussian_mixture_logpdf(d, num_components=8, var=1.0): 
    rng= jax.random.PRNGKey(12345) 
    means= random.normal(rng, shape=(num_components, d))*3.0 

    def logpdf(x): 
        x_reshaped= x[None, :] 
        diffs= x_reshaped - means 

        sq_maha= jnp.sum(diffs**2, axis=1) 
        log_probs= -0.5*d*jnp.log(2*jnp.pi*var)-0.5*sq_maha/var 
        return logsumexp(log_probs)- jnp.log(num_components) 
    
    logpdf.ndim=d
    return logpdf 



       

