## Initialization

In [1]:
import torch, argparse
import numpy as np

import os, sys

# sys.path.append(os.path.abspath(os.path.join('..')))

THIS_DIR = os.path.abspath(os.path.join('.'))
PARENT_DIR = os.path.dirname(THIS_DIR)
sys.path.append(PARENT_DIR)

from nn_models import MLP
from hnn import HNN
from utils import L2_loss, rk4
from data import get_dataset

print(THIS_DIR)
print(PARENT_DIR)

/Users/ZongyuWu/hamiltonian-nn/experiment-spring
/Users/ZongyuWu/hamiltonian-nn


## Help Functions

In [2]:
# print squared loss at specific steps for comparison with HNN
print_every = 200
def print_results(results, print_every=200):
    for step in range(0, len(results["train_loss"]), print_every):
        print(
            "step {}, train_loss {:.4e}, test_loss {:.4e}".format(
                step,
                results["train_loss"][step],
                results["test_loss"][step],
            )
        )
    # print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
    #     .format(results["train_loss"][-1], results["train_std"][-1],
    #             results["test_loss"][-1], results["test_std"][-1]))

def print_best(results):
    curr_min = 0

    for step in range(0, len(results["train_loss"])):
        if results["test_loss"][step] < results["test_loss"][curr_min]:
            curr_min = step
    print(
        "best test loss at step {}, train_loss {:.4e}, test_loss {:.4e}".format(
            curr_min,
            results["train_loss"][curr_min],
            results["test_loss"][curr_min],
        )
    )

In [3]:
def get_args():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument(
        "--input_dim", default=2, type=int, help="dimensionality of input tensor"
    )
    parser.add_argument(
        "--hidden_dim", default=200, type=int, help="hidden dimension of mlp"
    )
    parser.add_argument("--learn_rate", default=1e-3, type=float, help="learning rate")
    parser.add_argument(
        "--nonlinearity", default="tanh", type=str, help="neural net nonlinearity"
    )
    parser.add_argument(
        "--total_steps", default=2000, type=int, help="number of gradient steps"
    )
    parser.add_argument(
        "--print_every",
        default=200,
        type=int,
        help="number of gradient steps between prints",
    )
    parser.add_argument(
        "--name", default="spring", type=str, help="only one option right now"
    )
    parser.add_argument(
        "--baseline",
        dest="baseline",
        action="store_true",
        help="run baseline or experiment?",
    )
    parser.add_argument(
        "--use_rk4",
        dest="use_rk4",
        action="store_true",
        help="integrate derivative with RK4",
    )
    parser.add_argument(
        "--verbose", dest="verbose", action="store_true", help="verbose?"
    )
    parser.add_argument(
        "--kan", dest="kan", action="store_true", help="use kan instead of mlp?"
    )
    parser.add_argument(
        "--field_type",
        default="solenoidal",
        type=str,
        help="type of vector field to learn",
    )
    parser.add_argument("--seed", default=0, type=int, help="random seed")
    parser.add_argument(
        "--save_dir", default=THIS_DIR, type=str, help="where to save the trained model"
    )
    parser.set_defaults(feature=True)
    return parser.parse_args()

In [4]:
def train(args):
    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer
    if args.verbose:
        print("Training baseline model:" if args.baseline else "Training HNN model:")

    output_dim = args.input_dim if args.baseline else 2
    nn_model = MLP(args.input_dim, args.hidden_dim, output_dim, args.nonlinearity)
    model = HNN(
        args.input_dim,
        differentiable_model=nn_model,
        field_type=args.field_type,
        baseline=args.baseline,
    )
    optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4)

    # arrange data
    data = get_dataset(seed=args.seed)
    x = torch.tensor(data["x"], requires_grad=True, dtype=torch.float32)
    test_x = torch.tensor(data["test_x"], requires_grad=True, dtype=torch.float32)
    dxdt = torch.Tensor(data["dx"])
    test_dxdt = torch.Tensor(data["test_dx"])

    # vanilla train loop
    stats = {"train_loss": [], "test_loss": []}
    for step in range(args.total_steps + 1):

        # train step
        dxdt_hat = (
            model.rk4_time_derivative(x) if args.use_rk4 else model.time_derivative(x)
        )
        loss = L2_loss(dxdt, dxdt_hat)
        loss.backward()
        optim.step()
        optim.zero_grad()

        # run test data
        test_dxdt_hat = (
            model.rk4_time_derivative(test_x)
            if args.use_rk4
            else model.time_derivative(test_x)
        )
        test_loss = L2_loss(test_dxdt, test_dxdt_hat)

        # logging
        stats["train_loss"].append(loss.item())
        stats["test_loss"].append(test_loss.item())
        if args.verbose and step % args.print_every == 0:
            print(
                "step {}, train_loss {:.4e}, test_loss {:.4e}".format(
                    step, loss.item(), test_loss.item()
                )
            )

    train_dxdt_hat = model.time_derivative(x)
    train_dist = (dxdt - train_dxdt_hat) ** 2
    test_dxdt_hat = model.time_derivative(test_x)
    test_dist = (test_dxdt - test_dxdt_hat) ** 2
    print(
        "Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}".format(
            train_dist.mean().item(),
            train_dist.std().item() / np.sqrt(train_dist.shape[0]),
            test_dist.mean().item(),
            test_dist.std().item() / np.sqrt(test_dist.shape[0]),
        )
    )

    return model, stats

## Create Dataset

In [5]:
# create dataset
data = get_dataset(seed=0, mode=3, noise_std=0.05, keep_frequencies=10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.tensor(data["x"], requires_grad=True, dtype=torch.float32).to(device)
test_x = torch.tensor(data["test_x"], requires_grad=True, dtype=torch.float32).to(device)
dxdt = torch.Tensor(data["dx"]).to(device)
test_dxdt = torch.Tensor(data["test_dx"]).to(device)

# dataset['train_input'], dataset['train_label'],dataset['test_input'], dataset['test_label']
dataset = {
    "train_input": x,
    "train_label": dxdt,
    "test_input": test_x,
    "test_label": test_dxdt,
}

In [6]:
dataset["train_input"].shape, dataset["train_label"].shape

(torch.Size([7500, 2]), torch.Size([50, 2, 300]))

## Model Training

In [7]:

args = get_args()
# args.total_steps = 2000
args.verbose = True
model_hnn, stats_hnn = train(args)

# save
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
label = '-hnn'
path = "{}/{}{}.tar".format(args.save_dir, args.name, label)
torch.save(model_hnn.state_dict(), path)

Training HNN model:
step 0, train_loss 8.3700e-01, test_loss 8.4779e-01
step 200, train_loss 4.1400e-04, test_loss 4.9047e-04
step 400, train_loss 1.7449e-04, test_loss 2.0220e-04
step 600, train_loss 1.0028e-04, test_loss 1.1732e-04
step 800, train_loss 8.1844e-05, test_loss 9.4788e-05
step 1000, train_loss 8.5017e-05, test_loss 8.8129e-05
step 1200, train_loss 6.8045e-05, test_loss 7.9083e-05
step 1400, train_loss 6.1737e-05, test_loss 6.9726e-05
step 1600, train_loss 5.7677e-05, test_loss 6.6365e-05
step 1800, train_loss 5.4834e-05, test_loss 6.2784e-05
step 2000, train_loss 1.4473e-04, test_loss 2.3001e-04
Final train loss 2.2422e-04 +/- 2.9853e-06
Final test loss 2.3001e-04 +/- 3.1967e-06


In [8]:
print_results(stats_hnn, 1)
print_best(stats_hnn)

step 0, train_loss 8.3700e-01, test_loss 8.4779e-01
step 1, train_loss 8.2794e-01, test_loss 8.4367e-01
step 2, train_loss 8.2402e-01, test_loss 8.3460e-01
step 3, train_loss 8.1504e-01, test_loss 8.2635e-01
step 4, train_loss 8.0684e-01, test_loss 8.2051e-01
step 5, train_loss 8.0107e-01, test_loss 8.1471e-01
step 6, train_loss 7.9539e-01, test_loss 8.0737e-01
step 7, train_loss 7.8823e-01, test_loss 7.9926e-01
step 8, train_loss 7.8032e-01, test_loss 7.9160e-01
step 9, train_loss 7.7287e-01, test_loss 7.8451e-01
step 10, train_loss 7.6600e-01, test_loss 7.7711e-01
step 11, train_loss 7.5881e-01, test_loss 7.6874e-01
step 12, train_loss 7.5062e-01, test_loss 7.5950e-01
step 13, train_loss 7.4156e-01, test_loss 7.4987e-01
step 14, train_loss 7.3211e-01, test_loss 7.3994e-01
step 15, train_loss 7.2239e-01, test_loss 7.2925e-01
step 16, train_loss 7.1195e-01, test_loss 7.1728e-01
step 17, train_loss 7.0026e-01, test_loss 7.0394e-01
step 18, train_loss 6.8723e-01, test_loss 6.8946e-01
ste

In [9]:
args.baseline = True
model_baseline, stats_baseline = train(args)

# save
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
label = '-baseline'
path = "{}/{}{}.tar".format(args.save_dir, args.name, label)
torch.save(model_baseline.state_dict(), path)

Training baseline model:
step 0, train_loss 8.9177e-01, test_loss 9.7633e-01
step 200, train_loss 1.8746e-03, test_loss 1.0905e-03
step 400, train_loss 5.7407e-05, test_loss 3.5991e-05
step 600, train_loss 3.7400e-05, test_loss 7.1743e-05
step 800, train_loss 1.7748e-03, test_loss 1.2458e-03
step 1000, train_loss 2.2464e-05, test_loss 2.6716e-05
step 1200, train_loss 2.2518e-05, test_loss 2.6828e-05
step 1400, train_loss 2.1751e-05, test_loss 2.5885e-05
step 1600, train_loss 4.3046e-05, test_loss 5.0439e-05
step 1800, train_loss 2.8079e-05, test_loss 3.6848e-05
step 2000, train_loss 2.2038e-05, test_loss 2.6113e-05
Final train loss 2.1999e-05 +/- 4.0849e-07
Final test loss 2.6113e-05 +/- 4.4631e-07


In [10]:
print_results(stats_baseline, 1)
print_best(stats_baseline)

step 0, train_loss 8.9177e-01, test_loss 9.7633e-01
step 1, train_loss 9.5587e-01, test_loss 8.2766e-01
step 2, train_loss 8.0879e-01, test_loss 7.8033e-01
step 3, train_loss 7.6315e-01, test_loss 7.7991e-01
step 4, train_loss 7.6415e-01, test_loss 7.1956e-01
step 5, train_loss 7.0481e-01, test_loss 6.4505e-01
step 6, train_loss 6.3109e-01, test_loss 6.0629e-01
step 7, train_loss 5.9309e-01, test_loss 5.8841e-01
step 8, train_loss 5.7607e-01, test_loss 5.5222e-01
step 9, train_loss 5.4087e-01, test_loss 4.9717e-01
step 10, train_loss 4.8690e-01, test_loss 4.4882e-01
step 11, train_loss 4.3965e-01, test_loss 4.1928e-01
step 12, train_loss 4.1118e-01, test_loss 3.9483e-01
step 13, train_loss 3.8769e-01, test_loss 3.5939e-01
step 14, train_loss 3.5307e-01, test_loss 3.1649e-01
step 15, train_loss 3.1087e-01, test_loss 2.8031e-01
step 16, train_loss 2.7532e-01, test_loss 2.5555e-01
step 17, train_loss 2.5118e-01, test_loss 2.3272e-01
step 18, train_loss 2.2897e-01, test_loss 2.0340e-01
ste