In [1]:
import jax
import jax.numpy as jnp 
from jax import random
import equinox as eqx 
import optax
import matplotlib.pyplot as plt  

## Equação do Calor

A equação do calor é dada através da seguinte equação diferencial:

$$
\dfrac{\partial u}{\partial t} = \alpha \dfrac{\partial^2 u}{\partial x^2}, \hspace{1cm}  x\in [0,1],t\in [0,1]
$$

onde $\alpha = 0.4$, $u(0,t) = u(1,t) = 0$ e $u(x,0) = \sin(\dfrac{n \pi x}{L}), 0 < x < L$ e $n = 1,2,3,\ldots$.

Teremos que $L = 1$, como a largura da barra e $n=1$ a frequencia senoidal da condição inicial.

A solução exata é dada por $u(x,t) = e^{\frac{-n^2 \pi^2 \alpha t}{L^2}} \sin(\dfrac{n \pi x}{L})$

In [2]:
#PARAMETROS
alpha = 0.4
L = 1
N = 1

#DECLARAÇÃO DO DOMÍNIO
x_dom = (0, 1)
t_dom = (0, 1)
batch_size = 1000

#OBTENÇÃO DOS DADOS DE TREINAMENTO
N_SAMPLES = 200
key_x, key_t, key_model = jax.random.split(jax.random.PRNGKey(0),3)

#-----------------------------------------------------------------------------------
#PONTOS INTERNOS
x_samples = random.uniform(key_x, (N_SAMPLES, 1), minval=x_dom[0], maxval=x_dom[1])
t_samples = random.uniform(key_t, (N_SAMPLES, 1), minval=t_dom[0], maxval=t_dom[1])
intern_points = jnp.concatenate([x_samples, t_samples], 1)

#-----------------------------------------------------------------------------------
#Condição de contorno 1: u[0, t]
x_b1 = jnp.zeros_like(x_samples)
bc_1 = jnp.concatenate([x_b1, t_samples], 1)
bc1_target = jnp.zeros_like(x_samples)

#Condição de contorno 1: u[1, t]
x_b2 = jnp.ones_like(x_samples)
bc_2 = jnp.concatenate([x_b2, t_samples], 1)
bc2_target = jnp.zeros_like(x_samples)

#Condição inicial: u[x, 0]
t_ic = jnp.zeros_like(t_samples) 
ic = jnp.concatenate([x_samples, t_ic], 1)
ic_target = jnp.ones_like(t_samples) * jnp.sin((N * jnp.pi * x_samples) / L)

#-----------------------------------------------------------------------------------


geometry = [bc_1, bc_2, ic]
target = [bc1_target, bc2_target, ic_target]

#plt.scatter(x_b1, t_samples)
#plt.scatter(x_b2, t_samples)
#plt.scatter(x_samples, t_ic)
#plt.scatter(x_samples, t_samples)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:

class Linear:
    weight: jnp.ndarray
    bias: jnp.ndarray

    def __init__(self, n_input: int, n_output: int, key = jax.random.PRNGKey(0)):

        self.weight = random.uniform(key, (n_input, n_output))
        self.bias = random.uniform(key, (n_output, ))

    def __call__(self, input):
        return jnp.dot(input, self.weight) + self.bias
    
class NeuralNetwork:

    layers: list[Linear]
    activation: callable

    def __init__(self, layers: list[int], activation_func):

        self.layers = [Linear(m, n) for m, n in zip(layers[:-1], layers[1:])]
        self.activation = activation_func

    def __call__(self, input):
        
        *hidden, last = self.layers

        for layer in hidden:
            out = layer(input)
            return last(self.activation(out))

class PINNs: 

    model: NeuralNetwork 
    residual_function: callable

    def __init__(self, layers: list[int], act_function, residual):
        self.model = NeuralNetwork(layers, act_function)
        self.residual_function = residual
        
    def MSE_loss(self, input, target):
        out_prediction = self.model(input)
        return jnp.mean((target - out_prediction)**2)

    def loss_residual(self, input):
        out = self.residual_function(self.model, input)
        print(out)

    def loss_conditions(self, geometry, target):
        
        loss_list = []
        for g, t in zip(geometry, target):
            l = self.MSE_loss(g, t)
            loss_list.append(l)
        return sum(loss_list)

    def trainning(self, lr, epochs):
        pass

In [24]:
from jax import jacobian, grad, hessian, jacfwd, jacrev

#DECLARAÇÃO DA EQUAÇÃO RESIDUAL 
def eq_residual(model, x, t):

    input = jnp.concatenate(x, t)
    u = model(input)

    u_t = grad(u, argnums=1)(input)
    u_xx = grad(grad(u, argnums=0), argnums=0)(input)

    return u_t - alpha * u_xx

LAYERS = [2, 20, 20, 20, 1]
model = PINNs(layers = LAYERS, act_function = jax.nn.tanh, residual = eq_residual)
#model.loss_conditions(geometry, target)

u = NeuralNetwork(LAYERS, jax.nn.tanh)
input = jnp.concatenate([x_samples, t_samples],1)

def res(model, x, t):
    input = jnp.concatenate((x, t),1)
    return model(input)
    

f = lambda model, x, t: model(jnp.concatenate((x,t),1))
grad(f, argnums=1)(u, x_samples, t_samples)

TypeError: Gradient only defined for scalar-output functions. Output had shape: (200, 1).