# Forward Problem
## Diffusion 

In [1]:
from diffusion import DiffusionNN, make_forward_fn, make_diffusion_loss
import torch
import torchopt
from tqdm import tqdm



In [2]:
#Constants to initialize
x_domain = (-1, 1)
t_domain = (0, 1)
learning_rate = 0.1
n_epochs = 100
batch_size = 30

In [3]:
#Model and functional setup
diffusion_model = DiffusionNN()
diffusion_function = make_forward_fn(diffusion_model)

diffusion_loss = make_diffusion_loss(diffusion_function)

optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=learning_rate))

params = tuple(diffusion_model.parameters())

In [4]:
x = torch.FloatTensor(batch_size).uniform_(x_domain[0], x_domain[1])
t = torch.FloatTensor(batch_size).uniform_(t_domain[0], t_domain[1])
x_t = torch.stack((x,t), dim=1)

In [5]:
x_t

tensor([[-0.8789,  0.5529],
        [-0.3267,  0.2942],
        [ 0.5620,  0.1917],
        [-0.8159,  0.6720],
        [-0.5322,  0.7960],
        [ 0.8122,  0.7291],
        [ 0.7341,  0.3987],
        [ 0.8089,  0.5549],
        [-0.3777,  0.1345],
        [-0.8669,  0.3130],
        [-0.2321,  0.2138],
        [-0.2366,  0.9371],
        [ 0.2720,  0.5531],
        [ 0.5243,  0.1285],
        [ 0.2630,  0.3231],
        [-0.4307,  0.9378],
        [-0.4262,  0.4033],
        [-0.1910,  0.6391],
        [-0.8715,  0.4302],
        [-0.8325,  0.5432],
        [ 0.4995,  0.9007],
        [ 0.4882,  0.5638],
        [-0.8755,  0.5059],
        [-0.2381,  0.2446],
        [-0.8409,  0.7617],
        [ 0.2896,  0.3959],
        [ 0.2005,  0.3107],
        [ 0.4191,  0.0215],
        [-0.4552,  0.1518],
        [-0.0383,  0.7156]])

In [6]:
result = diffusion_function(x_t[0], params)

In [7]:
result.shape

torch.Size([])

In [8]:
from torch.func import grad, vmap

In [37]:
grad_u = grad(diffusion_function)

In [40]:
def dudt(x_t: torch.Tensor, params: torch.Tensor):
    return vmap(grad_u, in_dims=(0, None))(x_t, params)[:, 1:].squeeze()

In [47]:
def dudx(x_t: torch.Tensor, params: torch.Tensor):
    return vmap(grad_u, in_dims=(0, None))(x_t, params)[:, :1].squeeze()

In [45]:
#Defining the second derivative w.r.t. x
def d2udx2(x_t: torch.Tensor, params: torch.Tensor):
    pass

In [41]:
dudt(x_t, params)

tensor([-0.0468, -0.0501, -0.0464, -0.0456, -0.0431, -0.0350, -0.0412, -0.0379,
        -0.0516, -0.0486, -0.0509, -0.0388, -0.0428, -0.0476, -0.0470, -0.0400,
        -0.0489, -0.0443, -0.0479, -0.0470, -0.0344, -0.0408, -0.0473, -0.0506,
        -0.0444, -0.0455, -0.0477, -0.0499, -0.0513, -0.0418],
       grad_fn=<SqueezeBackward0>)

In [48]:
dudx(x_t, params)

tensor([-0.0164, -0.0141, -0.0028, -0.0166, -0.0147,  0.0058,  0.0016,  0.0041,
        -0.0144, -0.0154, -0.0133, -0.0101, -0.0040, -0.0039, -0.0059, -0.0129,
        -0.0148, -0.0112, -0.0160, -0.0164,  0.0026, -0.0002, -0.0163, -0.0133,
        -0.0166, -0.0049, -0.0070, -0.0062, -0.0148, -0.0083],
       grad_fn=<SqueezeBackward0>)

In [54]:
f = grad(dudx)

In [55]:
f(x_t[0], params)

RuntimeError: both arguments to matmul need to be at least 1D, but they are 0D and 2D

In [46]:
d2udx2(x_t, params)

RuntimeError: both arguments to matmul need to be at least 1D, but they are 0D and 2D

In [None]:
#Defining the training process
loss_evolution = []

for epoch in tqdm(range(n_epochs)):
    #Let's start with a uniform distribution of the data 
    x = torch.FloatTensor(batch_size).uniform_(x_domain[0], x_domain[1])
    t = torch.FloatTensor(batch_size).uniform_(t_domain[0], t_domain[1])

    #We compute the loss
    loss = diffusion_loss(x, t, params)
    #Update the parameters with the functional optimizer
    params = optimizer.step(loss, params)
    #Keeping track of the loss 
    loss_evolution.append(float(loss))
    tqdm.write(f"At epoch {epoch + 1} the model has the following loss: {loss}")