## Initialization

In [42]:
import sys
import os

# Add the parent directory to sys.path
THIS_DIR = os.path.abspath(os.path.join('.'))
sys.path.insert(0, THIS_DIR)
sys.path.append(os.path.abspath(os.path.join('..')))
import torch, argparse
import numpy as np
from nn_models import MLP
from hnn import HNN
from data import get_dataset, get_trajectory
from utils import L2_loss, rk4

import scipy.integrate

solve_ivp = scipy.integrate.solve_ivp
print(THIS_DIR)


/Users/ZongyuWu/hamiltonian-nn/experiment-spring


## Help Functions

In [43]:
# print squared loss at specific steps for comparison with HNN
print_every = 200
def print_results(results, print_every=200):
    for step in range(0, len(results["train_loss"]), print_every):
        print(
            "step {}, train_loss {:.4e}, test_loss {:.4e}".format(
                step,
                results["train_loss"][step],
                results["test_loss"][step],
            )
        )
    print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
        .format(results["train_loss"][-1], results["train_std"][-1],
                results["test_loss"][-1], results["test_std"][-1]))

def print_best(results):
    curr_min = 0

    for step in range(0, len(results["train_loss"])):
        if results["test_loss"][step] < results["test_loss"][curr_min]:
            curr_min = step
    print(
        "best test loss at step {}, train_loss {:.4e}, test_loss {:.4e}".format(
            curr_min,
            results["train_loss"][curr_min],
            results["test_loss"][curr_min],
        )
    )


In [44]:
def integrate_model(model, t_span, y0, **kwargs):
    
    def fun(t, np_x):
        x = torch.tensor( np_x, requires_grad=True, dtype=torch.float32).view(1,2)
        dx = model.time_derivative(x).data.numpy().reshape(-1)
        return dx

    return solve_ivp(fun=fun, t_span=t_span, y0=y0, **kwargs)

In [45]:
def get_args():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument(
        "--input_dim", default=2, type=int, help="dimensionality of input tensor"
    )
    parser.add_argument(
        "--hidden_dim", default=200, type=int, help="hidden dimension of mlp"
    )
    parser.add_argument("--learn_rate", default=1e-3, type=float, help="learning rate")
    parser.add_argument(
        "--nonlinearity", default="tanh", type=str, help="neural net nonlinearity"
    )
    parser.add_argument(
        "--total_iterations", default=10, type=int, help="number of active learning iterations"
    )
    parser.add_argument(
        "--epoch_per_iter", default=200, type=int, help="number of epochs per iteration"
    )
    parser.add_argument(
        "--sample_per_iter", default=25, type=int, help="number of samples generated per iteration"
    )
    parser.add_argument(
        "--print_every",
        default=200,
        type=int,
        help="number of gradient steps between prints",
    )
    parser.add_argument(
        "--name", default="spring", type=str, help="only one option right now"
    )
    parser.add_argument(
        "--baseline",
        dest="baseline",
        action="store_true",
        help="run baseline or experiment?",
    )
    parser.add_argument(
        "--use_rk4",
        dest="use_rk4",
        action="store_true",
        help="integrate derivative with RK4",
    )
    parser.add_argument(
        "--verbose", dest="verbose", action="store_true", help="verbose?"
    )
    parser.add_argument(
        "--kan", dest="kan", action="store_true", help="use kan instead of mlp?"
    )
    parser.add_argument(
        "--field_type",
        default="solenoidal",
        type=str,
        help="type of vector field to learn",
    )
    parser.add_argument("--seed", default=0, type=int, help="random seed")
    parser.add_argument(
        "--save_dir", default=THIS_DIR, type=str, help="where to save the trained model"
    )
    parser.set_defaults(feature=True)
    return parser.parse_args()


In [46]:
def train(args):
    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # Initialize model and optimizer
    if args.verbose:
        print("Training baseline model:" if args.baseline else "Training HNN model:")
    output_dim = args.input_dim if args.baseline else 2
    nn_model = MLP(args.input_dim, args.hidden_dim, output_dim, args.nonlinearity)
    model = HNN(
        args.input_dim,
        differentiable_model=nn_model,
        field_type=args.field_type,
        baseline=args.baseline,
    )
    optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4)

    # Generate initial dataset
    points_per_sample = 30 # Change if made any change to t_span or timescale. Adjust this to change the number of points per sample. 
    data = get_dataset(seed=args.seed, samples=args.sample_per_iter * 2) # 2x for train/test split
    x = torch.tensor(data["x"], requires_grad=True, dtype=torch.float32)
    test_x = torch.tensor(data["test_x"], requires_grad=True, dtype=torch.float32)
    dxdt = torch.Tensor(data["dx"])
    test_dxdt = torch.Tensor(data["test_dx"])
    
    # Active learning loop
    stats = {"train_loss": [], "test_loss": []}
    for iter in range(args.total_iterations):
        # Train model for epoch_per_iter epochs
        for epoch in range(args.epoch_per_iter + 1):
            # Train one step
            dxdt_hat = (
                model.rk4_time_derivative(x) if args.use_rk4 else model.time_derivative(x)
            )
            loss = L2_loss(dxdt, dxdt_hat)
            loss.backward()
            optim.step()
            optim.zero_grad()

            # run test data
            test_dxdt_hat = (
                model.rk4_time_derivative(test_x)
                if args.use_rk4
                else model.time_derivative(test_x)
            )
            test_loss = L2_loss(test_dxdt, test_dxdt_hat)

            # logging
            stats["train_loss"].append(loss.item())
            stats["test_loss"].append(test_loss.item())
            if args.verbose and epoch % args.print_every == 0:
                print(
                    "iter {} step {}, train_loss {:.4e}, test_loss {:.4e}".format(
                        iter, epoch, loss.item(), test_loss.item()
                    )
                )
                
        # Generate new train data, randomly sample inputs
        data = {}
        xs, dxs = [], []
        for s in range(args.sample_per_iter):
            q, p, dq, dp, t = get_trajectory(t_span=[0, 3], timescale=10, radius=None, y0=None, noise_std=0.1)
            xs.append(np.stack([q, p]).T)
            dxs.append(np.stack([dq, dp]).T)
        data["x"] = np.concatenate(xs)
        data["dx"] = np.concatenate(dxs).squeeze()
        
        # Merge new data with previous train set
        x = torch.cat([x, torch.tensor(data["x"], requires_grad=True, dtype=torch.float32)])
        dxdt = torch.cat([dxdt, torch.Tensor(data["dx"])])
        
        # From the initial point of each sample, integrate their path and get the hamiltonian
        t_span=[0,10]
        t_eval = np.linspace(t_span[0], t_span[1], 1000)
        kwargs = {'t_eval': t_eval, 'rtol': 1e-12}
        hamiltonians = []
        for i in range(0, len(x), points_per_sample):
            path = integrate_model(model, t_span, x[i].detach().numpy(), **kwargs)
            hamiltonian = model(torch.Tensor(path.y.T))[1].detach().numpy().squeeze()
            hamiltonians.append((i, hamiltonian))
        
        # Sort hamiltonians by std 
        hamiltonians.sort(key=lambda h: h[1].std().item(), reverse=True)

        # Select train set by top args.sample_per_iter hamiltonians
        x = torch.stack([x[i] for i, _ in hamiltonians[:args.sample_per_iter * points_per_sample]])
        dxdt = torch.stack([dxdt[i] for i, _ in hamiltonians[:args.sample_per_iter * points_per_sample]])

    # Get final train and test loss
    train_dxdt_hat = model.time_derivative(x)
    train_dist = (dxdt - train_dxdt_hat) ** 2
    test_dxdt_hat = model.time_derivative(test_x)
    test_dist = (test_dxdt - test_dxdt_hat) ** 2
    print(
        "Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}".format(
            train_dist.mean().item(),
            train_dist.std().item() / np.sqrt(train_dist.shape[0]),
            test_dist.mean().item(),
            test_dist.std().item() / np.sqrt(test_dist.shape[0]),
        )
    )

    return model, stats

## Model Training

In [47]:
args = get_args()
args.verbose = True
model, stats = train(args)

# save
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
label = "-active-hnn" 
path = "{}/{}{}.tar".format(args.save_dir, args.name, label)
torch.save(model.state_dict(), path)


Training HNN model:
iter 0 step 0, train_loss 8.0827e-01, test_loss 7.7982e-01
iter 0 step 200, train_loss 3.6939e-02, test_loss 3.6360e-02
iter 1 step 0, train_loss 3.7899e-02, test_loss 5.8920e-02
iter 1 step 200, train_loss 2.9611e-02, test_loss 4.4670e-02
iter 2 step 0, train_loss 4.7311e-02, test_loss 6.0883e-02
iter 2 step 200, train_loss 2.0326e-02, test_loss 4.7045e-02
iter 3 step 0, train_loss 5.3151e-02, test_loss 8.1623e-02
iter 3 step 200, train_loss 1.8468e-02, test_loss 5.7951e-02
iter 4 step 0, train_loss 7.4321e-02, test_loss 5.6075e-02
iter 4 step 200, train_loss 2.4988e-02, test_loss 8.8482e-02
iter 5 step 0, train_loss 5.1097e-02, test_loss 8.6034e-02
iter 5 step 200, train_loss 1.6459e-02, test_loss 5.6507e-02
iter 6 step 0, train_loss 5.0103e-02, test_loss 5.4365e-02
iter 6 step 200, train_loss 1.9843e-02, test_loss 5.7485e-02
iter 7 step 0, train_loss 5.8281e-02, test_loss 5.6575e-02
iter 7 step 200, train_loss 1.5808e-02, test_loss 6.0593e-02
iter 8 step 0, train

In [48]:
args.baseline = True
model, stats = train(args)

# save
os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
label = "-active-baseline" 
path = "{}/{}{}.tar".format(args.save_dir, args.name, label)
torch.save(model.state_dict(), path)


Training baseline model:
iter 0 step 0, train_loss 8.6160e-01, test_loss 9.0597e-01
iter 0 step 200, train_loss 3.8308e-02, test_loss 3.8129e-02
iter 1 step 0, train_loss 4.0415e-02, test_loss 5.6570e-02
iter 1 step 200, train_loss 3.1882e-02, test_loss 4.0645e-02
iter 2 step 0, train_loss 4.0925e-02, test_loss 4.4863e-02
iter 2 step 200, train_loss 2.7554e-02, test_loss 3.9633e-02
iter 3 step 0, train_loss 4.4358e-02, test_loss 1.1742e-01
iter 3 step 200, train_loss 2.9178e-02, test_loss 4.5318e-02
iter 4 step 0, train_loss 6.1258e-02, test_loss 4.5581e-02
iter 4 step 200, train_loss 3.8507e-02, test_loss 4.5118e-02
iter 5 step 0, train_loss 3.3670e-02, test_loss 4.5419e-02
iter 5 step 200, train_loss 2.8148e-02, test_loss 3.9434e-02
iter 6 step 0, train_loss 3.3291e-02, test_loss 4.0455e-02
iter 6 step 200, train_loss 2.7397e-02, test_loss 4.0383e-02
iter 7 step 0, train_loss 4.3547e-02, test_loss 4.4896e-02
iter 7 step 200, train_loss 3.3889e-02, test_loss 3.8631e-02
iter 8 step 0, 