## Exercise 5: 2d regression using LBFGS
Modify the code to solve a partial differential equation 
$$
\frac{dz}{dt} = 0.05 \frac{d^2z}{dx^2}
$$
in the domain $t = (0,1)$ and $x = (0,1)$

The boundary conditions are $z(x,0) = x(1-x)$

In [None]:
## This is a code to perform 2d regression using LBFGS
## Modify the code to solve the PDE 
## dz/dt = 0.05*d2z/dx2
## in the domain t = (0,1) and x = (0,1)
## The boundary conditions are z(x,0) = x*(1-x)

import torch
import matplotlib.pyplot as plt
from torch import nn, optim
import numpy as np

# Define the true function and generate some data
def true_func(x, y):
    return torch.sin(torch.sqrt(x ** 2 + y ** 2))

# Create a grid of points
x = torch.linspace(-3, 3, 100)
y = torch.linspace(-3, 3, 100)
xx, yy = torch.meshgrid(x, y)
zz_true = true_func(xx, yy)

# Add some noise
zz = zz_true

# Convert to PyTorch tensors
input_data = torch.cat((xx.reshape(-1, 1), yy.reshape(-1, 1)), 1)
target_data = zz.reshape(-1, 1)

# Define the neural network
model = nn.Sequential(
    nn.Linear(2, 100),
    nn.ReLU(),
    nn.Linear(100, 1),
)

# Define the loss function and the optimizer
criterion = nn.MSELoss()
optimizer = optim.LBFGS(model.parameters(), lr=0.1)

# Define a closure function for re-evaluation
def closure():
    optimizer.zero_grad()
    prediction = model(input_data)
    loss = criterion(prediction, target_data)
    loss.backward()
    return loss

# Train the neural network
for t in range(1000):
    loss=optimizer.step(closure)
    if(t%100==0):
      print(t," ",loss.item())


# Predict the function values at the grid points
with torch.no_grad():
    zz_pred = model(input_data).view_as(zz)

# Plot the true function and the neural network's approximation
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

c = axs[0].imshow(zz, origin='lower', extent=(-3, 3, -3, 3), cmap='viridis')
fig.colorbar(c, ax=axs[0])
axs[0].set_title('True function')

c = axs[1].imshow(zz_pred, origin='lower', extent=(-3, 3, -3, 3), cmap='viridis')
fig.colorbar(c, ax=axs[1])
axs[1].set_title('Neural network')

plt.show()