## Initialization

In [1]:
import torch, argparse
import numpy as np

import os, sys

# sys.path.append(os.path.abspath(os.path.join('..')))

THIS_DIR = os.path.abspath(os.path.join('.'))
PARENT_DIR = os.path.dirname(THIS_DIR)
sys.path.insert(0, THIS_DIR)  # Prioritize the current directory
sys.path.append(PARENT_DIR)

from nn_models import MLP
from hnn import HNN
from utils import L2_loss, rk4
from data import get_dataset

In [2]:
print(THIS_DIR)
print(PARENT_DIR)

/Users/ZongyuWu/hamiltonian-nn/experiment-spring
/Users/ZongyuWu/hamiltonian-nn


## Help Functions

In [3]:
# print squared loss at specific steps for comparison with HNN
print_every = 200
def print_results(results, print_every=200):
    for step in range(0, len(results["train_loss"]), print_every):
        print(
            "step {}, train_loss {:.4e}, test_loss {:.4e}".format(
                step,
                results["train_loss"][step],
                results["test_loss"][step],
            )
        )
    # print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
    #     .format(results["train_loss"][-1], results["train_std"][-1],
    #             results["test_loss"][-1], results["test_std"][-1]))

def print_best(results):
    curr_min = 0

    for step in range(0, len(results["train_loss"])):
        if results["test_loss"][step] < results["test_loss"][curr_min]:
            curr_min = step
    print(
        "best test loss at step {}, train_loss {:.4e}, test_loss {:.4e}".format(
            curr_min,
            results["train_loss"][curr_min],
            results["test_loss"][curr_min],
        )
    )

In [4]:
def get_args():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument(
        "--input_dim", default=2, type=int, help="dimensionality of input tensor"
    )
    parser.add_argument(
        "--hidden_dim", default=200, type=int, help="hidden dimension of mlp"
    )
    parser.add_argument("--learn_rate", default=1e-3, type=float, help="learning rate")
    parser.add_argument(
        "--nonlinearity", default="tanh", type=str, help="neural net nonlinearity"
    )
    parser.add_argument(
        "--total_steps", default=2000, type=int, help="number of gradient steps"
    )
    parser.add_argument(
        "--print_every",
        default=200,
        type=int,
        help="number of gradient steps between prints",
    )
    parser.add_argument(
        "--name", default="spring", type=str, help="only one option right now"
    )
    parser.add_argument(
        "--baseline",
        dest="baseline",
        action="store_true",
        help="run baseline or experiment?",
    )
    parser.add_argument(
        "--use_rk4",
        dest="use_rk4",
        action="store_true",
        help="integrate derivative with RK4",
    )
    parser.add_argument(
        "--verbose", dest="verbose", action="store_true", help="verbose?"
    )
    parser.add_argument(
        "--kan", dest="kan", action="store_true", help="use kan instead of mlp?"
    )
    parser.add_argument(
        "--field_type",
        default="solenoidal",
        type=str,
        help="type of vector field to learn",
    )
    parser.add_argument("--seed", default=0, type=int, help="random seed")
    parser.add_argument(
        "--save_dir", default=THIS_DIR, type=str, help="where to save the trained model"
    )
    parser.set_defaults(feature=True)
    return parser.parse_args()

In [5]:
def train(args):
    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer
    if args.verbose:
        print("Training baseline model:" if args.baseline else "Training HNN model:")

    output_dim = args.input_dim if args.baseline else 2
    nn_model = MLP(args.input_dim, args.hidden_dim, output_dim, args.nonlinearity)
    model = HNN(
        args.input_dim,
        differentiable_model=nn_model,
        field_type=args.field_type,
        baseline=args.baseline,
    )
    optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4)

    # arrange data
    data = get_dataset(seed=args.seed)
    x = torch.tensor(data["x"], requires_grad=True, dtype=torch.float32)
    test_x = torch.tensor(data["test_x"], requires_grad=True, dtype=torch.float32)
    dxdt = torch.Tensor(data["dx"])
    test_dxdt = torch.Tensor(data["test_dx"])

    # vanilla train loop
    stats = {"train_loss": [], "test_loss": []}
    for step in range(args.total_steps + 1):

        # train step
        dxdt_hat = (
            model.rk4_time_derivative(x) if args.use_rk4 else model.time_derivative(x)
        )
        loss = L2_loss(dxdt, dxdt_hat)
        loss.backward()
        optim.step()
        optim.zero_grad()

        # run test data
        test_dxdt_hat = (
            model.rk4_time_derivative(test_x)
            if args.use_rk4
            else model.time_derivative(test_x)
        )
        test_loss = L2_loss(test_dxdt, test_dxdt_hat)

        # logging
        stats["train_loss"].append(loss.item())
        stats["test_loss"].append(test_loss.item())
        if args.verbose and step % args.print_every == 0:
            print(
                "step {}, train_loss {:.4e}, test_loss {:.4e}".format(
                    step, loss.item(), test_loss.item()
                )
            )

    train_dxdt_hat = model.time_derivative(x)
    train_dist = (dxdt - train_dxdt_hat) ** 2
    test_dxdt_hat = model.time_derivative(test_x)
    test_dist = (test_dxdt - test_dxdt_hat) ** 2
    print(
        "Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}".format(
            train_dist.mean().item(),
            train_dist.std().item() / np.sqrt(train_dist.shape[0]),
            test_dist.mean().item(),
            test_dist.std().item() / np.sqrt(test_dist.shape[0]),
        )
    )

    return model, stats

## Model Training

In [6]:
args = get_args()
args.verbose = True
model_hnn, stats_hnn = train(args)

# save
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
label = "-baseline" if args.baseline else "-hnn"
label = "-rk4" + label if args.use_rk4 else label
path = "{}/{}{}.tar".format(args.save_dir, args.name, label)
torch.save(model_hnn.state_dict(), path)

Training HNN model:
step 0, train_loss 1.7496e-01, test_loss 1.8012e-01
step 200, train_loss 4.7808e-03, test_loss 5.0553e-03
step 400, train_loss 4.7097e-03, test_loss 4.9127e-03
step 600, train_loss 5.0039e-03, test_loss 4.9029e-03
step 800, train_loss 4.6973e-03, test_loss 4.8961e-03
step 1000, train_loss 4.6943e-03, test_loss 4.8916e-03
step 1200, train_loss 4.6927e-03, test_loss 4.8941e-03
step 1400, train_loss 4.6901e-03, test_loss 4.8879e-03
step 1600, train_loss 4.6899e-03, test_loss 4.8929e-03
step 1800, train_loss 4.6944e-03, test_loss 4.8899e-03
step 2000, train_loss 4.6875e-03, test_loss 4.8922e-03
Final train loss 4.6874e-03 +/- 8.5865e-05
Final test loss 4.8922e-03 +/- 9.3158e-05


In [7]:
print_results(stats_hnn, 1)
print_best(stats_hnn)

step 0, train_loss 1.7496e-01, test_loss 1.8012e-01
step 1, train_loss 1.7955e-01, test_loss 1.7560e-01
step 2, train_loss 1.7134e-01, test_loss 1.7801e-01
step 3, train_loss 1.6928e-01, test_loss 1.8047e-01
step 4, train_loss 1.6984e-01, test_loss 1.7642e-01
step 5, train_loss 1.6668e-01, test_loss 1.7016e-01
step 6, train_loss 1.6273e-01, test_loss 1.6590e-01
step 7, train_loss 1.6090e-01, test_loss 1.6370e-01
step 8, train_loss 1.6017e-01, test_loss 1.6169e-01
step 9, train_loss 1.5824e-01, test_loss 1.5964e-01
step 10, train_loss 1.5518e-01, test_loss 1.5845e-01
step 11, train_loss 1.5248e-01, test_loss 1.5809e-01
step 12, train_loss 1.5079e-01, test_loss 1.5712e-01
step 13, train_loss 1.4920e-01, test_loss 1.5440e-01
step 14, train_loss 1.4676e-01, test_loss 1.5042e-01
step 15, train_loss 1.4376e-01, test_loss 1.4644e-01
step 16, train_loss 1.4104e-01, test_loss 1.4310e-01
step 17, train_loss 1.3877e-01, test_loss 1.4011e-01
step 18, train_loss 1.3627e-01, test_loss 1.3711e-01
ste

In [8]:
args.baseline = True
model_baseline, stats_baseline = train(args)

# save
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
label = "-baseline" if args.baseline else "-hnn"
label = "-rk4" + label if args.use_rk4 else label
path = "{}/{}{}.tar".format(args.save_dir, args.name, label)
torch.save(model_baseline.state_dict(), path)

Training baseline model:
step 0, train_loss 2.2651e-01, test_loss 3.3679e-01
step 200, train_loss 4.6939e-03, test_loss 4.8829e-03
step 400, train_loss 5.3013e-03, test_loss 6.0300e-03
step 600, train_loss 5.5114e-03, test_loss 6.2766e-03
step 800, train_loss 4.6927e-03, test_loss 4.8818e-03
step 1000, train_loss 4.6922e-03, test_loss 4.8815e-03
step 1200, train_loss 4.6921e-03, test_loss 4.8815e-03
step 1400, train_loss 4.6917e-03, test_loss 4.8813e-03
step 1600, train_loss 4.6916e-03, test_loss 4.8813e-03
step 1800, train_loss 4.6912e-03, test_loss 4.8811e-03
step 2000, train_loss 4.7084e-03, test_loss 5.0752e-03
Final train loss 4.9363e-03 +/- 9.0650e-05
Final test loss 5.0752e-03 +/- 9.7005e-05


In [9]:
print_results(stats_baseline, 1)
print_best(stats_baseline)

step 0, train_loss 2.2651e-01, test_loss 3.3679e-01
step 1, train_loss 3.0509e-01, test_loss 2.1577e-01
step 2, train_loss 1.9913e-01, test_loss 1.6612e-01
step 3, train_loss 1.6575e-01, test_loss 1.9338e-01
step 4, train_loss 2.0112e-01, test_loss 1.7193e-01
step 5, train_loss 1.7892e-01, test_loss 1.3108e-01
step 6, train_loss 1.3168e-01, test_loss 1.2522e-01
step 7, train_loss 1.1788e-01, test_loss 1.4264e-01
step 8, train_loss 1.3005e-01, test_loss 1.3804e-01
step 9, train_loss 1.2531e-01, test_loss 1.0760e-01
step 10, train_loss 9.8823e-02, test_loss 8.0798e-02
step 11, train_loss 7.7522e-02, test_loss 7.4396e-02
step 12, train_loss 7.5737e-02, test_loss 7.5963e-02
step 13, train_loss 7.9542e-02, test_loss 6.7957e-02
step 14, train_loss 7.1266e-02, test_loss 5.2820e-02
step 15, train_loss 5.4101e-02, test_loss 4.4409e-02
step 16, train_loss 4.2982e-02, test_loss 4.6593e-02
step 17, train_loss 4.2928e-02, test_loss 4.8724e-02
step 18, train_loss 4.4190e-02, test_loss 4.1640e-02
ste