In [2]:
import jax
import jax.random as random
import jax.numpy as jnp

import flax

from differentials import expression, domain, boundary, initial

In [3]:
dx = lambda u: jax.grad(u, argnums=0)
dt = lambda u: jax.grad(u, argnums=1)

heat = expression(
    lambda u: lambda x, t: dt(u)(x, t) + dx(dx(u))(x, t),
    var=("x", "t"),
    boundaries=(
        # insulated ends u_x(0, t) = 0
        boundary(
            LHS=lambda u: lambda x, t: dx(u)(x, t),
            RHS=lambda u: lambda x, t: 0.0,
            con=(0.0, "t")
        ),
        # insulated end u_x(L, t) = 0
        boundary(
            LHS=lambda u: lambda x, t: dx(u)(x, t),
            RHS=lambda u: lambda x, t: 0.0,
            con=(1.0, "t")
        ),
        # inital function. u(x, 0) = sin(x)
        initial(
            LHS=lambda u: lambda x, t: u(x, t),
            RHS=lambda u: lambda x, t: jnp.sin(x),
            con=("x", 0.0)
        )
    ),
    x=domain(0, 1),
    t=domain(0, 1)
)

In [8]:
import time
from tqdm import tqdm

In [11]:
u_hat, params = heat.u()

x =jnp.array([1.0,2.0])
xs = heat.matrix(10)

#start val

def make_loss(expression, n=40):
    u_hat, _ = expression.u((4,4))
    # hyper param, num of samples per loss
    xs = expression.matrix(n)
    def loss(params):
        def loss_unit(x):
            error = expression.loss(
                lambda x, t: u_hat.apply(params, jnp.array((x, t)))[0],
                x[0], x[1]  # this is for x and t. No better way exists to do this
            )
            return error
        return jnp.max(jax.vmap(loss_unit)(xs))
        return jnp.mean(jax.vmap(loss_unit)(xs))
        # here there is a contention. What loss is better, the worst point tested, or the average point tested
    return jax.jit(loss)

start = time.time()
heat_loss = make_loss(heat)
heat_loss, grads = jax.value_and_grad(heat_loss)(params)
end = time.time()

print(end - start)

print(heat_loss)
print(grads)




ScopeParamShapeError: Initializer expected to generate shape (4, 5) but got shape (4, 4) instead for parameter "kernel" in "/Dense_2". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [None]:
# timing the loss / average loss with a mean function vs a min function. 

def make_mean_loss(expression):
    u_hat, _ = expression.u()
    # hyper param, num of samples per loss
    xs = expression.matrix(10)
    def loss(params):
        def loss_unit(x):
            error = expression.loss(
                lambda x, t: u_hat.apply(params, jnp.array((x, t)))[0],
                x[0], x[1]  # this is for x and t. No better way exists to do this
            )
            return error
        return jnp.mean(jax.vmap(loss_unit)(xs))
    return jax.jit(loss)

def make_max_loss(expression):
    u_hat, _ = expression.u()
    # hyper param, num of samples per loss
    xs = expression.matrix(10)
    def loss(params):
        def loss_unit(x):
            error = expression.loss(
                lambda x, t: u_hat.apply(params, jnp.array((x, t)))[0],
                x[0], x[1]  # this is for x and t. No better way exists to do this
            )
            return error
        return jnp.max(jax.vmap(loss_unit)(xs))
    return jax.jit(loss)

# 
n_10_mean_loss = list()
n_10_mean_time = list()
n_10_max_loss = list()
n_10_max_time = list()

for i in tqdm(range(100)):
    print("n = 10")
    start = time.time()
    heat_loss = make_mean_loss(heat)
    heat_loss, grads = jax.value_and_grad(heat_loss)(params)
    end = time.time()
    
    n_10_mean_loss.append(heat_loss)
    n_10_mean_time.append(end-start)
    
    start = time.time()
    heat_loss = make_max_loss(heat)
    heat_loss, grads = jax.value_and_grad(heat_loss)(params)
    end = time.time()
    
    n_10_max_loss.append(heat_loss)
    n_10_max_time.append(end-start)
    
#.  n = 30

def make_mean_loss(expression):
    u_hat, _ = expression.u()
    # hyper param, num of samples per loss
    xs = expression.matrix(30)
    def loss(params):
        def loss_unit(x):
            error = expression.loss(
                lambda x, t: u_hat.apply(params, jnp.array((x, t)))[0],
                x[0], x[1]  # this is for x and t. No better way exists to do this
            )
            return error
        return jnp.mean(jax.vmap(loss_unit)(xs))
    return jax.jit(loss)

def make_max_loss(expression):
    u_hat, _ = expression.u()
    # hyper param, num of samples per loss
    xs = expression.matrix(30)
    def loss(params):
        def loss_unit(x):
            error = expression.loss(
                lambda x, t: u_hat.apply(params, jnp.array((x, t)))[0],
                x[0], x[1]  # this is for x and t. No better way exists to do this
            )
            return error
        return jnp.max(jax.vmap(loss_unit)(xs))
    return jax.jit(loss)

# 
n_30_mean_loss = list()
n_30_mean_time = list()
n_30_max_loss = list()
n_30_max_time = list()

for i in tqdm(range(100)):
    print("n = 30")
    start = time.time()
    heat_loss = make_mean_loss(heat)
    heat_loss, grads = jax.value_and_grad(heat_loss)(params)
    end = time.time()
    
    n_30_mean_loss.append(heat_loss)
    n_30_mean_time.append(end-start)
    
    start = time.time()
    heat_loss = make_max_loss(heat)
    heat_loss, grads = jax.value_and_grad(heat_loss)(params)
    end = time.time()
    
    n_30_max_loss.append(heat_loss)
    n_30_max_time.append(end-start)

# displaying

n_30_mean_loss_arr = jnp.array(n_30_mean_loss)
n_30_mean_time_arr = jnp.array(n_30_mean_time)
n_30_max_loss_arr = jnp.array(n_30_max_loss)
n_30_max_time_arr = jnp.array(n_30_max_time)

n_10_mean_loss_arr = jnp.array(n_10_mean_loss)
n_10_mean_time_arr = jnp.array(n_10_mean_time)
n_10_max_loss_arr = jnp.array(n_10_max_loss)
n_10_max_time_arr = jnp.array(n_10_max_time)

# Compute mean and standard deviation for n = 30
n_30_mean_loss_mean = jnp.mean(n_30_mean_loss_arr)
n_30_mean_loss_std = jnp.std(n_30_mean_loss_arr)
n_30_mean_time_mean = jnp.mean(n_30_mean_time_arr)
n_30_mean_time_std = jnp.std(n_30_mean_time_arr)

n_30_max_loss_mean = jnp.mean(n_30_max_loss_arr)
n_30_max_loss_std = jnp.std(n_30_max_loss_arr)
n_30_max_time_mean = jnp.mean(n_30_max_time_arr)
n_30_max_time_std = jnp.std(n_30_max_time_arr)

# Compute mean and standard deviation for n = 10
n_10_mean_loss_mean = jnp.mean(n_10_mean_loss_arr)
n_10_mean_loss_std = jnp.std(n_10_mean_loss_arr)
n_10_mean_time_mean = jnp.mean(n_10_mean_time_arr)
n_10_mean_time_std = jnp.std(n_10_mean_time_arr)

n_10_max_loss_mean = jnp.mean(n_10_max_loss_arr)
n_10_max_loss_std = jnp.std(n_10_max_loss_arr)
n_10_max_time_mean = jnp.mean(n_10_max_time_arr)
n_10_max_time_std = jnp.std(n_10_max_time_arr)

# Print results in a summary table format, including the actual statistics
print(f"{'Sample Size':<12} {'Method':<6} {'Metric':<5} {'Mean':<10} {'Standard Deviation':<10}")
print(f"{'n = 30':<12} {'Mean':<6} {'Loss':<5} {n_30_mean_loss_mean:<10.4f} {n_30_mean_loss_std:<10.4f}")
print(f"{'n = 30':<12} {'Mean':<6} {'Time':<5} {n_30_mean_time_mean:<10.4f} {n_30_mean_time_std:<10.4f}")
print(f"{'n = 10':<12} {'Mean':<6} {'Loss':<5} {n_10_mean_loss_mean:<10.4f} {n_10_mean_loss_std:<10.4f}")
print(f"{'n = 10':<12} {'Mean':<6} {'Time':<5} {n_10_mean_time_mean:<10.4f} {n_10_mean_time_std:<10.4f}")
print()
print(f"{'n = 30':<12} {'Max':<6} {'Loss':<5} {n_30_max_loss_mean:<10.4f} {n_30_max_loss_std:<10.4f}")
print(f"{'n = 30':<12} {'Max':<6} {'Time':<5} {n_30_max_time_mean:<10.4f} {n_30_max_time_std:<10.4f}")
print(f"{'n = 10':<12} {'Max':<6} {'Loss':<5} {n_10_max_loss_mean:<10.4f} {n_10_max_loss_std:<10.4f}")
print(f"{'n = 10':<12} {'Max':<6} {'Time':<5} {n_10_max_time_mean:<10.4f} {n_10_max_time_std:<10.4f}")

  0%|                                                                                                                 | 0/100 [00:00<?, ?it/s]

n = 10


  1%|█                                                                                                        | 1/100 [00:08<14:48,  8.97s/it]

n = 10


  2%|██                                                                                                       | 2/100 [00:14<11:22,  6.96s/it]

n = 10


  3%|███▏                                                                                                     | 3/100 [00:20<10:13,  6.32s/it]

n = 10


  4%|████▏                                                                                                    | 4/100 [00:25<09:39,  6.03s/it]

n = 10


In [32]:
# loss metrics