In [1]:
import qmctorch

INFO:QMCTorch|  ____    __  ______________             _
INFO:QMCTorch| / __ \  /  |/  / ___/_  __/__  ________/ /  
INFO:QMCTorch|/ /_/ / / /|_/ / /__  / / / _ \/ __/ __/ _ \ 
INFO:QMCTorch|\___\_\/_/  /_/\___/ /_/  \___/_/  \__/_//_/ 


# Graph Neural Networks

There has been a lot of work done on graph neural networks. See for example the Deep Graph Library (https://www.dgl.ai/) and its application to chemistry https://github.com/awslabs/dgl-lifesci 

In particular the paper Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective (https://arxiv.org/abs/1906.11081) already implemented in dgl-lifesci (https://github.com/awslabs/dgl-lifesci/blob/master/examples/README.md) offers an interesting way of extending the defintion of the Jastrow Factors

# GNN Jastrow Factors


Instead of defining the Electorn-Electron Jastrow factor through the Pade Jastrow or the FullyConnected Netowrk we can consider the different connection graphs and use these graph as an input of graph network. 



W can first consider the connection graph between all the elctrons. In this input graph each node represent a given electron and an edge exists between all electron pairs. The distance between two electron can be used as an edge feature to encode the relative positions of the electrons

We can also consider the connection graphs between the electrons and the nuclei. In this graph each electron is represented by a node and each atom is also represented by a node. Edges exists only between electron and atoms biut not between electron pairs (we can optionally consider edges between nuclei)

Expressing the structure of the electron/nuclei system as a graph allows expressing different interactions, e.g. elec-elec terms, elec-elec-elec termsn, elec-nuclei, elec-elec-nuclei, etc ... in a very flexible way through convolution over the graphs. (see https://arxiv.org/abs/1906.11081)

The `JastrowFactorGraph` orchestrate the calculation of such Jastrow factor and accept different graph neural network for the elec-elec graphs and the elec-nuc graphs

In [5]:
import torch
from qmctorch.wavefunction.jastrows.graph.jastrow_graph import JastrowFactorGraph
from qmctorch.wavefunction.jastrows.graph.mgcn.mgcn_predictor import MGCNPredictor

nup, ndown = 2, 2
nelec = nup + ndown

atom_types = ["Li", "H"]
atomic_pos = torch.tensor([[0., 0., 0.],
                           [0., 0., 3.015]])


jastrow = JastrowFactorGraph(nup, ndown,
                                  atomic_pos,
                                  atom_types,
                                  ee_model=MGCNPredictor,
                                  ee_model_kwargs={'n_layers': 3,
                                                   'feats': 32,
                                                   'cutoff': 5.0,
                                                   'gap': 1.},
                                  en_model=MGCNPredictor,
                                  en_model_kwargs={'n_layers': 3,
                                                   'feats': 32,
                                                   'cutoff': 5.0,
                                                   'gap': 1.0})


# SlaterJastrow wave function with MGCN

The `SlaterJastrowGraph` class allows using GNN Jastrows in Slater-Jastrow wave function 

In [10]:
import torch
from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrowGraph
from qmctorch.wavefunction.jastrows.graph.mgcn.mgcn_predictor import MGCNPredictor


mol = Molecule(
    atom='Li 0 0 0; H 0 0 3.14',
    unit='bohr',
    calculator='pyscf',
    basis='sto-3g',
    redo_scf=True)

wf = SlaterJastrowGraph( mol,
                         kinetic='auto',
                         include_all_mo=False,
                         configs='single_double(2,2)',
                         ee_model=MGCNPredictor,
                         ee_model_kwargs={'n_layers': 3,
                                          'feats': 32,
                                          'cutoff': 5.0,
                                          'gap': 1.},
                         en_model=MGCNPredictor,
                         en_model_kwargs={'n_layers': 3,
                                          'feats': 32,
                                          'cutoff': 5.0,
                                          'gap': 1.0})

nbatch = 10
pos = torch.rand(nbatch,  wf.nelec*3)

wfval = wf(pos)

INFO:QMCTorch|
INFO:QMCTorch| SCF Calculation
INFO:QMCTorch|  Removing LiH_pyscf_sto-3g.hdf5 and redo SCF calculations
INFO:QMCTorch|  Running scf  calculation
converged SCF energy = -7.85928101642664
INFO:QMCTorch|  Molecule name       : LiH
INFO:QMCTorch|  Number of electrons : 4
INFO:QMCTorch|  SCF calculator      : pyscf
INFO:QMCTorch|  Basis set           : sto-3g
INFO:QMCTorch|  SCF                 : HF
INFO:QMCTorch|  Number of AOs       : 6
INFO:QMCTorch|  Number of MOs       : 6
INFO:QMCTorch|  SCF Energy          : -7.859 Hartree
INFO:QMCTorch|
INFO:QMCTorch| Wave Function
INFO:QMCTorch|  Jastrow factor      : False
INFO:QMCTorch|  Highest MO included : 3
INFO:QMCTorch|  Configurations      : single_double(2,2)
INFO:QMCTorch|  Number of confs     : 4
INFO:QMCTorch|  Kinetic energy      : auto
INFO:QMCTorch|  Number var  param   : 37
INFO:QMCTorch|  Cuda support        : False
INFO:QMCTorch|
INFO:QMCTorch| Wave Function
INFO:QMCTorch|  Jastrow factor      : True
INFO:QMCTorch|

# TODO Explore the different architectures of MGCN 

As for the Fully connected networks, it would be great assess the performance of MGCN jastrow factors in predicting the total energy of the test molecules (H2, LiH, Li2, N2). Of course MGCN are only one of the possible options and we can also define new GNN to compute the jastrows.