In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from model import PINN
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--plot-every")
args = parser.parse_args()
plot_every = args.plot_every

In [2]:
# Define callbacks
def epoch_logger(epoch, model, physics_loss, data_loss, **kw):
    if epoch % 1000 != 0:
        return
    print(
        "Epoch %d: Physics Loss %.4g, Node3YVel Loss %.4g, Node4YVel Loss %.4g, E = %.4g"
        % (epoch, physics_loss, data_loss[0], data_loss[1], model.E())
    )

if not os.path.exists("plots"):
    os.mkdir("plots")

def plotter(epoch, model, data_t, u_pred_t, **kw):
    if epoch % plot_every != 0:
        return
    E = model.E()
    t = data_t.detach().cpu()
    v_pred = u_pred_t.detach().cpu()

    fig, ax = plt.subplots(4, 1, figsize=(6, 12), sharex=True)
    for dim in range(len(ax)):
        axes = ax[dim]

        axes.plot(t, v_pred[:, dim], label="Prediction")

        if dim == 1:
            axes.plot(
                t,
                model.node3_vel_y.detach().cpu(),
                marker=".",
                markersize=3,
                linestyle="None",
                label="Data",
            )
        elif dim == 3:
            axes.plot(
                t,
                model.node4_vel_y.detach().cpu(),
                marker=".",
                markersize=3,
                linestyle="None",
                label="Data",
            )
    fig.suptitle("Epoch = %d\nE = %.5g" % (epoch, E * 1e8))
    plt.savefig(
        os.path.join("plots", "%d.png" % epoch)
    )
    plt.close()


# Define model
layers = [1] + 3 * [32] + [4]
sigmas = [1, 10, 50]
model = PINN(layers, sigmas)
model.load_ops_data()
model.compile(
    torch.optim.Adam(list(model.parameters()) + [model.a], lr=1e-3),
    callbacks=[epoch_logger, plotter],
    loss_weights=[1e-12, 1e2, 1e2],
)
model.train()

Epoch 0: Physics Loss 0, Node3YVel Loss 1244, Node4YVel Loss 487.5, E = 0.256
Epoch 1000: Physics Loss 0, Node3YVel Loss 39.34, Node4YVel Loss 89.33, E = 0.256
Epoch 2000: Physics Loss 0, Node3YVel Loss 19.32, Node4YVel Loss 81.89, E = 0.256
Epoch 3000: Physics Loss 0, Node3YVel Loss 17.24, Node4YVel Loss 76.89, E = 0.256
Epoch 4000: Physics Loss 0, Node3YVel Loss 15.61, Node4YVel Loss 72.91, E = 0.256
Epoch 5000: Physics Loss 0, Node3YVel Loss 14.15, Node4YVel Loss 69.4, E = 0.256
Epoch 6000: Physics Loss 0, Node3YVel Loss 13.04, Node4YVel Loss 66.09, E = 0.256
Epoch 7000: Physics Loss 0, Node3YVel Loss 12.25, Node4YVel Loss 62.9, E = 0.256
Epoch 8000: Physics Loss 0, Node3YVel Loss 11.65, Node4YVel Loss 59.74, E = 0.256
Epoch 9000: Physics Loss 0, Node3YVel Loss 11.08, Node4YVel Loss 56.66, E = 0.256
Epoch 10000: Physics Loss 0, Node3YVel Loss 10.51, Node4YVel Loss 53.76, E = 0.256
Epoch 11000: Physics Loss 0, Node3YVel Loss 9.938, Node4YVel Loss 51.17, E = 0.256
Epoch 12000: Physics

KeyboardInterrupt: 