In [None]:
import os
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

from tqdm import tqdm
from forcing_functions import get_function
from finite_element_code import set_up_4_dof, integrate_rk4, get_mck
from dataclasses import dataclass

# Set the random seed so that results are consistent across trials
torch.manual_seed(123)

# Creating The SIREN Class
The SIREN network is described in [this paper](https://arxiv.org/abs/2006.09661)

The main points from the paper that will be used here are:
- Neural networks are functions of the form $u(x) = l_N \circ l_{N-1} \circ \cdots \circ l_1$ 
  - Each $l_i$ is a function representing a layer of the network
  - $N$ is the number of layers in the network
  - The layers take the form $l_i(x) = a( W_i x + b_i )$
    - $a$ is the activation function (in this case $a(x) = \sin (\omega_0 x)$, where $\omega_0$ is a hyper-parameter)

SIRENs are especially good at handling differentiation of the network with respect to its input. This is because
$$\frac{d}{dx}\sin(x) = \cos(x) = \sin\left(\frac{\pi}{2}-x\right)$$

In [None]:
# Define the layers l_1 ... l_(N-1). l_N is just a linear layer, Wx + b
class SineLayer(nn.Module):
    """
    Adapted from
    https://github.com/vsitzmann/siren/blob/master/explore_siren.ipynb
    """

    def __init__(
        self, in_features, out_features, bias=True, is_first=False, omega_0=30
    ):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0,
                    np.sqrt(6 / self.in_features) / self.omega_0,
                )

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))


# Define the SIREN network
class Siren(nn.Module):
    """
    Adapted from
    https://github.com/vsitzmann/siren/blob/master/explore_siren.ipynb
    """

    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        outermost_linear=False,
        first_omega_0=30,
        hidden_omega_0=30.0,
    ):
        super().__init__()

        self.net = []
        self.net.append(
            SineLayer(
                in_features, hidden_features, is_first=True, omega_0=first_omega_0
            )
        )

        for i in range(hidden_layers):
            self.net.append(
                SineLayer(
                    hidden_features,
                    hidden_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(
                    -np.sqrt(6 / hidden_features) / hidden_omega_0,
                    np.sqrt(6 / hidden_features) / hidden_omega_0,
                )

            self.net.append(final_linear)
        else:
            self.net.append(
                SineLayer(
                    hidden_features,
                    out_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = (
            coords.clone().detach().requires_grad_(True)
        )  # allows to take derivative w.r.t. input
        output = self.net(coords)

        return output, coords


class OtherActivationNetwork(nn.Module):
    # Same as Siren class, except the activation function can be chosen
    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        outermost_linear=False,
        activation=nn.Tanh,
    ):
        super().__init__()
        self.activation = activation

        self.net = []
        self.net.append(
            nn.Sequential(nn.Linear(in_features, hidden_features), self.activation())
        )

        for _ in range(hidden_layers):
            self.net.append(
                nn.Sequential(
                    nn.Linear(
                        hidden_features,
                        hidden_features,
                    ),
                    self.activation(),
                )
            )

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            self.net.append(final_linear)
        else:
            self.net.append(
                nn.Sequential(
                    nn.Linear(hidden_features, out_features), self.activation()
                )
            )

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = (
            coords.clone().detach().requires_grad_(True)
        )  # allows to take derivative w.r.t. input
        output = self.net(coords)

        return output, coords


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad

# The Problem


In [None]:
INITIAL_TIME = 0
FINAL_TIME = 2.9
N_TIMESTEPS = 290  # Number of time-steps to simulate
DOF_DATA = [2, 3]  # Which degrees of freedom to give data on
N_DATAPOINTS = 20  # Number of datapoints to train on
FUNCTION = "sine"  # options are "sine" "gaussian" "chirp"
DATA_CUTOFF_TIME = 2  # Time after which to stop giving data

t = torch.linspace(INITIAL_TIME, FINAL_TIME, N_TIMESTEPS, device="cuda")

x, y, E, nu, rho, a0, a1, fdim, free_indices, ndim, sens_dim = set_up_4_dof()
ffun = get_function(
    FUNCTION,
    INITIAL_TIME,
    FINAL_TIME,
    amplitude=-1000,
    total_dimensions=ndim,
    force_dimension=fdim,
)
m, c, k = get_mck(x, y, E, nu, rho, a0, a1, free_indices, device="cuda")
displacements, velocities = integrate_rk4(m, c, k, ffun, t)

cutoff_index = torch.argmin((t - DATA_CUTOFF_TIME).abs())
data_indices_step = int(cutoff_index / N_DATAPOINTS)
data_indices = torch.arange(0, cutoff_index, data_indices_step)

i, j = torch.meshgrid(torch.as_tensor(DOF_DATA), data_indices, indexing="ij")
displacement_data = displacements[i, j]
velocity_data = velocities[i, j]
t_data = t[data_indices]

plt.figure()
plt.title("Forcing Function")
plt.xlabel("Time, $t$ [sec]")
plt.ylabel("Force, $f$ [N]")
plt.plot(t.cpu(), ffun(t)[fdim, :].cpu())

plt.figure()
plt.title("Reference Displacement")
plt.xlabel("Time, $t$ [sec]")
plt.ylabel("Force, $f$ [N]")
plt.plot(t.cpu(), displacements.cpu().T, linewidth=1)
plt.plot(
    t_data.cpu(),
    displacement_data.cpu().T,
    linestyle="None",
    marker="x",
    color="purple",
)
plt.legend(["u1x", "u1y", "u2x", "u2y"])

# Define Model Training Function


In [None]:
# Calculate dt
dt = t[1] - t[0]


# Define Physics Loss Function
def physics_loss(v1_predicted, u0, v0, f0, dt):
    a0_predicted = (v1_predicted - v0) / dt
    residual_vector = m @ a0_predicted + c @ v0 + k @ u0 - f0
    loss = torch.nn.functional.mse_loss(residual_vector, residual_vector * 0)
    return loss


# Class to store loss info
@dataclass
class LossHistory:
    physics_loss: "list[float]"
    data_loss: "list[float]"
    total_loss: "list[float]"


# Function that trains the model
def train_model(model, optimizer, n_epochs, physics_weight, update_interval=10):

    loss_history = LossHistory([], [], [])
    pbar = tqdm(range(n_epochs))

    for epoch in pbar:

        optimizer.zero_grad()

        # accumulator variables
        phys_loss = 0
        data_loss = 0

        u0 = torch.zeros(1, 4, device="cuda")

        v0 = torch.zeros(1, 4, device="cuda")

        for time_index in range(len(t)):

            t0 = t[time_index]

            f0 = ffun(t0)

            f0_magnitude = f0.max()

            model_input = torch.hstack(
                (
                    u0,
                    v0,
                    f0_magnitude.reshape(1, -1),
                )
            )

            v1_predicted, _ = model(model_input)

            phys_loss = phys_loss + physics_weight * physics_loss(
                v1_predicted=v1_predicted,
                u0=u0.T,
                v0=v0.T,
                f0=f0,
                dt=dt,
            )

            # Data loss

            if time_index in data_indices:

                u1_predicted = u0 + v0 * dt

                # Select u1_predicted at nodes where data is available

                u1_predicted_data = u1_predicted[:, DOF_DATA]

                v1_predicted_data = v1_predicted[:, DOF_DATA]

                displacement_error = torch.nn.functional.mse_loss(
                    u1_predicted_data.squeeze(),
                    displacement_data[:, time_index % data_indices_step],
                )

                velocity_error = torch.nn.functional.mse_loss(
                    v1_predicted_data.squeeze(),
                    velocity_data[:, time_index % data_indices_step],
                )

                data_loss = data_loss + (displacement_error + velocity_error) * (
                    4 * len(t)
                ) / (N_DATAPOINTS * len(DOF_DATA))

            # Propagate u and v

            u0 = u0 + v0 * dt

            v0 = v1_predicted.detach()

        # Add losses
        loss = data_loss + phys_loss

        # Store losses for graphing later
        loss_history.data_loss.append(float(data_loss.detach().item()))
        loss_history.physics_loss.append(float(phys_loss.detach().item()))
        loss_history.total_loss.append(float(loss.detach().item()))

        # Backpropagate loss throughtout network - get derivatives of W and b w.r.t. loss

        loss.backward()

        # Update the network parameters

        optimizer.step()

        if not (epoch % update_interval):
            pbar.set_description_str(
                f"Epoch: {epoch} - L_p={float(phys_loss): 4g} - L_c={float(data_loss): 4g}"
            )

    return model, loss_history

## Define Visualization Function

In [None]:
def visualize(model, savefile, restrict_bounds_to_true_sol=False):
    all_state_vectors = torch.zeros(len(t), 8, device="cuda")

    u0 = all_state_vectors[0, :4].reshape(1, -1)
    v0 = all_state_vectors[0, 4:].reshape(1, -1)

    for time_index in range(len(t)):
        t0 = t[time_index]

        f0 = ffun(t0)

        f0_magnitude = f0.max()

        model_input = torch.hstack(
            (
                u0,
                v0,
                f0_magnitude.reshape(1, -1),
            )
        )

        v1_predicted, _ = model(model_input)
        u0 = u0 + v0 * dt

        v0 = v1_predicted.detach()

        all_state_vectors[time_index, :4] = u0.squeeze()
        all_state_vectors[time_index, 4:] = v0.squeeze()

    fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
    for i in range(axes.shape[0]):
        for j in range(axes.shape[1]):
            ax = axes[i, j]
            # i-th Node
            # j-th DOF

            # EXAMPLE: node 1, y-direction -> i = 0, j = 1 -> node_idx = 2*0 + 1 = 1 (second element in the array)
            node_idx = 2 * i + j

            # Get model prediction for time-history at this node
            model_prediction_u = all_state_vectors[:, node_idx].squeeze()

            # '4 + ...' is necessary because state vector looks like [u1x u1y u2x u2y v1x v1y v2x v2y]
            #                                                                         ^- index 4
            # model_prediction_v = all_state_vectors[:, 4 + node_idx].squeeze()

            # Get ground-truth (reference solution)
            true_u = displacements[node_idx, :].squeeze()
            # true_v = velocities[node_idx, :].squeeze()

            # Get given data
            ct = 0
            if node_idx in DOF_DATA:
                t_data = t[data_indices]
                u_data = displacement_data[ct].squeeze()
                # v_data = velocity_data[ct].squeeze()
                ct += 1

            # Make plots
            ax.plot(t.cpu(), true_u.cpu(), color="red", label="Reference")
            ax.plot(
                t.cpu(),
                model_prediction_u.cpu(),
                color="#808080",
                linestyle="dashed",
                label="Prediction",
            )

            if node_idx in DOF_DATA:
                ax.plot(
                    t_data.cpu(),
                    u_data.cpu(),
                    color="black",
                    marker="x",
                    linestyle="none",
                    label="Data",
                )

            if i == 0 and j == 0:
                ax.legend()

            ax.set_title(f"$u_{{ {i + 1} { 'x' if j == 0 else 'y' } }}$")

            if restrict_bounds_to_true_sol:
                ax.set_ylim(displacements.min().cpu(), displacements.max().cpu())

    fig.suptitle("Prediction vs. Ground Truth")
    fig.supxlabel("Time, $t$ [sec]")
    fig.supylabel("Displacement, $u$ [m]")
    fig.tight_layout()
    plt.savefig(savefile, bbox_inches="tight")
    plt.close(fig)


def plot_loss(loss_history, savefile):
    fig = plt.figure()
    plt.plot(loss_history.total_loss, color="black", label="Total")
    plt.plot(
        loss_history.data_loss, color="blue", linestyle="dashed", label="Data"
    )
    plt.plot(
        loss_history.physics_loss,
        color="red",
        linestyle="dashed",
        label="Physics (Weighted)",
    )
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss History")
    plt.yscale("log")
    plt.tight_layout()
    fig.savefig(savefile, bbox_inches="tight")
    plt.close(fig)

# Perform Training

In [None]:
LEARNING_RATE = 1e-4
N_EPOCHS = 10_000

model = Siren(
    in_features=4 + 4 + 1,  # 4 displacement, 4 velocity, 1 forcing
    hidden_features=32,
    hidden_layers=3,
    out_features=4,  # 4 velocities
    outermost_linear=True,
).to(device="cuda")

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

model, loss_history = train_model(
    model=model,
    optimizer=optimizer,
    n_epochs=N_EPOCHS,
    physics_weight=1e-16,
)

visualize(
    model=model,
    savefile="test.png",
    restrict_bounds_to_true_sol=True,
)

plot_loss(
    loss_history=loss_history,
    savefile="testloss.png",
)