In [1]:
import numpy as np
import torch
import random
import time
import argparse
import gc

import os
from datetime import datetime
from functools import partial, wraps
from shadow.plot import *
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'DejaVu Sans'

from graph_loader import graph_loader

import sys
MAINPATH = ".."  # nopep8
sys.path.append(MAINPATH)  # nopep8
from src import io_file
from src.HGNN import *
from src.models import *
# from src import models
# import importlib
# importlib.reload(models)

def namestr(obj, namespace):
    return [name for name in namespace if namespace[name] is obj]

def pprint(*args, namespace=globals()):
    for arg in args:
        print(f"{namestr(arg, namespace)[0]}: {arg}")

In [2]:
datapoints = None
rname=True
withdata = None

randfilename = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") + f"_{datapoints}"

PSYS = f"LJ3dab"
TAG = f"HGNN"
out_dir = f"../results"

In [3]:
def _filename(name, tag=TAG):
    rstring = randfilename if (rname and (tag != "data")) else (
        "0" if (tag == "data") or (withdata == None) else f"0_{withdata}")
    filename_prefix = f"{out_dir}/{PSYS}-{tag}/{rstring}/"
    file = f"{filename_prefix}/{name}"
    os.makedirs(os.path.dirname(file), exist_ok=True)
    filename = f"{filename_prefix}/{name}".replace("//", "/")
    print("===", filename, "===")
    return filename

def OUT(f):
    @wraps(f)
    def func(file, *args, tag=TAG, **kwargs):
        return f(_filename(file, tag=tag), *args, **kwargs)
    return func

loadfile = OUT(io_file.loadfile)
savefile = OUT(io_file.savefile)
save_ovito = OUT(io_file.save_ovito)

In [4]:
def get_arg_parser():
    parser = argparse.ArgumentParser(description='StriderNET Torch arguments')
    #Directories
    parser.add_argument('--out_dir', default='./Output/',help='Output directory')
    #Model args:
    parser.add_argument('--node_emb_size', type=int, default=5,help='node embedding size')
    parser.add_argument('--edge_emb_size', type=int, default=5,help='edge embedding size')
    parser.add_argument('--hidden_emb_size', type=int, default=5,help='Hidden embedding size')
    
    parser.add_argument('--fa_layers', type=int, default=2,help='Initial node embedding MLP layers')
    parser.add_argument('--fb_layers', type=int, default=2,help='Initial edge embedding MLP layers')
    parser.add_argument('--fe_layers', type=int, default=2,help='Edge update MLP layers')
    parser.add_argument('--fv_layers', type=int, default=2,help='node update MLP layers')
    parser.add_argument('--MLP1_layers', type=int, default=2,help='MLP layers for node F from edge attribute')
    parser.add_argument('--MLP1_F_out_dim', type=int, default=3,help='Force dimentions')
    # parser.add_argument('--MLP2_layers', type=int, default=3,help='Displacement prediction MLP layers')
    # parser.add_argument('--sigma', type=float, default=2.0,help='Displacement scaling factor')
    parser.add_argument('--message_passing_steps', type=int, default=2,help='No. of message passing steps')
    # parser.add_argument('--alpha',type=float,default=1e-6,help='Multivariate Gaussian standard deviation hyperparameter')
    # parser.add_argument('--disp_cutoff', type=float, default=0.1,help='Cutoff for predicted displacement, for trianing stability')
    
    #Training Args:
    parser.add_argument('--epochs', type=int, default=10000,help='No. of training epochs')
    parser.add_argument('--lr', type=float, default=1e-4,help='Learning Rate')
    parser.add_argument('--b_sz', type=int, default=20,help='Training batch size')
    parser.add_argument('--seed', type=int, default=42,help='Seed value')
    parser.add_argument('--dt', type=float, default=1e-5,help='dt value')
    parser.add_argument('--kT', type=int, default=1,help='kT value')
    parser.add_argument('--A', type=int, default=1,help='A value')
    parser.add_argument('--B', type=int, default=125,help='B value')
    # parser.add_argument('--train_len_ep', type=int, default=10,help='Length of optimization trajectory episode during training')
    # parser.add_argument('--val_len_ep', type=int, default=20,help='Length of optimization trajectory episode during validation')
    # parser.add_argument('--val_freq', type=int, default=10,help='Frequency of validation')
    parser.add_argument('--cuda', action='store_true',help='use CUDA')
    return parser



In [5]:
import torch.autograd as autograd

def get_zdot_lambda(N, Dim, hamiltonian, drag=None, constraints=None, external_force=None):
    dim = N * Dim
    I = torch.eye(dim)
    J = torch.zeros((2 * dim, 2 * dim))
    J[:dim, dim:] = I
    J[dim:, :dim] = -I
    J2 = torch.zeros((2 * dim, 2 * dim))
    J2[:dim, :dim] = I
    J2[dim:, dim:] = I

    def dH_dz(state_graph, params):
        dH_dx = autograd.grad(hamiltonian, inputs='position', create_graph=True)(state_graph, params)
        dH_dp = autograd.grad(hamiltonian, inputs='velocity', create_graph=True)(state_graph, params)
        return torch.cat([dH_dx.flatten(), dH_dp.flatten()])

    if drag is None:
        def drag(state_graph, params):
            return 0.0

    def dD_dz(state_graph, params):
        dD_dx = autograd.grad(drag, inputs='position', create_graph=True)(state_graph, params)
        dD_dp = autograd.grad(drag, inputs='velocity', create_graph=True)(state_graph, params)
        return torch.cat([dD_dx.flatten(), dD_dp.flatten()])

    if external_force is None:
        def external_force(state_graph, params):
            return torch.zeros_like(state_graph["velocity"])

    if constraints is None:
        def constraints(state_graph, params):
            return torch.zeros((1, 2 * dim))

    def fn_zdot(state_graph, params):
        dH = dH_dz(state_graph, params)
        dD = J2 @ dD_dz(state_graph, params)
        dD = -J @ dD
        F = torch.cat([torch.zeros(dim), external_force(state_graph, params).flatten()])
        F = -J @ F
        S = dH + J2 @ dD + F
        A = constraints(state_graph, params).reshape(-1, 2 * dim)
        Aᵀ = A.t()
        INV = torch.pinverse(A @ J @ Aᵀ)
        λ = -INV @ A @ J @ S
        zdot = J @ (S + Aᵀ @ λ)
        return zdot.reshape(2 * N, Dim)

    def lambda_force(state_graph, params):
        dH = dH_dz(state_graph, params)
        dD = J2 @ dD_dz(state_graph, params)
        dD = -J @ dD
        F = torch.cat([torch.zeros(dim), external_force(state_graph, params).flatten()])
        F = -J @ F
        S = dH + J2 @ dD + F
        A = constraints(state_graph, params).reshape(-1, 2 * dim)
        Aᵀ = A.t()
        INV = torch.pinverse(A @ J @ Aᵀ)
        λ = -INV @ A @ J @ S
        return (J @ Aᵀ @ λ).reshape(2 * N, Dim)

    return fn_zdot, lambda_force


def get_constraints(N, Dim, phi_, mass=None):
    if mass is None:
        mass = 1.0

    def phi(x): return phi_(x.reshape(N, Dim))

    def phidot(x, p):
        Dphi = autograd.grad(phi, inputs='x', create_graph=True)(x.flatten())
        pm = (p.flatten() / mass)
        return Dphi @ pm

    def psi(z):
        x, p = torch.split(z, 2)
        return torch.vstack([phi(x), phidot(x, p)])

    def Dpsi(z):
        return autograd.jacobian(psi)(z)

    def fn(x, p, params):
        z = torch.vstack([x, p])
        return Dpsi(z)

    return fn


In [6]:
g_loader = graph_loader()
Train_loader= g_loader.create_batched_States(Batch_size=20)
Init_batch=next(iter(Train_loader))

100%|██████████| 10000/10000 [00:28<00:00, 350.51it/s]


In [7]:
model=HGNN(in_edge_feats=Init_batch['edge_attr'].shape[1],
                in_node_feats=Init_batch['x'].shape[1],
                in_type_ohe_size=Init_batch['type'].shape[1],
                node_emb_size=8,
                edge_emb_size=8,
                hidden_emb_size=5,
                use_ke_model=True,
                kemlp_layers=2,
                fa_layers=2,
                fb_layers=2,
                fe_layers=2,
                fv_layers=2,
                MLP1_layers=2,
                message_passing_steps=2
            )

In [8]:
zdot_model, lamda_force_model = get_zdot_lambda(
    *Init_batch['position'].shape[-2:], model)