In [12]:
import numpy as onp
import scipy.stats as st
from jax.config import config  
config.update('jax_enable_x64', True)
from models import Pol_Net, Policy_Net

import jax
import jax.numpy as jnp
from jax import random,jit,lax
import jax
from jax import ops
from jax.example_libraries import optimizers
from jax.lax import fori_loop
import jraph
import time

import jax_md
from jax_md import space, smap, energy, minimize, quantity, simulate
from typing import Any, NamedTuple,  Optional, Union
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from jax._src import prng
Array = Any
KeyArray = Union[Array, prng.PRNGKeyArray]
import timeit
import matplotlib
import matplotlib.pyplot as plt

#Debugging
from Systems import lennard_jones, SW_Silicon, CSH
import Systems
from Generic_system import Generic_system
from Utils import *
from Optimizers import *




def get_index(elem,arr):
    return jnp.where(elem==arr,size=1)[0][0]
get_indeces_fn=jax.jit(jax.vmap(get_index,in_axes=(0,None)))


def Batch_choose_topK_e_greedy(Mux_Muy,probs,key,B_sz,K=1,epsilon=1.0):#eps= 1.0 during training
    """Chooses n-nodes*K nodes , K from each graph, after renormalizing the output probabilities
    """
    N=int(Mux_Muy.shape[0]/B_sz)
    node_indeces=onp.zeros((B_sz,K),dtype='int32')
    node_probs=jnp.zeros((B_sz,K))
    #choosen_Mu_vec=jnp.zeros((B_sz,N,2))
    key1, key2 = random.split(key, 2)
    keys=random.split(key2,B_sz)
    sample_p = random.uniform(key1,(B_sz,))
    for p in range(len(sample_p)):
        myprobs=probs[p*N:(p+1)*N]
        myprobs=myprobs*(1/jnp.sum(myprobs))
        if(sample_p[p]<epsilon):
            #Greedy
            random_choice=random.choice(keys[p],myprobs,shape=(K,1),replace=False,p=myprobs.reshape(-1))
        else:
            #Random
            random_choice=random.choice(keys[p],myprobs,shape=(K,1),replace=False)  
        ind=get_indeces_fn(random_choice,myprobs)
        node_indeces[p,:]=ind
        node_probs=node_probs.at[p,:].set(myprobs[ind].reshape(-1))
        #choosen_Mu_vec=choosen_Mu_vec.at[p,ind].set(Mux_Muy[p*N+ind])
    return node_indeces, node_probs

#Disp Function

           #Disp Function

def pdf_multivariate_gauss(x, mu, cov):
    """Removed part1 for scaling reason[w/0 part1 it is in 0 to 1]: pdf gives density not probability"""
   
    '''
    Caculate the multivariate normal density (pdf)
    
    Keyword arguments:
        x = numpy array of a "d x 1" sample vector
        mu = numpy array of a "d x 1" mean vector
        cov = "numpy array of a d x d" covariance matrix
    '''
    #assert(mu.shape[0] > mu.shape[1]), 'mu must be a row vector'
    #assert(x.shape[0] > x.shape[1]), 'x must be a row vector'
    #assert(cov.shape[0] == cov.shape[1]), 'covariance matrix must be square'
    #assert(mu.shape[0] == cov.shape[0]), 'cov_mat and mu_vec must have the same dimensions'
    #assert(mu.shape[0] == x.shape[0]), 'mu and x must have the same dimensions'
    #part1 = 1 / ( ((2* jnp.pi)**(len(mu)/2)) * (jnp.linalg.det(cov)**(1/2)) )
    part2 = (-1/2) * ((x-mu).T.dot(jnp.linalg.inv(cov))).dot((x-mu))
    #return part1 * jnp.exp(part2)
    return jnp.exp(part2)

vmap_pdf_multivariate_gauss=jax.jit(jax.vmap(jax.vmap(pdf_multivariate_gauss)))

def Batch_pred_disp_vec(Mu,key,B_sz=1,std=0.01,spatial_dim=3):
    mean = Mu
    K=mean.shape[1]
    cov = jnp.array([jnp.eye(spatial_dim)*(std**2)]*B_sz*K).reshape((B_sz,K,spatial_dim,spatial_dim))
    Pred_disp= jax.random.multivariate_normal(key,mean,cov)
    probs=vmap_pdf_multivariate_gauss(Pred_disp,mean,cov)
    return Pred_disp, jnp.log(probs), probs
    #Does not performs the dislpacement of node here, only predicts the node and displacement vector                       
          


#Discounted reward function
@jax.jit
def get_discounted_returns(Rewards,Y=0.9):
    """Calculates discounted rewards"""
    res=jnp.zeros(Rewards.shape)
    #res=Rewards
    Temp_G=onp.zeros((Rewards.shape[0],))
    for k in range(Rewards.shape[1]-1,-1,-1):
        Temp_G=Rewards[:,k]+Y*Temp_G
        res=res.at[:,k].set(Temp_G)
    return res
    
#Defining loss function
@jax.jit
def Traj_Loss_fn(*,log_probs, Returns):
    return jnp.sum(log_probs*Returns,axis=1)

#Grad_add function
@jax.jit
def add_grads(grad1,grad2):
    return jax.tree_map(lambda x,y:x+y,grad1,grad2)

@jax.jit
def scalar_mult_grad(k,grad):
    return jax.tree_map(lambda x:k*x,grad)

def print_log(log,is_plot=False,epoch_id=0,Batch_id=0):
    B_sz=log['Reward'].shape[0]
    log_length=log['Reward'].shape[1]
    if(is_plot==True):
        for i in range(len(log['States'])):
            Sys.plot_batched(log['States'][i],epoch_id=epoch_id,batch_id=Batch_id,step_id=i,node_ids=log['Node_id'][i],Edges=True,save=False)
            #Sys.plot_frame_edge(ax1,log['States'][i],node_id=int(log['Node_id'][i]))
    for k in range(B_sz):
        print("\n#GraphNo. ",k+1)
        print("\nStep\tMax_Mu\tMean_Mu\t   Max_Disp  Mean_Disp\tLog_Total_prob\tReward\t d_PE\t  PE")
        for i in range(log['Reward'].shape[1]):
            print(i+1,"\t%8.5f  %8.5f  %8.5f  %8.5f  %8.5f  %8.5f  %8.5f  %8.5f"%(log['Max_Mu'][k][i],log['Mean_Mu'][k][i],log['Max_Disp'][k][i],log['Mean_Disp'][k][i],log['Total_prob'][k][i],log['Reward'][k][i],log['d_PE'][k][i],log['PE'][k][i]))

        

In [5]:
#Defining model

#from JMDSystem import MDTuple, My_system
key = random.PRNGKey(119)
Sys=Generic_system()

Train,Val,Test, shift_fn,Batch_pair_cutoffs,Batch_pair_sigma,Batch_Disp_Vec_fn,Batch_Node_energy_fn,Batch_Total_energy_fn,displacement_fn, shift_fn,Disp_Vec_fn =Sys.create_batched_States(random.PRNGKey(147),System='CSH',spatial_dimension=3,N=152, N_sample =100,Batch_size=3)    

[-1029908.49464028 -1029926.15007278 -1029971.64720984 -1030044.89145933
 -1030145.60295065 -1030273.13941423 -1030426.29824724 -1030603.29639151
 -1030802.33444609 -1031022.34919555 -1031262.429675   -1031520.87204498
 -1031795.7888478  -1032085.82260453 -1032389.51272812 -1032705.06057534
 -1033030.81348671 -1033365.35368776 -1033707.35922477 -1034055.61278503
 -1034409.00898816 -1034766.5437388  -1035127.31610485 -1035490.52969655
 -1035855.47073072 -1036221.5168519  -1036588.12343782 -1036954.80990915
 -1037321.16625166 -1037686.83758151 -1038051.52244854 -1038414.95822403
 -1038776.92697089 -1039137.25065287 -1039495.77567141 -1039852.38009444
 -1040206.95903613 -1040559.43489854 -1040909.74571237 -1041257.84492182
 -1041603.70145962 -1041947.28895437 -1042288.59742456 -1042627.62292729
 -1042964.36851668 -1043298.84314007 -1043631.05845305 -1043961.03688839
 -1044288.79757686 -1044614.36462925 -1044937.76659542 -1045259.02932403
 -1045578.18001264 -1045895.25131587 -1046210.27660

In [None]:
Test_R=[Systate_temp.R[k] for k in range(B_sz)]
    
Batch_Total_energy_fn(FIRE_desc(1e-2,len_ep,Test_R[k],Total_energy_fn,shift_fn)[1])
    

In [6]:
def Total_energy_fn(R):
    return Batch_Total_energy_fn(R[jnp.newaxis,:,:])[0]
Total_energy_fn=jax.jit(Total_energy_fn)

model=Pol_Net(edge_emb_size=48
    ,node_emb_size=48
    ,fa_layers=2
    ,fb_layers=2
    ,fv_layers=2
    ,fe_layers=2
    ,MLP1_layers=1
    ,MLP2_layers=4
    ,spatial_dim=3
    ,sigma=50
    ,train=True
    ,message_passing_steps=1)
    
#Initializing model parameters
key1, key2 =random.split(random.PRNGKey(147), 2)
params = model.init(key2, Train[0].Graph)

In [9]:
#loss function
def loss_fn(params,Systate,key,spatial_dim=3,len_ep=10,Batch_id=1):
    len_ep=len_ep
    K_disp=100
    Batch_id=1
    log_length=len_ep
    B_sz=Systate.N.shape[0]
    Systate_temp=Systate
    apply_fn=model.apply
    #apply_fn=jax.jit(model.apply)
    log_length=len_ep
    log = {
        'Max_Mu': jnp.zeros((B_sz,log_length,)),    
        'Max_Disp': jnp.zeros((B_sz,log_length,)),
        'Mean_Mu': jnp.zeros((B_sz,log_length,)),    
        'Mean_Disp': jnp.zeros((B_sz,log_length,)),
        #'Disp_prob':jnp.zeros((B_sz,log_length,)),
        'Total_prob':jnp.zeros((B_sz,log_length,)),
        'Reward':jnp.zeros((B_sz,log_length,)),
        'd_PE':jnp.zeros((B_sz,log_length,)),
        'PE':jnp.zeros((B_sz,log_length,)),
        'States':[]}
    #Systate_temp=Systate
    Test_R=[Systate_temp.R[k] for k in range(B_sz)]
    PE1=onp.zeros((len_ep,B_sz))
    for k in range(B_sz):
        PE1[:,k]=Batch_Total_energy_fn(FIRE_desc(1e-2,len_ep,Test_R[k],Total_energy_fn,shift_fn)[1])
    for i in range(len_ep):
        key, key1,key2 = random.split(key, 3)
    
        #1: Pass through Policy_net
        #Batch_G, Batch_node_probs, Batch_Mux_Muy = apply_fn(params, Systate_temp.Graph)
        (Batch_G, Batch_node_probs, Batch_Mux_Muy),mutated_vars = apply_fn(params, Systate_temp.Graph,mutable=['batch_stats'])
        #print("Batch_G.nodes",Batch_G.nodes)
        Batch_Mux_Muy=jnp.clip(Batch_Mux_Muy,-10,10)
        #2: Choose node and disp from prob distributions
        Batch_chosen_node_indeces,Batch_chosen_node_prob=Batch_choose_topK_e_greedy(Batch_Mux_Muy,Batch_node_probs,key=key1,B_sz=B_sz,K=K_disp,epsilon=1.0)#eps= 1.0 during training
        #To choose k nodes
        
        Node_mask=jnp.zeros(Batch_Mux_Muy.reshape(B_sz,-1,spatial_dim).shape)
        Node_mask=Node_mask.at[jnp.array([[i for i in range(Batch_chosen_node_indeces.shape[0])]]* Batch_chosen_node_indeces.shape[1]).T, Batch_chosen_node_indeces,:].set(1.0)
        Batch_Disp_vec, Batch_log_disp_prob,Batch_prob_disp= Batch_pred_disp_vec(Batch_Mux_Muy.reshape(B_sz,-1,spatial_dim),key2,B_sz=B_sz,std=1e-6)
        #Batch_log_node_prob=jnp.log(Batch_chosen_node_prob)
        #print(Batch_prob_disp)
        #print(Batch_log_prob,Batch_log_node_prob)
        #print(Batch_chosen_Mux_Muy-Batch_Disp_vec)
        #print(jnp.sum(Batch_log_disp_prob),jnp.sum(Batch_log_node_prob))
    
        #print(jnp.sum(Batch_log_disp_prob+Batch_log_node_prob))
        #print(jnp.sum(Batch_log_disp_prob*Node_mask[:,:,0],axis=1).shape)
        Log_Pi_a_given_s=jnp.sum(Batch_log_disp_prob*Node_mask[:,:,0],axis=1)
        #3: Displace all nodes with predicted displacement
    #     Batch_Disp_vec_pred=Batch_Mux_Muy.at[Batch_chosen_node_indeces].set(Batch_Disp_vec)
    #     Node_indeces=jnp.array([[i for i in range(Systate.R.shape[1])] for k in range(B_sz)])
        Systate_new,Batch_d_PE=Sys.multi_disp_node(Batch_Disp_vec*Node_mask,Systate_temp,shift_fn,Batch_pair_cutoffs,Batch_pair_sigma,Batch_Disp_Vec_fn,Batch_Node_energy_fn,Batch_Total_energy_fn)
        log['d_PE']=log['d_PE'].at[:,i].set(Batch_d_PE)
        log['PE']=log['PE'].at[:,i].set(Systate_temp.pe)
        #log['Node_id']=log['Node_id'].at[:,i].set(Batch_chosen_node_index)
        #log['Node_prob']=log['Node_prob'].at[:,i].set(Batch_chosen_node_prob.reshape(-1))
        Mu_magnitude=jnp.sum(Batch_Mux_Muy**2,axis=1).reshape((B_sz,-1,))
        Disp_magnitude=jnp.sum(Batch_Disp_vec**2,axis=2)
        
        log['Max_Mu']=log['Max_Mu'].at[:,i].set(jnp.sqrt(jnp.max(Mu_magnitude,axis=1)))
        log['Mean_Mu']=log['Mean_Mu'].at[:,i].set(jnp.sqrt(jnp.mean(Mu_magnitude,axis=1)))
        
        log['Max_Disp']=log['Max_Disp'].at[:,i].set(jnp.sqrt(jnp.max(Disp_magnitude,axis=1)))
        log['Mean_Disp']=log['Mean_Disp'].at[:,i].set(jnp.sqrt(jnp.mean(Disp_magnitude,axis=1)))
        
        
        #log['Disp_prob']=log['Disp_prob'].at[:,i].set(jnp.exp(jnp.sum(Batch_log_prob)))
        log['Total_prob']=log['Total_prob'].at[:,i].set(Log_Pi_a_given_s/10)
        log['Reward']=log['Reward'].at[:,i].set(-1*(Systate_new.pe-PE1[-1,:])+jnp.exp(-(0.05)*(Systate_new.pe-PE1[-1,:])))
        log['States']+=[Systate_temp]
        #Update current state
        Systate_temp=Systate_new
    loss_batch=Traj_Loss_fn(log_probs=log['Total_prob'],Returns=get_discounted_returns(log['Reward']))  #Shape: (B_sz,)
    #Taking sum of loss
    loss=jnp.sum(loss_batch)/B_sz
    return loss, (Systate_temp,log)    #Returns updated graph


In [10]:
    
#Defining optimizer
import optax
import flax
from flax.training import train_state
from flax import serialization
from flax.training import checkpoints as ckp
import matplotlib.pyplot as plt

#config.update('jax_disable_jit', True)
#config.update("jax_debug_nans", True)

#schedule = optax.warmup_cosine_decay_schedule(
#  init_value=1e-8,
#  peak_value=0.001,
#   warmup_steps=50,
#   decay_steps=500,
#   end_value=0.0,
# )

tx = optax.chain(
  optax.clip(0.5),
  optax.adam(learning_rate=0.005)
)
#tx=flax.optim.momentum(learning_rate=1e-3,beta=0.9,weight_decay=1,nestrov=True)

#tx = optax.sgd(learning_rate=0.001)
#Add l2 regularizer
#print(params)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

#Initial loss and gradients computation
((loss,(Systate_init,log_init)),init_grad)=loss_grad_fn(params,Train[0],random.PRNGKey(147))
print_log(log_init)
print("Initial Loss",loss)



2023-01-16 10:48:15.020032: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/civil/btech/ce1180169/anaconda3/envs/JaxEqv_Graph/lib
2023-01-16 10:48:15.020637: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/civil/btech/ce1180169/anaconda3/envs/JaxEqv_Graph/lib



#GraphNo.  1

Step	Max_Mu	Mean_Mu	   Max_Disp  Mean_Disp	Log_Total_prob	Reward	 d_PE	  PE
1 	 0.07082   0.06177   0.07082   0.06177  -15.00247  -17338.77741  -122.73729  -1051282.42417
2 	 0.07082   0.06173   0.07082   0.06173  -15.72501  -17226.46423  -112.31317  -1051405.16146
3 	 0.07082   0.06165   0.07082   0.06165  -14.42475  -17285.71634  59.25211  -1051517.47463
4 	 0.07082   0.06164   0.07082   0.06164  -16.34712  -17444.24633  158.52999  -1051458.22253
5 	 0.07082   0.06152   0.07082   0.06152  -13.82348  -17774.33139  330.08506  -1051299.69253

#GraphNo.  2

Step	Max_Mu	Mean_Mu	   Max_Disp  Mean_Disp	Log_Total_prob	Reward	 d_PE	  PE
1 	 0.07082   0.06195   0.07082   0.06195  -13.82765  -25327.09366  -132.84095  -1032085.82260
2 	 0.07082   0.06189   0.07082   0.06189  -12.54239  -25348.47027  21.37660  -1032218.66356
3 	 0.07082   0.06191   0.07082   0.06191  -14.63570  -25368.07904  19.60878  -1032197.28695
4 	 0.07082   0.06181   0.07082   0.06181  -14.75977  -25480.31708

In [13]:
#Defining optimizer
import optax
import flax
from flax.training import train_state
from flax import serialization
from flax.training import checkpoints as ckp
import matplotlib.pyplot as plt

#config.update('jax_disable_jit', True)
#config.update("jax_debug_nans", True)

#schedule = optax.warmup_cosine_decay_schedule(
#  init_value=1e-8,
#  peak_value=0.001,
#   warmup_steps=50,
#   decay_steps=500,
#   end_value=0.0,
# )

tx = optax.chain(
  optax.clip(0.5),
  optax.adam(learning_rate=0.005)
)
#tx=flax.optim.momentum(learning_rate=1e-3,beta=0.9,weight_decay=1,nestrov=True)

#tx = optax.sgd(learning_rate=0.001)
#Add l2 regularizer
#print(params)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

#Initial loss and gradients computation
((loss,(Systate_init,log_init)),init_grad)=loss_grad_fn(params,Train[0],random.PRNGKey(147))
print_log(log_init)
print("Initial Loss",loss)



#Training and Validation

Train_loss_data   =  []
Train_loss_epochs =  []
Val_loss_data     =  []
Val_loss_epochs   =  []
Batch_size        =  2
len_ep            =  15
Max_dataset_size  =  300
Print_freq        =  1
plot_freq         =  100
Model_save_freq   =  10
N_epoch           =  800
Val_freq          =  20
Reset_dataset_every= 100
keys1=random.split(random.PRNGKey(117721817),N_epoch)
for i in range(N_epoch):
    if((i+1)%Reset_dataset_every==0):
        print("Re creating dataset with new states")
        print("Curr_Traj_shape:",Traj.shape)
        #Sampling the accumulated trajectories
        Sampling_Indeces=jax.random.permutation(random.PRNGKey(961*(i+1)), jnp.arange(0,Traj.shape[0],1)) 
        Traj=Traj[Sampling_Indeces[:min(Max_dataset_size,Traj.shape[0])],...]
        Train,Val,Test, shift_fn,Batch_pair_cutoffs,Batch_pair_sigma,Batch_Disp_Vec_fn,Batch_Node_energy_fn,Batch_Total_energy_fn,displacement_fn, shift_fn,Disp_Vec_fn =Sys.create_batched_States(random.PRNGKey(147),System='LJ',spatial_dimension=3,N=100, N_sample =100,Batch_size=4,Traj=Traj)        
    grads_acc=scalar_mult_grad(0.0,init_grad)
    keys2=random.split(keys1[i],Batch_size)
    
    Batch_loss=0
    for p in range(Batch_size):
        ((loss_val,(Systate,log)),grads) = loss_grad_fn(params,Train[(Batch_size*i+p)%len(Train)],keys2[p],len_ep=len_ep)
        print("Batch: ",p,' ',loss_val)
        print_log(log,is_plot=False,epoch_id=i,Batch_id=p)
        Batch_loss+=loss_val
        grads_acc=add_grads(grads_acc,grads)
        if(i>0):
            if(jnp.sum(Train[(Batch_size*i+p)%len(Train)].pe)>jnp.sum(Systate.pe)):
                    Traj=jnp.concatenate([Traj,Systate.R])
    
    if((i+1)%Val_freq==0):
        Val_Batch_loss=0
        for p in range(Batch_size):
            ((val_loss_val,(val_Systate,val_log)),grads) = loss_grad_fn(params,Val[(Batch_size*i+p)%len(Val)],keys2[p],len_ep=20)
            print_log(val_log,is_plot=False,epoch_id=i,Batch_id=p)
            print("D_PE_sum:",jnp.sum(val_Systate.pe-Val[(Batch_size*i+p)%len(Val)].pe))
        
            Val_loss_data+=[Val_Batch_loss/Batch_size]
            Val_loss_epochs+=[i]
            Val_Batch_loss+=val_loss_val
        
    updates, opt_state = tx.update(scalar_mult_grad(1/Batch_size,grads_acc), opt_state,params)
    Train_loss_data+=[Batch_loss/Batch_size]
    Train_loss_epochs+=[i]
    #old_params=params
    #print("myGrads:",scalar_mult_grad(1/Batch_size,grads_acc))
    #print("Before update",params)
    params = optax.apply_updates(params, updates)
    #print("After_update",params)
    #print(jax.tree_multimap(lambda x,y:100*(x-y)/y,params,old_params))
    if i % Print_freq == 0:
        if((i+1)%Val_freq==0):
            print('Val-Loss step {}: '.format(i), Val_Batch_loss/Batch_size)
        print('Loss step {}:{} '.format(i, Batch_loss/Batch_size))
        #print('P.E {}: '.format(Systate.pe))
        #fig,ax=plt.subplots(1,1,figsize=(10,10))
        #Sys.plot_frame_edge(ax,Systate)
        #fig.savefig("./Plots/System_"+str(i)+"_"+str(Systate.pe)+"_plot"+".png")
        #plt.close(fig)
    if(i%Model_save_freq==0):
        ckp.save_checkpoint("./checkpoints/",params,i,overwrite=True,keep=400)



#GraphNo.  1

Step	Max_Mu	Mean_Mu	   Max_Disp  Mean_Disp	Log_Total_prob	Reward	 d_PE	  PE
1 	 0.07082   0.06177   0.07082   0.06177  -15.00247  -17338.77741  -122.73729  -1051282.42417
2 	 0.07082   0.06173   0.07082   0.06173  -15.72501  -17226.46423  -112.31317  -1051405.16146
3 	 0.07082   0.06165   0.07082   0.06165  -14.42475  -17285.71634  59.25211  -1051517.47463
4 	 0.07082   0.06164   0.07082   0.06164  -16.34712  -17444.24633  158.52999  -1051458.22253
5 	 0.07082   0.06152   0.07082   0.06152  -13.82348  -17774.33139  330.08506  -1051299.69253

#GraphNo.  2

Step	Max_Mu	Mean_Mu	   Max_Disp  Mean_Disp	Log_Total_prob	Reward	 d_PE	  PE
1 	 0.07082   0.06195   0.07082   0.06195  -13.82765  -25327.09366  -132.84095  -1032085.82260
2 	 0.07082   0.06189   0.07082   0.06189  -12.54239  -25348.47027  21.37660  -1032218.66356
3 	 0.07082   0.06191   0.07082   0.06191  -14.63570  -25368.07904  19.60878  -1032197.28695
4 	 0.07082   0.06181   0.07082   0.06181  -14.75977  -25480.31708

10 	 0.07082   0.06150   0.07082   0.06150  -14.86562  -93298.96259  759.99255  -1031724.60938
11 	 0.07082   0.06110   0.07082   0.06110  -12.47190  -94118.86392  819.90133  -1030964.61683
12 	 0.07082   0.06162   0.07082   0.06162  -13.70016  -94943.49003  824.62611  -1030144.71550
13 	 0.07082   0.06136   0.07082   0.06136  -17.50707  -95953.52270  1010.03266  -1029320.08938
14 	 0.07082   0.06144   0.07082   0.06144  -14.63734  -97151.35415  1197.83145  -1028310.05672
15 	 0.07082   0.06127   0.07082   0.06127  -15.15121  -98168.64706  1017.29291  -1027112.22527

#GraphNo.  3

Step	Max_Mu	Mean_Mu	   Max_Disp  Mean_Disp	Log_Total_prob	Reward	 d_PE	  PE
1 	 0.07082   0.06140   0.07082   0.06140  -16.39374  -75634.16193  -126.20011  -1056429.22242
2 	 0.07082   0.06139   0.07082   0.06139  -14.39656  -75479.07686  -155.08507  -1056555.42253
3 	 0.07082   0.06147   0.07082   0.06147  -15.03966  -75482.62904   3.55218  -1056710.50759
4 	 0.07082   0.06157   0.07082   0.06157  -13.18098 

KeyboardInterrupt: 

In [None]:
# coding=utf-8
# Copyright 2021 The Learn2hop Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Re-creation of jax_md/nn.py to allow for vmap-enabled functions.

This file replicates jax_md/nn.py which contains neural network primatives, but
extends a number of the functions in order to allow them be vmapped. For
example, the current implementation of the radial basis functions in nn.py
uses a gather of species indexes which does not compose with vmaps. To get
this functionality, we reproduce the code using masks.
"""

import jax
import jax.numpy as jnp


@jax.jit
def _behler_parrinello_cutoff_fn(dr, cutoff_distance = 8.0):
    """Function of pairwise distance that smoothly goes to zero at the cutoff."""
    return jnp.where((dr < cutoff_distance) & (dr > 1e-7),0.5 * (jnp.cos(np.pi * dr / cutoff_distance) + 1), 0)


def radial_symmetry_functions(metric_mapped,
                              etas,
                              cutoff_distance,
                              num_species = 1):
    """Returns a function that computes radial symmetry functions.

        This is a re-implementation of the radial symmetry functions within
        nn.py but allows for vmapping by making use of input masking rather than
        array indexing.

        Args:
        metric_mapped: displacement function that computes distances in
          spatial positions between all pairs of atoms
        etas: [num_etas], list corresponding to strength of interaction terms
        cutoff_distance: neighbors whose distance is larger than cutoff_distance do
          not contribute to each others symmetry functions. The contribution of a
          neighbor to the symmetry function and its derivative goes to zero at this
          distance. [0.0009, 0.01, 0.02, 0.035, 0.06, 0.1, 0.2, 0.4]
        num_species: total number of species

        Returns:
        A function that computes the radial symmetry fucntions from inputs, yielding
        output of shape [num_atoms,num_etas* num_species] to maintain the type
        consistency from nn.py
        """
    def radial_fn(eta, dr):
        cutoffs = _behler_parrinello_cutoff_fn(dr, cutoff_distance)
        return jnp.exp(-eta * dr**2) * cutoffs

    #@jax.jit
    def compute_fun(positions, species):
        def return_radial(atom_type):
            mask = species == atom_type
            dr = metric_mapped(positions, positions)
            radial = jax.vmap(radial_fn, (0, None))(etas, dr)
            radial_masked = radial * mask.reshape([1, -1, 1])
            return jnp.sum(radial_masked, axis=-1)

        radial_vmap = jax.vmap(return_radial)
        radial_symmetry = radial_vmap(np.arange(num_species))
        radial_symmetry = jnp.transpose(radial_symmetry, (2, 0, 1))
        return radial_symmetry.reshape([radial_symmetry.shape[0], -1])

    return compute_fun


In [None]:
plt.plot(_behler_parrinello_cutoff_fn(jnp.arange(0.001,10,0.01), cutoff_distance = 8.0))

In [None]:
etas=jnp.array([0.001, 0.01 ,0.05, 0.1, 0.2, 0.4,1,2])
radial_feats_fn=radial_symmetry_functions(pair_dist_fn,etas,2.5,1)
radial_feats_fn=jax.jit(jax.vmap(radial_feats_fn,(0,0)))

In [None]:
def pair_dist_fn(R1,R2):
    dR = Disp_Vec_fn(R1, R2)
    dr = space.distance(dR)
    return dr
pair_dist_fn=jax.jit(pair_dist_fn)

In [None]:
Test_R=Train[0].R[:]
Test_species=jnp.zeros(Train[0].species[:].shape)

In [None]:
Test_species

In [None]:
Test_R.shape

In [None]:
pair_dist_fn(Test_R,Test_R).shape

In [None]:
radial_feats=jax.jit(jax.vmap(radial_feats_fn,(0,0)))(Test_R,Test_species)

In [None]:
radial_feats.shape

In [None]:
radial_feats[:,0,:]

In [None]:
radial_feats[:10,1,:]

In [None]:
radial_feats[:10]

In [None]:
import seaborn as sns

In [None]:
sns.distplot(radial_feats[:,0,0])

In [None]:
radial_feats

In [None]:

#Training and Validation

Train_loss_data   =  []
Train_loss_epochs =  []
Val_loss_data     =  []
Val_loss_epochs   =  []
Batch_size        =  3
len_ep            =  10
Max_dataset_size  =  1000
Print_freq        =  1
plot_freq         =  100
Model_save_freq   =  10
N_epoch           =  800
Val_freq          =  20
Reset_dataset_every= 1
keys1=random.split(random.PRNGKey(117721817),N_epoch)
for i in range(N_epoch):
    if((i+1)%Reset_dataset_every==0):
        print("Re creating dataset with new states")
        print("Curr_Traj_shape:",Traj.shape)
        #Sampling the accumulated trajectories
        Sampling_Indeces=jax.random.permutation(random.PRNGKey(961*(i+1)), jnp.arange(0,Traj.shape[0],1)) 
        Traj=Traj[:min(Max_dataset_size,Traj.shape[0]),...]
        Train,Val,Test, shift_fn,Batch_pair_cutoffs,Batch_pair_sigma,Batch_Disp_Vec_fn,Batch_Node_energy_fn,Batch_Total_energy_fn=Sys.create_batched_States(random.PRNGKey(147),System='LJ',spatial_dimension=3,N=100, N_sample =Traj.shape[0],Batch_size=4,Traj=Traj)        
    grads_acc=scalar_mult_grad(0.0,init_grad)
    keys2=random.split(keys1[i],Batch_size)
    
    Batch_loss=0
    for p in range(Batch_size):
        ((loss_val,(Systate,log)),grads) = loss_grad_fn(params,Train[(Batch_size*i+p)%len(Train)],keys2[p],len_ep=len_ep)
        print("Batch: ",p,' ',loss_val)
        print_log(log,is_plot=False,epoch_id=i,Batch_id=p)
        Batch_loss+=loss_val
        grads_acc=add_grads(grads_acc,grads)
        if(i>0):
            if(jnp.sum(Train[(Batch_size*i+p)%len(Train)].pe)<jnp.sum(Systate.pe)):
                    Traj=jnp.concatenate([Traj,Systate.R])
    
    if((i+1)%Val_freq==0):
        Val_Batch_loss=0
        for p in range(Batch_size):
            ((val_loss_val,(val_Systate,val_log)),grads) = loss_grad_fn(params,Val[(Batch_size*i+p)%len(Val)],keys2[p],len_ep=20)
            print_log(val_log,is_plot=False,epoch_id=i,Batch_id=p)
            print("D_PE_sum:",jnp.sum(val_Systate.pe-Val[(Batch_size*i+p)%len(Val)].pe))
        
            Val_loss_data+=[Val_Batch_loss/Batch_size]
            Val_loss_epochs+=[i]
            Val_Batch_loss+=val_loss_val
        
    updates, opt_state = tx.update(scalar_mult_grad(1/Batch_size,grads_acc), opt_state,params)
    Train_loss_data+=[Batch_loss/Batch_size]
    Train_loss_epochs+=[i]
    #old_params=params
    #print("myGrads:",scalar_mult_grad(1/Batch_size,grads_acc))
    #print("Before update",params)
    params = optax.apply_updates(params, updates)
    #print("After_update",params)
    #print(jax.tree_multimap(lambda x,y:100*(x-y)/y,params,old_params))
    if i % Print_freq == 0:
        if((i+1)%Val_freq==0):
            print('Val-Loss step {}: '.format(i), Val_Batch_loss/Batch_size)
        print('Loss step {}:{} '.format(i, Batch_loss/Batch_size))
        #print('P.E {}: '.format(Systate.pe))
        #fig,ax=plt.subplots(1,1,figsize=(10,10))
        #Sys.plot_frame_edge(ax,Systate)
        #fig.savefig("./Plots/System_"+str(i)+"_"+str(Systate.pe)+"_plot"+".png")
        #plt.close(fig)
    if(i%Model_save_freq==0):
        ckp.save_checkpoint("./checkpoints/",params,i,overwrite=True,keep=400)




