## TODO
* arrange order of material
* new task before any training: load some simulation data and visualize a simulation
    * image plot
    * v/z at a few times
* include multiple ICs for supervised training. update existing code that uses the dataloader etc.!
* provide mask as an input channel always (even for unet)
* new first task: train without using fixed BCs using UnetBC
  * visualize results after training
  * image plots v/z ref/est/diff
  * plot v/z at some chosen time points (ref/est)
  * scatter plot of delta v and delta z ref vs. est (over all locations / time points)
* rename updates to delta_zeta, delta_v
* visualization code
  * visualize numerical simulations: give them an example for plotting a simulation, ask them to do a minor variation on this
  * give instructions
* flux task
  * before implementing flux net, have students graph total mass over time of data-driven BCnet
  * change flux output so z fluxes defined on vel points, velocities still as tendencies. keep BC constraint
  * standard output visualization for fluxnet
  * compare overall error and total mass over time of fluxnet to BCnet
* new task: hybrid net with supervised loss
  * inputs -> net -> flux -> zeta -> u. keep imposing BCs.
  * train on same supervised loss as before
  * plots, comparisons, etc.
* final task: unsupervised learning
  * pde loss from $A\zeta - b$
  * use library of system states, randomly sample from these to generate each batch. restart with new random state when integrating past $t_\text{max}$.
* link to papers: unet, 
* further reading
* student version

optional:
* remove order parameter (multiple inputs time steps)?
* extra tasks: change hyperparams, multiple random seeds
* separate cell and markdown explanation of random seeding func
* include time/space axis information in the hdf5 file, and save it in the dataset objects and use it for plotting
* refactor dataset so it's just a tensordataset. do sequence building etc. in the dataloader instead

In [None]:
# Importing necessary libraries
import h5py
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset
from matplotlib import pyplot as plt

%load_ext tensorboard

## SW Equations and Custom Dataset for SW Simulation Data

We are interested in solving the following form of Shallow Water Equations

$\dfrac{\partial v}{\partial t} + \dfrac{c_D}{h}v|v|+g\dfrac{\partial \zeta}{\partial x} = 0$

$\dfrac{\partial \zeta}{\partial t} + \dfrac{\partial (vh)}{\partial x}=0$

where $v(x, t)$ is the velocity, $ζ(x, t)$ is the positive or negative surface disturbance and $h(x, t) = \zeta(x, t) + h_0$ is the total depth. $c_D$ and $g$ correspond to drag coefficient and graviational acceleration respectively. The first equation describes Newton's second law ($F=ma$) acting on a fluid parcel, while the second models mass conservation.

## The problem we want to solve
### Time integration of PDEs
Given any **system state** at time $t$ for all locations $x$:
$$\forall x: \left(v(x, t), \zeta(x,t) \right)$$
we want to time-integrate the SWEs to obtain the system state at time $t+\Delta$
$$\forall x:\left(v(x, t+\Delta t), \zeta(x, t+\Delta t) \right) = \left(
\int_t^{t+\Delta t} \frac{\partial v(x, t)}{\partial t}, 
\int_t^{t+\Delta t} \frac{\partial \zeta(\cdot,t)}{\partial t}
\right) $$
we will consider the case where we deal with initial conditions and time-integrated outputs only on an evenly spaced grid with spacing $\Delta x$, and a fixed time step $\Delta t$.


### Classical PDE integration
Classical physics-based numerical methods compute these updates by calculating partial derivatives in space and time. Explicit methods do these calculations at time steps where $v,\zeta$ are already known, simplifying calculationg but often requiring very small time steps to achieve accuracy and stability, often at high computational cost. (Semi)implicit methods can take larger time steps but most iteratively solve a system of equations at each time step until convergence, which can also be costly.

We'll discuss the discretization and time stepping used to solve this PDE in one of our later tasks. For now, it's enough to know that we've generated some simulation data from numerical simulation code, and we'll use that data to train neural networks and as a "ground truth" reference.

### Problem statement
In this tutorial, our goal is to **train a neural network to carry out time integration of SWEs** by $\Delta t$, such that the results match a semiimplicit scheme. By replacing the iterative solving operation of the numerical scheme with a forward pass through a neural network, we aim to produce a fast time-integration method whos computation time does not depend on the input data.

While we will not focus on computation times here due to using simplified, lightweight 1-D versions of the numerical model and deep learning architecture, this technique has demonstrated impressive speed increases compared to classical numerical solvers when applied to 2D and 3D fluid dynamics.

## Dealing with the Dataset

In this section, we'll import and use custom PyTorch class to load SWE simulation data from an HDF5 file. If you'd like to see how this class works later you can read through the code [here](https://github.com/alicanbekar/pi_lecture_pytorch/blob/main/sw_dataset.ipynb), but for now that isn't necessary. This dataset class is responsible for:
1. Reading data from an HDF5 file.
2. Normalizing the data, so that $\zeta$ and $v$ both range from roughly -1 to 1.
3. Splitting the data into training, validation, and testing datasets.
4. Retrieving sequences of SWE system state to create inputs and outputs for the trained models.

Run the next cell to install the necessary code and download the data.

In [None]:
!pip install -q import-ipynb
import import_ipynb
!wget https://raw.githubusercontent.com/alicanbekar/pi_lecture_pytorch/main/sw_dataset.ipynb
!wget !wget https://raw.githubusercontent.com/alicanbekar/pi_lecture_pytorch/main/simulation_data.h5
%run sw_dataset.ipynb

## Exploring the data
Run the following cells to retrieve and plot some data

In [None]:
data = SWDataset(file_path="simulation_data.h5", normalize=False)

In [None]:
# create pytorch tensors to store zeta/vel
zeta = torch.cat([data[i][0][0] for i in range(len(data))], axis=0)
vel = torch.cat([data[i][0][1] for i in range(len(data))], axis=0)

In [None]:
plt.imshow(vel.detach().numpy())
plt.colorbar()

## Creating the U-Net Model for SW Simulations

U-Net is a convolutional neural network architecture primarily used for biomedical image segmentation. In our case, we will adapt U-Net to handle 1D data from the SW simulations.

The U-Net architecture is symmetric, and it consists of an encoding (downsampling) path, followed by a decoding (upsampling) path. Skip connections are used to pass the information from the encoding path to the decoding path, which helps the network retain spatial details.

Let's walk through the code and its structure:


In [None]:
class UNet(nn.Module):
    def __init__(self, order):
        super(UNet, self).__init__()
        self.order = order

        # Initial convolution layers for two different input types
        self.conv_zeta = nn.Conv1d(self.order, 8, kernel_size=3, padding=1)
        self.conv_vel = nn.Conv1d(self.order, 8, kernel_size=4, padding=1)

        # Encoder (downsampling) blocks
        self.enc1 = self.u_net_block(16, 16)
        self.enc2 = self.u_net_block(16, 32)
        self.enc3 = self.u_net_block(32, 64)
        self.enc4 = self.u_net_block(64, 128)
        self.enc5 = self.u_net_block(128, 256)

        # Pooling layer for downsampling
        self.pool = nn.MaxPool1d(2)

        # Upsampling layers
        self.up1 = nn.ConvTranspose1d(256, 128, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose1d(128, 64, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose1d(64, 32, kernel_size=2, stride=2)
        self.up4 = nn.ConvTranspose1d(32, 16, kernel_size=2, stride=2)

        # Decoder (upsampling) blocks
        self.dec1 = self.u_net_block(256, 128)
        self.dec2 = self.u_net_block(128, 64)
        self.dec3 = self.u_net_block(64, 32)
        self.dec4 = self.u_net_block(32, 16)

        # Output convolution layers
        self.output_dec_zeta = nn.Conv1d(16, 1, kernel_size=3, padding=1)
        self.output_dec_vel = nn.Conv1d(16, 1, kernel_size=2, padding=1)

    def u_net_block(self, in_channels, out_channels):
        """
        Creates a U-Net block with two convolution layers followed by batch normalization and ReLU activation.
        """
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x_zeta, x_vel):
        # Initial convolution operations
        x_zeta = self.conv_zeta(x_zeta)
        x_vel = self.conv_vel(x_vel)

        # Concatenate the two feature maps along the channel dimension
        x_combined = torch.cat([x_vel, x_zeta], dim=1)

        # Encoding process
        e1 = self.enc1(x_combined)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        e5 = self.enc5(self.pool(e4))

        # Decoding process with skip connections
        d1 = self.up1(e5)
        d1 = torch.cat([d1, e4], dim=1)
        d1 = self.dec1(d1)

        d2 = self.up2(d1)
        d2 = torch.cat([d2, e3], dim=1)
        d2 = self.dec2(d2)

        d3 = self.up3(d2)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3(d3)

        d4 = self.up4(d3)
        d4 = torch.cat([d4, e1], dim=1)
        d4 = self.dec4(d4)

        # Separate output convolutions
        dt_zeta = self.output_dec_zeta(d4)
        dt_vel = self.output_dec_vel(d4)
        return dt_zeta, dt_vel


### Guidelines:

- **u_net_block**: This function should return a sequential block that performs two sets of (convolution -> batch normalization -> ReLU activation). You can chain these operations using `nn.Sequential`.

- **Encoder**: Remember, as you go deeper into the encoder, you are reducing the spatial dimensions (using max pooling) and typically increasing the number of channels.

- **Decoder**: It's the reverse of the encoder. For each block, you will upsample to increase spatial dimensions and typically decrease the number of channels. Make sure to include the skip connections from the encoder. This can be done using torch's concatenation.

- **Output Layers**: The goal is to transform the deep feature maps into our desired output. Depending on the task, this could be a segmentation mask, regression map, etc.

Remember, the architecture of U-Net is symmetric. It might be helpful to sketch the network or list down the sizes of feature maps as you code.




### Boundary Conditions can be applied using a boundary mask

In [None]:
class UNetBC(UNet):
    def __init__(self, order, mask):
        super(UNetBC, self).__init__(order)
        self.mask = mask

    def forward(self, x_zeta, x_vel):
        x_vel = x_vel * self.mask
        dt_zeta, dt_vel = super(UNetBC, self).forward(x_zeta, x_vel)
        dt_vel = dt_vel * self.mask
        return dt_zeta, dt_vel

# Training the U-Net Model

After defining our U-Net architecture, it's time to set up a training loop. This loop will iteratively update our model's weights using our dataset. Let's break down the steps needed:

1. **Setting Up**: Import necessary libraries, define hyperparameters, initialize computational device, and set random seeds.
2. **Data Loading**: Load the training and validation datasets and create data loaders.
3. **Model & Training Essentials Initialization**: Create mask, model, optimizer, and loss function.
4. **Training Loop**: For each epoch, forward propagate the input through the model, compute the loss, backpropagate the errors, and update the model weights.
5. **Validation Loop**: After training for each epoch, we will evaluate the model's performance on the validation dataset.
6. **Logging & Visualization**: Log metrics such as losses to TensorBoard.
7. **Model Saving**: After all epochs are completed, save the model's state dict.

Let's get started with the skeleton and explanations:


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import random

### Step 1: Setting Up

First, we need to define some hyperparameters, which are constants that determine how the model will be trained. We also set a computational device (either a GPU or CPU) to ensure our tensors and model are loaded onto the right hardware.


In [None]:
# TODO: Define hyperparameters
BATCH_SIZE = 32
EPOCHS = 400
LR = 0.001
ORDER = 1  # Autoregressive model order
NUMTIME = 600  # time steps per simulation
EXP_NAME = 'plain_time_integrator'

# Initialize TensorBoard writer for logging
log_dir = f'runs/exp_{EXP_NAME}'
writer = SummaryWriter(log_dir=log_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seeds(seed=42):
    # This function ensures reproducibility
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

### Step 2: Data Loading

Next, load the training and validation datasets. We also create data loaders that will allow us to fetch batches of data.


In [None]:
train_dataset = SWDataset(file_path="simulation_data.h5", order=ORDER, numtime=NUMTIME, mode="train")
valid_dataset = SWDataset(file_path="simulation_data.h5", order=ORDER, numtime=NUMTIME, mode="valid")

dataloader_train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
dataloader_valid = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

### Step 3: Model & Training Essentials Initialization

Before training, initialize the model, the optimizer responsible for weight updates, and the loss function.


In [None]:
batch = next(iter(dataloader_train))
input_vel_batch = batch[0][1]
mask = torch.ones_like(input_vel_batch).to(device)
mask[..., 0] = 0
mask[..., -1] = 0

model = UNetBC(order=ORDER, mask=mask).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.MSELoss()

### Step 4 & 5: Training and Validation Loop

The training loop involves:
1. Setting the model to training mode.
2. Iterating through batches of data from the dataloader.
3. Making predictions using the model.
4. Calculating the loss.
5. Backpropagating to compute gradients.
6. Updating model parameters using the optimizer.

After each training epoch, you'll also run a validation loop to check the model's performance on unseen data. We will develop this training loop as a class which can take our UNet model as input, so when we make modifications to the architecture, we can still use this trainer.


In [None]:
class UNetTrainer:
    def __init__(self, model, device, dataloader_train, dataloader_valid, optimizer, criterion, writer, epochs):
        self.model = model
        self.device = device
        self.dataloader_train = dataloader_train
        self.dataloader_valid = dataloader_valid
        self.writer = writer
        self.epochs = epochs
        self.optimizer = optimizer
        self.criterion = criterion

    def train_epoch(self):
        train_losses = []
        self.model.train()

        for (input_zeta, input_vel), (target_zeta, target_vel) in self.dataloader_train:
            self.optimizer.zero_grad()
            input_zeta, input_vel = input_zeta.to(self.device), input_vel.to(self.device)
            target_zeta, target_vel = target_zeta.to(self.device), target_vel.to(self.device)
            output_zeta_dt, output_vel_dt = self.model(input_zeta, input_vel)
            output_zeta = input_zeta + output_zeta_dt
            output_vel = input_vel + output_vel_dt
            loss_zeta = self.criterion(output_zeta, target_zeta)
            loss_vel = self.criterion(output_vel, target_vel)
            total_loss = (loss_zeta + loss_vel) / 2.0
            total_loss.backward()
            self.optimizer.step()
            train_losses.append(total_loss.item())

        return np.mean(train_losses)

    def validate_epoch(self):
        valid_losses = []
        self.model.eval()
        with torch.no_grad():
            for (input_zeta, input_vel), (targets_zeta, targets_vel) in self.dataloader_valid:
                input_zeta, input_vel = input_zeta.to(self.device), input_vel.to(self.device)
                targets_zeta, targets_vel = targets_zeta.to(self.device), targets_vel.to(self.device)
                output_zeta_dt, output_vel_dt = self.model(input_zeta, input_vel)
                output_zeta = input_zeta + output_zeta_dt
                output_vel = input_vel + output_vel_dt
                loss_zeta = self.criterion(output_zeta, targets_zeta)
                loss_vel = self.criterion(output_vel, targets_vel)
                combined_loss = (loss_zeta + loss_vel) / 2.0
                valid_losses.append(combined_loss.item())

        return np.mean(valid_losses)

    def train(self):
        for epoch in range(self.epochs):
            train_loss = self.train_epoch()
            self.writer.add_scalar("Loss/train", train_loss, epoch)

            valid_loss = self.validate_epoch()
            self.writer.add_scalar("Loss/valid", valid_loss, epoch)

            print(f"Epoch {epoch+1}/{self.epochs} Train Loss: {train_loss:.8f} Valid Loss: {valid_loss:.8f}")

trainer = UNetTrainer(model, device, dataloader_train, dataloader_valid, optimizer, criterion, writer, epochs=EPOCHS)
trainer.train()

### Step 6: Logging & Visualization

We've already added logging functionality in the training loop using TensorBoard's `SummaryWriter`. This will help visualize training and validation loss curves, among other metrics you might want to track.


In [None]:
%tensorboard --logdir $log_dir

### Step 7: Model Saving

Finally, save the model's state dict, which contains the model's learned parameters. Later, you can load this state dict to make predictions with the trained model.


In [None]:
torch.save(model.state_dict(), EXP_NAME + '.pth')

### Step 8: Learning the flux values instead of tendencies
Hyperbolic conservation laws can be written in the form:

$\dfrac{\partial \mathbf{U}}{\partial t} + \dfrac{\partial \mathbf{F}(\mathbf{U})}{\partial \mathbf{x}}=\mathbf{0}$

Instead of outputting the tendencies $\mathbf{U}_t$ for elevation $\zeta$ and velocity $v$, we can also output the fluxes $\mathbf{F}$ calculated on discretized domain corresponding to these variables.
This will guarantee that our neural network will satisfy the conservation laws precisely.

In [None]:
class UNetFlux(UNet):
    def __init__(self, order):
        super(UNetFlux, self).__init__(order)
        self.output_dec_zeta = nn.Conv1d(16, 1, kernel_size=4, padding=1)
        self.output_dec_vel = nn.Conv1d(16, 1, kernel_size=3, padding=1)

    def forward(self, x_zeta, x_vel):
        F_zeta, F_vel = super(UNetFlux, self).forward(x_zeta, x_vel)
        zeta_flux = torch.diff(F_zeta, dim=-1)
        vel_flux = torch.diff(F_vel, dim=-1)
        vel_flux = torch.nn.functional.pad(vel_flux, (1, 1), "constant", 0)
        zeta_flux = torch.nn.functional.pad(zeta_flux, (1, 1), "constant", 0)
        return zeta_flux, vel_flux

model = UNetFlux(order=1)
optimizer = optim.Adam(model.parameters(), lr=LR)
trainer = UNetTrainer(model, device, dataloader_train, dataloader_valid, optimizer, criterion, writer, epochs=EPOCHS)
trainer.train()

### Step 9: Learning the $\zeta$ values only using the UNet
We can also output updates on $\zeta$ values and use a hybrid approach to update the field variables. This approach uses the Imex integration scheme for the variables. The derivation of the discretized SWE in this case can be obtained with the following derivation,

We discretize the momentum equation as follows:

$u^{n+1} = u^n - \Delta t C_D\frac{1}{h}u^n|u^n|- \Delta t g (1-w_{\textbf{imp}}) \frac{\partial \zeta^{n}}{\partial x}-\Delta t g w_{\textbf{imp}} \frac{\partial \zeta^{n+1}}{\partial x}$

where $w_{\textbf{imp}}$ is a fixed parameter controlling weighting between implicit and explicit time stepping. The mass equation is discretized as:

$\zeta^{n+1} = \zeta^n - \Delta t (1-w_{\textbf{imp}}) \frac{\partial h^n u^n}{\partial x}-\Delta t w_{\textbf{imp}} \frac{\partial h^n u^{n+1}}{\partial x}.
$

Recall that $h=d+\zeta$ and $d$ is the undisturbed water depth. Inserting the momentum equation into the mass conservation equation, we obtain:

$\zeta^{n+1} = \zeta^n - \Delta t (1-w_{\textbf{imp}}) \frac{\partial h^n u^n}{\partial x}-\Delta t w_{\textbf{imp}} \frac{\partial h^nu^*}{\partial x} + \Delta t^2 w_{\textbf{imp}}^2g\frac{\partial^2 h^n\zeta^{n+1}}{\partial x^2}$

where $u^*$ is an explicit prediction for $u$:

$u^* = u^n - \Delta t c_D\frac{1}{h}u^n|u^n|- \Delta t g (1 - w_{\textbf{imp}}) \frac{\partial \zeta^{n}}{\partial x}$

The second order spatial derivatives are discretized using the second order finite central difference stencil. Then using $u^*$, we obtain the following expression for momentum equation:

$\zeta^{n+1}_i = \frac{1}{1+c_E+c_W}\bigg[\zeta^n+\text{div}+c_E\zeta^{n+1}_{i+1}+c_W\zeta^{n+1}_{i-1}\bigg]$

where $\text{div} = - \Delta t (1-w_{\textbf{imp}})\frac{\partial h^n u^n}{\partial x} -\Delta t w_{\textbf{imp}}\frac{\partial h^nu^*}{\partial x}$, while $c_E$ and $c_W$ are defined as

$c_E=\frac{0.5\Delta t^2w_{\textbf{imp}}^2g (h_i^n+h_{i+1}^n)}{\Delta x^2}$ if $h(i+1)>0$ and $0$ otherwise

$c_W =\frac{0.5\Delta t^2w_{\textbf{imp}}^2g (h_i^n+h_{i-1}^n)}{\Delta x^2}$ if $h(i-1)>0$ and $0$ otherwise

Then $\zeta$ update equation describes a linear system of equations in $\zeta^{n+1}$ that can be written in matrix-vector form

$A \zeta^{n+1} = b$

where $A$ is a $N \times N$ tridiagonal matrix ($N=L/\Delta x$) with $A_{k,k}=1$, $A_{k, k - 1} = -\frac{c_W}{1 + c_E + c_W}$, $A_{k, k + 1} = - \frac{c_E}{1 + c_E + c_W}$ and all other elements zero. $b\in\mathbb R^N$ with $b = \frac{\zeta^n +div}{1 + c_E + c_W}$. Having obtained $\zeta^{n+1}$, the new velocity $u^{n+1}$ is calculated as

$u^{n+1} = u^* - \Delta t g w_{\textbf{imp}} \frac{\partial \zeta^{n+1}}{\partial x}$


###  Our tasks are as follows:

1- First, modify our UNet architecture to output only one channel.

2- Create the loss for updating the $\zeta$ values using the equation system $A \zeta^{n+1} = b$.

3- Update the velocity values using the formula $v^{n+1} = v^* - \Delta t g w_{\textbf{imp}} \frac{\partial \zeta^{n+1}}{\partial x}$. Hence we need a function accomplishing this.

4-  Update the batch size and mask tensor. Call the dataset class for the initial conditions and disable normalization.

5- Modify the training loop of our model.

6- Run the training loop

In [None]:
# New Architecture
class ZetaUNet(UNetBC):
    def __init__(self, order, mask):
        super(ZetaUNet, self).__init__(order, mask)
        self.order = order
        self.mask = mask

    def forward(self, x_zeta, x_vel):
        zeta_dt, _ = super(ZetaUNet, self).forward(x_zeta, x_vel)
        return zeta_dt

In [None]:
# New physics informed loss function
# Parameters for the given dataset of SWE.
CD = 1.0e-3
G = 9.81
DT = 300.0
W_IMP = 0.5
H0 = 100.0
DX = 10.0e3
N = 256

class HybridLoss(nn.Module):
    def __init__(self):
        super(HybridLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, zeta_dt, zeta_n, vel_n):
        h_n = H0 + zeta_n
        h_m = (h_n[..., 1:] + h_n[..., :-1]) / 2.0

        zeta_dx = torch.diff(zeta_n, axis=-1) / DX
        zeta_dx = torch.nn.functional.pad(zeta_dx, (1, 1), "constant", 0)
        h_m = torch.nn.functional.pad(h_m, (1, 1), "constant", 1)
        vel_star = (
            vel_n
            - DT * CD * torch.div(torch.mul(torch.abs(vel_n), vel_n), h_m)
            - DT * G * (1 - W_IMP) * zeta_dx
            - DT * G * W_IMP * zeta_dx
        )
        U = torch.mul(vel_n, h_m)
        U_star = torch.mul(vel_star, h_m)
        U_dx = (torch.diff(U, axis=-1)) / DX
        U_star_dx = (torch.diff(U_star, axis=-1)) / DX
        div = -DT * (1 - W_IMP) * U_dx - DT * W_IMP * U_star_dx

        h_e = h_n + torch.roll(h_n, shifts=-1, dims=-1)
        h_w = h_n + torch.roll(h_n, shifts=1, dims=-1)
        h_e[..., -1] = 2 * h_n[..., -1]
        h_w[..., 0] = 2 * h_n[..., 0]

        h_e_dx2 = (h_e[..., :-2] - 2 * h_e[..., 1:-1] + h_e[..., 2:]) / (DX**2)
        h_w_dx2 = (h_w[..., :-2] - 2 * h_w[..., 1:-1] + h_w[..., 2:]) / (DX**2)

        c_e = DT**2 * W_IMP**2 * G * h_e_dx2
        c_w = DT**2 * W_IMP**2 * G * h_w_dx2
        diag_c = torch.ones(N)
        diag_cm1 = torch.div(-c_w, 1 + c_e + c_w)
        diag_cp1 = torch.div(-c_e, 1 + c_e + c_w)
        diag_cm1 = torch.nn.functional.pad(diag_cm1, (1, 0), "constant", 0)
        diag_cp1 = torch.nn.functional.pad(diag_cp1, (0, 1), "constant", 0)
        b = torch.div(zeta_n[..., 1:-1] + div[..., 1:-1], 1 + c_e + c_w)
        b = torch.nn.functional.pad(b, (1, 1), "constant", 0)
        zeta_new = zeta_n + zeta_dt
        A = torch.zeros(BATCH_SIZE, 1, N, N)
        A[..., torch.arange(N), torch.arange(N)] = diag_c
        A[..., torch.arange(N - 1), torch.arange(1, N)] = diag_cm1
        A[..., torch.arange(1, N), torch.arange(N - 1)] = diag_cp1
        loss = self.mse_loss(torch.matmul(A, zeta_new.unsqueeze(-1)), b.unsqueeze(-1))
        return loss

In [None]:
# Velocity Integration function
def integrate_vel(vel_n, zeta_new, zeta_old):
    h_n = H0 + zeta_old
    h_m = (h_n[..., 1:] + h_n[..., :-1]) / 2.0

    zeta_old_dx = torch.diff(zeta_old, axis=-1) / DX
    zeta_new_dx = torch.diff(zeta_new, axis=-1) / DX
    zeta_old_dx = torch.nn.functional.pad(zeta_old_dx, (1, 1), "constant", 0)
    zeta_new_dx = torch.nn.functional.pad(zeta_new_dx, (1, 1), "constant", 0)
    h_m = torch.nn.functional.pad(h_m, (1, 1), "constant", 1)

    vel_star = (
        vel_n
        - DT * CD * torch.div(torch.mul(torch.abs(vel_n), vel_n), h_m)
        - DT * G * (1 - W_IMP) * zeta_old_dx
        - DT * G * W_IMP * zeta_new_dx
    )
    vel_new = vel_star - DT * G * W_IMP * zeta_new_dx
    return vel_new

In [None]:
BATCH_SIZE = 1
train_dataset = SWDataset(file_path="simulation_data.h5", order=ORDER, numtime=NUMTIME, mode="train", normalize=False)
valid_dataset = SWDataset(file_path="simulation_data.h5", order=ORDER, numtime=NUMTIME, mode="valid", normalize=False)

dataloader_train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
dataloader_valid = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [None]:
# Training Loop for the hybrid model
class PITrainerZeta:
    def __init__(self, model, device, train_dataset, optimizer, criterion, writer, epochs):
        self.model = model
        self.device = device
        self.train_dataset = train_dataset
        self.optimizer = optimizer
        self.criterion = criterion
        self.writer = writer
        self.EPOCHS = EPOCHS
        self.NUMTIME = NUMTIME
    def train_epoch(self):
        train_losses = []
        self.model.train()

        init_zeta, init_vel = self.train_dataset.get_initial_conditions()
        self.optimizer.zero_grad()
        accumulated_loss = 0
        input_zeta, input_vel = init_zeta.to(self.device), init_vel.to(self.device)

        for _ in range(self.NUMTIME):
            output_zeta_dt = self.model(input_zeta, input_vel)
            next_zeta = input_zeta + output_zeta_dt
            next_vel = integrate_vel(input_vel, next_zeta, input_zeta)

            unsupervised_loss = self.criterion(output_zeta_dt, input_zeta, input_vel)
            accumulated_loss += unsupervised_loss

            input_zeta, input_vel = next_zeta, next_vel

        accumulated_loss.backward()
        self.optimizer.step()
        train_losses.append(accumulated_loss.item())

        return np.mean(train_losses)

    def train(self):
        for epoch in range(self.EPOCHS):
            train_loss = self.train_epoch()
            self.writer.add_scalar("Loss/train", train_loss, epoch)

            print(
                f"Epoch {epoch+1}/{self.EPOCHS} Train Loss: {train_loss:.8f}"
            )

        torch.save(self.model.state_dict(), "unet_model.pth")

batch = next(iter(dataloader_train))
input_vel_batch = batch[0][1]
mask = torch.ones_like(input_vel_batch).to(device)
mask[..., 0] = 0
mask[..., -1] = 0
model = ZetaUNet(order=ORDER, mask=mask).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = HybridLoss()
trainer = PITrainerZeta(model, device, train_dataset, optimizer, criterion, writer, epochs=EPOCHS)
trainer.train()

### Step 10: Physical Constraints Using the semi-implicit discretization

The physical constraints can also be applied using the discretized form of the SWE. Network can output both velocity and elevations and the outputs can be forced to satisfy the following discretized equation

$u^{n+1} = u^n - \Delta t C_D\frac{1}{h}u^n|u^n|- \Delta t g (1-w_{\textbf{imp}}) \frac{\partial \zeta^{n}}{\partial x}-\Delta t g w_{\textbf{imp}} \frac{\partial \zeta^{n+1}}{\partial x}$

$\zeta^{n+1} = \zeta^n - \Delta t (1-w_{\textbf{imp}}) \frac{\partial h^n u^n}{\partial x}-\Delta t w_{\textbf{imp}} \frac{\partial h^n u^{n+1}}{\partial x}.
$



###  Our tasks are as follows:

1- Create the loss for updating the $\zeta$ and $v$ values using the semi implicit discretization of SWE.

2- Call the UNet model which outputs both field variables.

3- Modify the training loop of our model.

4- Run the training loop.

In [None]:
# Physics informed loss function for Semi implicit discretized SWE
class SemiImpLoss(nn.Module):
    def __init__(self):
        super(SemiImpLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, zeta_dt, vel_dt, zeta_n, vel_n):
        h_n = H0 + zeta_n
        h_n_tilde = (h_n[..., :-1] + h_n[..., 1:]) / 2.0

        zeta_nx = torch.diff(zeta_n, axis=-1) / DX
        zeta_n1x = torch.diff(zeta_n + zeta_dt, axis=-1) / DX
        zeta_nx = torch.nn.functional.pad(zeta_nx, (1, 1), "constant", 0)
        zeta_n1x = torch.nn.functional.pad(zeta_n1x, (1, 1), "constant", 0)

        h_n_tilde = torch.nn.functional.pad(h_n_tilde, (1, 1), "constant", 1)

        U_nx = torch.diff(torch.mul(h_n_tilde, vel_n))
        U_n1x = torch.diff(torch.mul(h_n_tilde, vel_n + vel_dt))

        mass_loss = vel_dt + DT * (
            CD * torch.div(torch.mul(vel_n, torch.abs(vel_n)), h_n_tilde)
            + G * (1 - W_IMP) * zeta_nx
            + G * W_IMP * zeta_n1x
        )
        mom_loss = zeta_dt + DT * (1 - W_IMP) * U_nx + DT * W_IMP * U_n1x

        loss1 = torch.mean(mass_loss**2)
        loss2 = torch.mean(mom_loss**2)

        return loss1 + loss2

In [None]:
# Training Loop for the hybrid model
class PITrainerZetaV:
    def __init__(self, model, device, train_dataset, optimizer, criterion, writer, epochs):
        self.model = model
        self.device = device
        self.train_dataset = train_dataset
        self.optimizer = optimizer
        self.criterion = criterion
        self.writer = writer
        self.EPOCHS = EPOCHS

    def train_epoch(self):
        train_losses = []
        self.model.train()

        init_zeta, init_vel = self.train_dataset.get_initial_conditions()
        self.optimizer.zero_grad()
        accumulated_loss = 0
        input_zeta, input_vel = init_zeta.to(self.device), init_vel.to(self.device)

        for _ in range(self.NUMTIME):
            output_zeta_dt, output_vel_dt = self.model(input_zeta, input_vel, mask)
            next_zeta = input_zeta + output_zeta_dt
            next_vel = input_vel + output_vel_dt

            unsupervised_loss = self.criterion(output_zeta_dt, output_vel_dt, input_zeta, input_vel)
            accumulated_loss += unsupervised_loss

            input_zeta, input_vel = next_zeta, next_vel

        accumulated_loss.backward()
        self.optimizer.step()
        train_losses.append(accumulated_loss.item())

        return np.mean(train_losses)

    def train(self):
        for epoch in range(self.EPOCHS):
            train_loss = self.train_epoch()
            self.writer.add_scalar("Loss/train", train_loss, epoch)

            print(
                f"Epoch {epoch+1}/{self.EPOCHS} Train Loss: {train_loss:.8f}"
            )

        torch.save(self.model.state_dict(), "unet_model.pth")

model = UnetMask(order=ORDER).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = SemiImpLoss()
trainer = PITrainerZetaV(model, device, dataloader_train, optimizer, criterion, writer, epochs=EPOCHS)
trainer.train()

### Step 11: Equivariant Convolutions?