In [6]:
# import statements
import jax
import matplotlib.pyplot as plt
import math
import time
import numpy as np
import plotly.graph_objs as go
import jax.numpy as jnp

In [44]:
# our 'h' function on the boundary
def boundary_function(x,y,z):
    # for sufficiently small eps, we can make this approximation
    index = jnp.argmin( jnp.array([x,y,1-x,1-y]) )
    if index == 1:
        return z
    else:
        return 0
    
# gradient our 'h' function on the boundary
def boundary_function_grad(x,y,z):
    # for sufficiently small eps, we can make this approximation
    index = jnp.argmin( jnp.array([x,y,1-x,1-y]) )
    if index == 1:
        return 1
    else:
        return 0

# solves for u at specified x,y and temperature z using walk on spheres
def solve(x = 1/4,y = 1/4, z=0, eps = 1e-4,max_iters = 500):   
    running_function_value = 0
    running_gradient_value = 0
    # walk on spheres
    for _ in range(max_iters):
        x_new,y_new = x,y 
        r_max = min(x_new,y_new,1-x_new,1-y_new)
        # check point is outisde eps distance from boundary
        while  r_max> eps:
            # pick random point on circle 
            theta = np.random.uniform(0, 2 * jnp.pi)
            x_new = x_new + jnp.cos(theta) * r_max
            y_new = y_new + jnp.sin(theta) * r_max
            # calculate the largest circle that will fit in domain
            r_max = min(x_new,y_new,1-x_new,1-y_new)
            
        running_function_value += boundary_function( x_new, y_new, z)
        running_gradient_value += boundary_function_grad( x_new, y_new, z)

    #taking mean
    numerical_value = running_function_value/max_iters
    numerical_grad_value = running_gradient_value/max_iters
    return [numerical_value, numerical_grad_value]

# WoS for gradient
def grad_solve(x_init = 1/4,y_init = 1/4, z=1, eps = 1e-4,max_iters = 1000):
    running_function_value = 0
    # walk on spheres
    for _ in range(max_iters):
        x_new,y_new = x_init,y_init 
        r_max = min(x_new,y_new,1-x_new,1-y_new)
        # check point is outisde eps distance from boundary
        while  r_max> eps:
            # pick random point on circle 
            theta = np.random.uniform(0, 2 * jnp.pi)
            x_new = x_new + jnp.cos(theta) * r_max
            y_new = y_new + jnp.sin(theta) * r_max
            # calculate the largest circle that will fit in domain
            r_max = min(x_new,y_new,1-x_new,1-y_new)
                        
        running_function_value += boundary_function_grad( x_new, y_new, z)

    #taking mean
    numerical_value = running_function_value/max_iters
    return numerical_value

In [34]:
print(solve(z=10.0, max_iters= 500))

4.18


In [26]:
# checking solve
def f(z):
    return solve(z=z)
solution1 = solve(z =10.0)
print(solution1)

solution2 = f(10.0)
print(solution2)

# Compute the gradient of f with respect to x
grad_f = jax.grad(f)

# Evaluate the gradient at a specific point, e.g., x = 0
z = 10.0
gradient_value = grad_f(z)
grad_value_wos = grad_solve(z = z)

print(f'Gradient of function at z={z} is {gradient_value}')
print(f'Gradient using WoS  at z={z} is {gradient_value}')

4.278
4.319


KeyboardInterrupt: 

In [55]:
# consider the unit square with temperature z at the x-axis. 
# sequential search for unit square
def seq_search(target_temp = 15, heater_low = 0, heater_high= 100, error_min = 0.1):
    errors_low = []
    errors_high = []
    count = 0
    # initial
    temp_low = solve(z = heater_low)[0]
    low_error =np.abs(temp_low - target_temp) 
    errors_low.append(low_error)

    temp_high = solve(z = heater_high)[0]
    high_error =np.abs(temp_high - target_temp) 
    errors_high.append(high_error)

    if target_temp<temp_low or target_temp>temp_high:
        print(f"Initial heater settings are insufficent")
        print(f"Initial temperatures are: {temp_low} and {temp_high}")
        print("Please readjust")
        return
    #success
    if low_error< error_min:
        return heater_low       
    if high_error< error_min:
        return heater_high
    
    count+=1
    
    # repeat
    while count <=2 or jnp.abs(errors_low[-1] - errors_low[-2])/errors_low[-1] > error_min or jnp.abs(errors_low[-1] - errors_low[-2])/errors_low[-1] > error_min:
           
        # if closer to lower temperature
        if low_error < high_error:
            heater_high = (heater_high+heater_low)/2
            temp_high = solve(z = heater_high)[0]
            high_error =np.abs(temp_high - target_temp) 
            errors_high.append(high_error)
            if high_error< error_min:
                return heater_high
        # if closer to higher temperature
        else:
            heater_low = (heater_high+heater_low)/2
            temp_low = solve(z = heater_low)[0]
            low_error =np.abs(temp_low - target_temp) 
            errors_low.append(low_error)
            if low_error< error_min:
                return heater_low
            
        count += 1
        if count%10 == 0:
            print(f"Iter: {count}, error_low: {low_error}, error_low:{high_error}")
            print(f"temp_low: {heater_low}, temp_high:{heater_high}")

        print(f"Iter: {count}, error_low: {low_error}, error_low:{high_error}")
        print(f"temp_low: {temp_low}, temp_high:{temp_high}")
        print(f"heater_low: {heater_low}, heater_high:{heater_high}")


    print("we reach final step")
    if errors_low[-1] < errors_high[-1]:
        return heater_low
    else:
        return heater_high
        

In [57]:
z_optim = seq_search(15,0,50,0.1)

Iter: 2, error_low: 3.9499999999999993, error_low:5.199999999999999
temp_low: 11.05, temp_high:20.2
heater_low: 25.0, heater_high:50
Iter: 3, error_low: 3.9499999999999993, error_low:2.1000000000000014
temp_low: 11.05, temp_high:17.1
heater_low: 25.0, heater_high:37.5
Iter: 4, error_low: 1.625, error_low:2.1000000000000014
temp_low: 13.375, temp_high:17.1
heater_low: 31.25, heater_high:37.5
Iter: 5, error_low: 1.625, error_low:0.9749999999999996
temp_low: 13.375, temp_high:14.025
heater_low: 31.25, heater_high:34.375
Iter: 6, error_low: 2.4656249999999993, error_low:0.9749999999999996
temp_low: 12.534375, temp_high:14.025
heater_low: 32.8125, heater_high:34.375
Iter: 7, error_low: 2.234375, error_low:0.9749999999999996
temp_low: 12.765625, temp_high:14.025
heater_low: 33.59375, heater_high:34.375
Iter: 8, error_low: 0.4968749999999993, error_low:0.9749999999999996
temp_low: 15.496875, temp_high:14.025
heater_low: 33.984375, heater_high:34.375
Iter: 9, error_low: 0.4968749999999993, err

In [52]:
print(z_optim)

35.9375


In [54]:
print(solve(z = z_optim, max_iters=1000 ))

[15.95625, 0.444]


In [66]:
def gradient_descent(target, learning_rate, num_iterations):
    # Initialize parameters
    z = 30 # start with 10 celcius
    
    for i in range(num_iterations):
        # Compute the function value and gradient
        function_value, gradient = solve(x = 1/4,y = 1/4, z=z, eps = 1e-4,max_iters = 500)
        gradient_of_loss_funct = -2*gradient*(target - function_value)
        # Update parameters
        z = z - learning_rate * gradient_of_loss_funct

        if jnp.abs(function_value-target)<0.1:
            return z
        
        # Optionally print progress
        print(f"Iteration {i+1}: function value = {function_value}, z = {z}")
    
    return z
    

target = 15 # Starting point
learning_rate = 0.05
num_iterations = 100

optimal_val = gradient_descent(target, learning_rate, num_iterations)
print(f"Optimal parameters: {optimal_val}")


Iteration 1: function value = 14.1, z = 30.0423
Iteration 2: function value = 12.13708920000002, z = 30.15796159632
Iteration 3: function value = 12.36476425449123, z = 30.26600626188586
Iteration 4: function value = 13.80129885541997, z = 30.32066703407871
Iteration 5: function value = 13.28045216092648, z = 30.39598322943013
Iteration 6: function value = 13.67819245324358, z = 30.455464569034167
Iteration 7: function value = 13.339493481236946, z = 30.528194754555987
Iteration 8: function value = 12.455503459858877, z = 30.632010213393745
Iteration 9: function value = 12.865444289625406, z = 30.721661553229477
Iteration 10: function value = 12.411551267504727, z = 30.826234882022288
Iteration 11: function value = 13.255280999269628, z = 30.901257799053692
Iteration 12: function value = 12.607713182013876, z = 30.998863101227524
Iteration 13: function value = 14.073483847957274, z = 31.040926934530265
Iteration 14: function value = 12.726780043157405, z = 31.13412895276081
Iteration 1

In [68]:
print(solve(z =33.986))

[14.41006399999997, 0.424]
