In [41]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import math
import time
# import plotly.graph_objs as go

In [11]:
print(jnp.argmin(jnp.array([1,1,0,1])))

2


In [42]:
def largest_circle(x,y):
    return min(x,y,1-x,1-y)

# our 'h' function on the boundary
def boundary_function(x,y,z):
    # for sufficiently small eps, we can make this approximation
    index = jnp.argmin( jnp.array([x,y,1-x,1-y]) )
    if index == 1:
        return z
    else:
        return 0

# numpy version
# # given center x,y and radius r, return a random point on the circle boundary
def rand_point(x,y,r):
    theta = np.random.uniform(0, 2 * jnp.pi)
    # Convert spherical coordinates to Cartesian coordinates
    return x + jnp.cos(theta) * r, y + jnp.sin(theta) * r

# jax version
# given center x,y and radius r, return a random point on the circle boundary
# key_index = 0
# def rand_point(x,y,r):
#     key = jax.random.PRNGKey(key_index) 
#     theta = jax.random.uniform(key, minval=0, maxval=2*jnp.pi)
#     key_index += 1
#     # Convert spherical coordinates to Cartesian coordinates
#     return x + jnp.cos(theta) * r, y + jnp.sin(theta) * r

In [35]:
# print(boundary_function(0.1,0.3,3))
key_index = 0
key = jax.random.PRNGKey(key_index)
# Generate a random number between 0 and 2*pi
random_number = jax.random.uniform(key, minval=0, maxval=2*jnp.pi)
print(random_number)

1.0255733


In [36]:
print(key)

[  0 100]


In [2]:

# our 'h' function on the boundary
def boundary_function(x,y,z):
    # for sufficiently small eps, we can make this approximation
    index = jnp.argmin( jnp.array([x,y,1-x,1-y]) )
    if index == 1:
        return z
    else:
        return 0
    
    # our 'h' function on the boundary
def boundary_function_grad(x,y,z):
    # for sufficiently small eps, we can make this approximation
    index = jnp.argmin( jnp.array([x,y,1-x,1-y]) )
    if index == 1:
        return 1
    else:
        return 0

def solve(x = 1/4,y = 1/4, z=0, eps = 1e-4,max_iters = 10000):   
    running_function_value = 0
    # walk on spheres
    for _ in range(max_iters):
        x_new,y_new = x,y 
        r_max = min(x,y,1-x,1-y)
        # check point is outisde eps distance from boundary
        while  r_max> eps:
            # pick random point on circle 
            x_new,y_new = rand_point(x_new,y_new, r_max)
            # calculate the largest circle that will fit in domain
            r_max = min(x_new,y_new,1-x_new,1-y_new)
            
        running_function_value += boundary_function( x_new, y_new, z)

    #taking mean
    numerical_value = running_function_value/max_iters
    return numerical_value

def grad_solve(x_init = 1/4,y_init = 1/4, z=1, eps = 1e-4,max_iters = 10000):
    running_function_value = 0
    # walk on spheres
    for _ in range(max_iters):
        x_new,y_new = x_init,y_init 
        r_max = min(x_new,y_new,1-x_new,1-y_new)
        # check point is outisde eps distance from boundary
        while  r_max> eps:
            # pick random point on circle 
            x_new,y_new = rand_point(x_new,y_new, r_max)
            # calculate the largest circle that will fit in domain
            r_max = min(x_new,y_new,1-x_new,1-y_new)
            
        running_function_value += boundary_function_grad( x_new, y_new, z)

    #taking mean
    numerical_value = running_function_value/max_iters
    return numerical_value

In [65]:
# checking solve
def f(z):
    return solve(z=z)
solution1 = solve(z =10)
print(solution1)

solution2 = f(10.0)
print(solution2)

# Compute the gradient of f with respect to x
grad_f = jax.grad(f)

# Evaluate the gradient at a specific point, e.g., x = 0
x = 10.0
gradient_value = grad_f(x)

print(f'Gradient of function at x={x} is {gradient_value}')

4.342
4.293
Gradient of function at x=10.0 is 0.4334793984889984


In [62]:
for _ in range(10):
    print(grad_f(10.0))

0.45199773
0.43599793
0.4149982
0.429998
0.41899815
0.4619976
0.44299784
0.429998
0.4149982
0.4529977


In [60]:
epsi = 1e-2
# u_eps = 1e-16
# epsi = 10*jnp.sqrt(u_eps)
x_0 = f(z = 10-epsi)
x_1 = f(z = 10+epsi)
print(f"gradient is {(x_1-x_0)/(2*epsi)}")

gradient is -18.079499999995495


In [61]:
print(f(z = 10))
print(f(z = 10.5))
print(f(z = 11))

4.37
4.599
4.983


In [63]:
print(4.983-4.37)

0.6129999999999995
