In [None]:
import numpy as np
import torch
from torch.utils.data import TensorDataset
seed = 30
import random
font = {'size'   : 16}
import matplotlib
matplotlib.rc('font', **font)
lam = 0.01

import os
if not os.path.exists('example'):
    os.makedirs('example')


# Functions for the true solution

def allen_cahn_true(x: np.array):
    return np.sin(np.pi*x[:,0])*np.sin(np.pi*x[:,1])

def allen_cahn_forcing(x: np.array):
    return -2*lam*np.pi**2*allen_cahn_true(x) + allen_cahn_true(x)**3 - allen_cahn_true(x)

def allen_cahn_pdv(x: np.array):
    ux = np.pi*np.cos(np.pi*x[:,0])*np.sin(np.pi*x[:,1])
    uy = np.pi*np.sin(np.pi*x[:,0])*np.cos(np.pi*x[:,1])
    return np.column_stack((ux, uy))

def allen_cahn_hes(x: np.array):
    uxx = -np.pi**2*np.sin(np.pi*x[:,0])*np.sin(np.pi*x[:,1])
    uxy = np.pi**2*np.cos(np.pi*x[:,0])*np.cos(np.pi*x[:,1])
    uyy = -np.pi**2*np.sin(np.pi*x[:,0])*np.sin(np.pi*x[:,1])
    return np.column_stack((uxx, uxy, uxy, uyy)).reshape((-1,2,2))

xmin = -1.
xmax = 1.
dx = 0.02

n_rand = 1000


# Grid of points in the domain
x = np.arange(xmin+dx, xmax, dx)
y = np.arange(xmin+dx, xmax, dx)
x_pts, y_pts = np.meshgrid(x, y)
x_pts = x_pts.reshape((-1,1))
y_pts = y_pts.reshape((-1,1))
pts = np.column_stack((x_pts, y_pts))
u_grid = allen_cahn_true(pts).reshape((-1,1))
pdv_grid = allen_cahn_pdv(pts)
hes_grid = allen_cahn_hes(pts)

print(f'u_grid.shape: {u_grid.shape}')
print(f'pts.shape: {pts.shape}')
print(f'pdv_grid.shape: {pdv_grid.shape}')
print(f'hes_grid.shape: {hes_grid.shape}')

grid_exampleset = TensorDataset(torch.tensor(pts, dtype=torch.float32), torch.tensor(u_grid, dtype=torch.float32), torch.tensor(pdv_grid, dtype=torch.float32), torch.tensor(hes_grid, dtype=torch.float32))


# Boundary conditions
x = np.arange(xmin, xmax+dx, dx)
y = np.array([xmin]*len(x))
x_pts = x.reshape((-1,1))
y_pts = y.reshape((-1,1))
pts = np.column_stack((x_pts, y_pts))
# Get first boudary condition
u_bc = allen_cahn_true(pts).reshape((-1,1))
pts_bc = pts

x = np.arange(xmin, xmax+dx, dx)
y = np.array([xmax]*len(x))
x_pts = x.reshape((-1,1))
y_pts = y.reshape((-1,1))
# Append second boundary condition
pts = np.column_stack((x_pts, y_pts))
u_bc = np.row_stack((u_bc, allen_cahn_true(pts).reshape((-1,1))))
pts_bc = np.row_stack((pts_bc, pts))


x = np.array([xmin]*len(y))
y = np.arange(xmin, xmax+dx, dx)
x_pts = x.reshape((-1,1))
y_pts = y.reshape((-1,1))
# Append third boundary condition
pts = np.column_stack((x_pts, y_pts))
u_bc = np.row_stack((u_bc, allen_cahn_true(pts).reshape((-1,1))))
pts_bc = np.row_stack((pts_bc, pts))

x = np.array([xmax]*len(y))
y = np.arange(xmin, xmax+dx, dx)
x_pts = x.reshape((-1,1))
y_pts = y.reshape((-1,1))
# Append fourth boundary condition
pts = np.column_stack((x_pts, y_pts))
u_bc = np.row_stack((u_bc, allen_cahn_true(pts).reshape((-1,1))))
pts_bc = np.row_stack((pts_bc, pts))

print(f'u_bc.shape: {u_bc.shape}')
print(f'pts_bc.shape: {pts_bc.shape}')

bc_exampleset = TensorDataset(torch.tensor(pts_bc, dtype=torch.float32), torch.tensor(u_bc, dtype=torch.float32))


# Plot the solution
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

# Plot the solution on a grid of points
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x = np.arange(xmin, xmax+dx, dx)
y = np.arange(xmin, xmax+dx, dx)
x_pts, y_pts = np.meshgrid(x, y)
z = allen_cahn_true(np.column_stack((x_pts.reshape((-1,1)), y_pts.reshape((-1,1))))).reshape(x_pts.shape)
ax.plot_surface(x_pts, y_pts, z, cmap=cm.jet)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('u')
plt.show()
plt.close()

# Now plot with a heatmap
fig, ax = plt.subplots()
im = ax.imshow(z, cmap=cm.jet)
ax.set_xlabel('x')
ax.set_ylabel('y')
fig.colorbar(im, ax=ax, orientation='vertical')
plt.show()
plt.close()

# Plot the forcing on a grid of points
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x = np.arange(xmin, xmax+dx, dx)
y = np.arange(xmin, xmax+dx, dx)
x_pts, y_pts = np.meshgrid(x, y)
z = allen_cahn_forcing(np.column_stack((x_pts.reshape((-1,1)), y_pts.reshape((-1,1))))).reshape((x_pts.shape))
ax.plot_surface(x_pts, y_pts, z, cmap=cm.jet)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('f')
plt.show()
plt.close()

# Not plot with a heatmap
fig, ax = plt.subplots()
im = ax.imshow(z, cmap=cm.jet)
ax.set_xlabel('x')
ax.set_ylabel('y')
fig.colorbar(im, ax=ax, orientation='vertical')
plt.show()
plt.close()



In [None]:
import torch
from torch import nn
import torch
from torch import nn
from torch.func import vmap, jacrev, hessian

lam = 0.01

def allen_cahn_true(x: torch.Tensor):
    return torch.sin(torch.pi*x[:,0])*torch.sin(torch.pi*x[:,1])

def allen_cahn_forcing(x: torch.Tensor):
    return -2*lam*torch.pi**2*allen_cahn_true(x) + allen_cahn_true(x)**3 - allen_cahn_true(x)

def allen_cahn_pdv(x: torch.Tensor):
    ux = torch.pi*torch.cos(torch.pi*x[:,0])*torch.sin(torch.pi*x[:,1])
    uy = torch.pi*torch.sin(torch.pi*x[:,0])*torch.cos(torch.pi*x[:,1])
    return torch.column_stack((ux, uy))

# Density network
class AllenNet(torch.nn.Module):
    def __init__(self,
                 bc_weight: float,
                 sys_weight: float,
                 pde_weight: float,
                 hidden_units: list,
                 device: str,
                 *args,
                 **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        # Save the parameters
        self.bc_weight = bc_weight
        self.pde_weight = pde_weight
        self.sys_weight = sys_weight
        self.hidden_units = hidden_units
        self.device = device

        self.in_dim = 2
        self.out_dim = 1
        # Add the first in_dimension
        hidden_units = [self.in_dim] + hidden_units
        # Define the net
        net = nn.Sequential()
        for i in range(len(hidden_units)-1):
            net.add_module(f'layer_{i}', nn.Linear(hidden_units[i], hidden_units[i+1]))
            net.add_module(f'activation_{i}', nn.Tanh())
        net.add_module(f'layer_{len(hidden_units)-1}', nn.Linear(hidden_units[-1], self.out_dim))
        # Save the network
        self.net = net.to(self.device)
        # Define the optimizer
    
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        # Forward function
        return self.net(x)
    
    def forward_single(self, x:torch.Tensor) -> torch.Tensor:
        # Forward function for individual samples
        return self.net(x.reshape((1,-1))).reshape((-1))
    
    def loss_fn(self,
        x:torch.Tensor,
        x_bc:torch.Tensor=None,
        y_bc:torch.Tensor=None,
    ) -> torch.Tensor:
        
        # Get the prediction
        y_pred = self.forward(x)
        # Get the partial derivatives from the network
        Dy_pred = vmap(jacrev(self.forward_single))(x)[:,0,:]
        Hy_pred = vmap(hessian(self.forward_single))(x)[:,0,:,:]
        
        # lambda*(uxx + uyy) - u + u^3 = 0
        # Calculate the pde_residual
        pde_pred = lam*(Hy_pred[:,0,0] + Hy_pred[:,1,1]) - y_pred.reshape((-1)) + y_pred.reshape((-1))**3
        # Calculate the loss
        pde_loss = nn.MSELoss()(pde_pred, allen_cahn_forcing(x))
    
        
        y_bc_pred = self.forward(x_bc)
        bc_loss = nn.MSELoss()(y_bc_pred.reshape((-1)), y_bc.reshape((-1)))

    
        # Total loss
        tot_loss = self.pde_weight*pde_loss + self.bc_weight*bc_loss
        return tot_loss
    
    def eval_losses(self, step:int,
        x:torch.Tensor,
        y:torch.Tensor,
        x_bc:torch.Tensor=None,
        y_bc:torch.Tensor=None,
        print_to_screen:bool=False,    
    ):
        # Check that the mode parameter is correct
        # Get the prediction
        y_pred = self.forward(x)
        # Get the partial derivatives from the network
        Dy_pred = vmap(jacrev(self.forward_single))(x)[:,0,:]
        Hy_pred = vmap(hessian(self.forward_single))(x)[:,0,:,:]
            
        # lambda*(uxx + uyy) - u + u^3 = f
        # Calculate the pde_residual
        pde_pred = lam*(Hy_pred[:,0,0] + Hy_pred[:,1,1]) - y_pred.reshape((-1)) + y_pred.reshape((-1))**3
        # Calculate the loss
        pde_loss = nn.MSELoss()(pde_pred, allen_cahn_forcing(x))
        
        # Loss wrt the true output
        out_loss = nn.MSELoss()(y_pred, y)        
        # Boundary condition
        y_bc_pred = self.forward(x_bc)
        bc_loss = nn.MSELoss()(y_bc_pred.reshape((-1)), y_bc.reshape((-1)))
        
        # Total loss
        tot_loss = pde_loss + self.bc_weight*bc_loss
        
        
        if print_to_screen:
            print(f'Step: {step}, total loss: {tot_loss}')
            print(f'pde loss: {pde_loss}, out loss {out_loss}, bc loss: {bc_loss}')
        
        
        return step, out_loss, pde_loss, bc_loss, tot_loss
    
    def evaluate_forcing(self, x):
        # Get the prediction
        y_pred = self.forward(x)
        # Get the partial derivatives from the network
        Dy_pred = vmap(jacrev(self.forward_single))(x)[:,0,:]
        Hy_pred = vmap(hessian(self.forward_single))(x)[:,0,:,:]
            
        # lambda*(uxx + uyy) - u + u^3 = 0
        # Calculate the pde_residual
        pde_pred = lam*(Hy_pred[:,0,0] + Hy_pred[:,1,1]) - y_pred.reshape((-1)) + y_pred.reshape((-1))**3
                
        return pde_pred
    
    def evaluate_consistency(self, x):
        # Get the prediction
        y_pred = self.forward(x)
        # Get the partial derivatives from the network
        Dy_pred = vmap(jacrev(self.forward_single))(x)[:,0,:]
        Hy_pred = vmap(hessian(self.forward_single))(x)[:,0,:,:]
            
        # lambda*(uxx + uyy) - u + u^3 = 0
        # Calculate the pde_residual
        pde_pred = lam*(Hy_pred[:,0,0] + Hy_pred[:,1,1]) - y_pred.reshape((-1)) + y_pred.reshape((-1))**3
                
        return torch.abs(pde_pred - allen_cahn_forcing(x))
            
    
    
    


In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
seed = 30
from itertools import cycle

epochs = 100
batch_size = 32 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')   


train_dataset = test_dataset = grid_exampleset
bc_dataset = bc_exampleset
#else:
#    bc_dataset = None
# Generate the dataloader
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, 10000, shuffle=True)
bc_dataloader = DataLoader(bc_dataset, batch_size, shuffle=True)

print('Data loaded!')

print(train_dataset[:][0].shape)

activation = torch.nn.Tanh()

model = AllenNet(
    bc_weight=1.,
    pde_weight=1.,
    sys_weight=1.,
    hidden_units=[50 for _ in range(4)],
    device=device,
).to(device)


step_list= []
out_losses_train = []
pde_losses_train = []
tot_losses_train = []
bc_losses_train = []

step_list_test = []
out_losses_test = []
pde_losses_test = []
tot_losses_test = []
bc_losses_test = []
times_test = []

from torch.optim import LBFGS, Adam


opt = Adam(model.parameters(), lr=1e-3)

import time
def train_loop(epochs:int,
        train_dataloader:DataLoader,
        test_dataloader:DataLoader,
        bc_dataloader:DataLoader,
        print_every:int=100):
    
    # Training mode for the network
    
    for epoch in range(epochs):
        model.train()
        step_prefix = epoch*len(train_dataloader)
        start_time = time.time()
        print(f'Epoch: {epoch}, step_prefix: {step_prefix}')
        for step, (train_data, bc_data) in enumerate(zip(train_dataloader, cycle(bc_dataloader))):
            
            # Load batches from dataloaders
            x_train = train_data[0].to(device).float().requires_grad_(True)
            
            y_train = train_data[1].to(device).float()
            
            x_bc = bc_data[0].to(device).float()
            y_bc = bc_data[1].to(device).float()
            
            #if name == 'grid':
            # Call zero grad on optimizer
            opt.zero_grad()
            
            loss = model.loss_fn(
                x=x_train, x_bc=x_bc, y_bc=y_bc
            )
            # Backward the loss, calculate gradients
            loss.backward()
            # Optimizer step
            opt.step()
            # Update the learning rate schedulings
            
            # Printing
            if (step_prefix+step) % print_every == 0:
                #print('Train losses')
                with torch.no_grad():
                    step_val, out_loss_train, pde_loss_train, bc_loss_train, tot_loss_train = model.eval_losses(
                        step = step_prefix+step,
                        x=x_train, y=y_train, x_bc=x_bc, y_bc=y_bc, print_to_screen=True,
                    )
                    step_list.append(step_val)
                    tot_losses_train.append(tot_loss_train)
                    out_losses_train.append(out_loss_train)
                    pde_losses_train.append(pde_loss_train)
                    bc_losses_train.append(bc_loss_train)
        
        # Calculate and average the loss over the test dataloader
        stop_time = time.time()
        print(f'Epoch time: {stop_time-start_time}')
        epoch_time = stop_time-start_time
        times_test.append(epoch_time)
        model.eval()
        test_loss = 0.0
        out_loss_test = 0.0
        der_loss_test = 0.0
        pde_loss_test = 0.0
        tot_loss_test = 0.0
        bc_loss_test = 0.0
        hes_loss_test = 0.0
        
        with torch.no_grad():
            for (test_data, bc_data) in zip(test_dataloader, cycle(bc_dataloader)):
                x_test = test_data[0].to(device).float().requires_grad_(True)
                y_test = test_data[1].to(device).float()

                x_bc = bc_data[0].to(device).float()
                y_bc = bc_data[1].to(device).float()
                
                step_test, out_loss, pde_loss, bc_loss, tot_loss = model.eval_losses(step=step_prefix+step,
                                                                                        x=x_test, y=y_test, x_bc=x_bc, y_bc=y_bc)
                
                out_loss_test += out_loss.item()
                pde_loss_test += pde_loss.item()
                tot_loss_test += tot_loss.item()
                bc_loss_test += bc_loss.item()
                
                test_loss += tot_loss.item()
                
            test_loss /= len(test_dataloader)
            out_loss_test /= len(test_dataloader)
            pde_loss_test /= len(test_dataloader)
            tot_loss_test /= len(test_dataloader)
            bc_loss_test /= len(test_dataloader)
        
        step_list_test.append(step_test)
        out_losses_test.append(out_loss_test)
        pde_losses_test.append(pde_loss_test)
        tot_losses_test.append(tot_loss_test)
        bc_losses_test.append(bc_loss_test)
            
        print(f"Average test loss: {test_loss}")
        print(f"Average output loss: {out_loss_test}")
        print(f"Average PDE loss: {pde_loss_test}")
        print(f"Average total loss: {tot_loss_test}")
        print(f"Average bc loss: {bc_loss_test}")


train_loop(epochs=epochs, train_dataloader=train_dataloader, test_dataloader=test_dataloader, bc_dataloader=bc_dataloader, print_every=100)


torch.cuda.empty_cache()
model.eval()

In [None]:
# %%
import os

import numpy as np
from matplotlib import pyplot as plt
from model import allen_cahn_true
# Generate the grid for the true solution
xmin = -1.
xmax = 1.
dx = 0.01
x = np.arange(xmin, xmax+dx, dx)
y = np.arange(xmin, xmax+dx, dx)
x_pts, y_pts = np.meshgrid(x, y)
pts = np.column_stack((x_pts.reshape((-1,1)), y_pts.reshape((-1,1))))
u_grid = allen_cahn_true(torch.tensor(pts)).reshape((-1,1)).reshape(x_pts.shape)

u_pred = model.forward(torch.tensor(pts).to(device).float()).detach().cpu().numpy().reshape(x_pts.shape)

u_err = np.abs(u_grid - u_pred)

from matplotlib import cm

# Plot the predicted solution
fig, ax = plt.subplots()
im = ax.imshow(u_pred, cmap=cm.jet)
ax.set_xlabel('x')
ax.set_ylabel('y')
fig.colorbar(im, ax=ax, orientation='vertical')
fig.suptitle('Predicted solution')
plt.show()
plt.close()

# Plot the error wrt the true solution
fig, ax = plt.subplots()
im = ax.imshow(u_err, cmap=cm.jet)
ax.set_xlabel('x')
ax.set_ylabel('y')
fig.colorbar(im, ax=ax, orientation='vertical')
fig.suptitle('Error of the predicted solution')
plt.show()
plt.close()


# Convert the losses arrays
epoch_list = torch.tensor(step_list).cpu().numpy()
out_losses_train = torch.tensor(out_losses_train).cpu().numpy()
pde_losses_train = torch.tensor(pde_losses_train).cpu().numpy()
tot_losses_train = torch.tensor(tot_losses_train).cpu().numpy()
bc_losses_train = torch.tensor(bc_losses_train).cpu().numpy()    


N = 10
l = len(np.convolve(out_losses_train, np.ones(N)/N, mode='valid'))
plt.figure()
plt.plot(epoch_list[:l], np.convolve(pde_losses_train, np.ones(N)/N, mode='valid'), label='pde_loss', color='red')
plt.plot(epoch_list[:l], np.convolve(out_losses_train, np.ones(N)/N, mode='valid'), label='out_loss', color='green')
plt.plot(epoch_list[:l], np.convolve(bc_losses_train, np.ones(N)/N, mode='valid'), label='bc_loss', color='purple')
plt.legend()
plt.yscale('log')
plt.title('Losses of the student model')
plt.xlabel('Training steps')
plt.ylabel('Loss')
plt.show()
plt.close()


# Convert the losses arrays
epoch_list = torch.tensor(step_list_test).cpu().numpy()
out_losses_test = torch.tensor(out_losses_test).cpu().numpy()
pde_losses_test = torch.tensor(pde_losses_test).cpu().numpy()
tot_losses_test = torch.tensor(tot_losses_test).cpu().numpy()
bc_losses_test = torch.tensor(bc_losses_test).cpu().numpy()
times_test = np.array(times_test)


    
plt.figure()
plt.plot(epoch_list, pde_losses_test, label='pde_loss', color='red')
plt.plot(epoch_list, out_losses_test, label='out_loss', color='green')
plt.plot(epoch_list, bc_losses_test, label='bc_loss', color='purple')
plt.legend()
plt.yscale('log')
plt.title('Losses of the student model')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()