In [1]:
from  deepqmc.wf.paulinet.cusp import CuspCorrection, ElectronicAsymptotic
from  deepqmc.wf.paulinet.gto import GTOBasis
from  deepqmc.wf.paulinet.molorb import MolecularOrbital
from  deepqmc.wf.paulinet.omni import OmniSchNet
from  deepqmc.wf.paulinet.pyscfext import pyscf_from_mol
#from deepqmc.wf import PauliNet



from deepqmc import Molecule
from deepqmc.physics import pairwise_diffs, pairwise_distance
from deepqmc.plugins import PLUGINS
from deepqmc.torchext import sloglindet, triu_flat
from deepqmc.wf import WaveFunction

from functools import partial

import torch
from torch import nn

from deepqmc.torchext import SSP, get_log_dnn

from deepqmc.wf.paulinet.schnet import ElectronicSchNet, SubnetFactory

#__version__ = '0.3.0'
#__all__ = ['OmniSchNet']
import numpy 

# %load test.py
import torch 
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, radius_graph
from torch_geometric.data import Batch
from torch_geometric.data import DataLoader

import ase
import torch.nn as nn
import torch.nn.functional as Func
from torch.nn import Embedding, Sequential, Linear, ModuleList, Module
import numpy as np
import math

from PaiNN import PaiNNElecNuc
from message import MessagePassPaiNN_NE
from torch_geometric.nn import radius

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
from deepqmc import Molecule

mol = Molecule.from_name('LiH')

### Jatrow, Backflow, Bipartite, Batching

In [16]:
class Jastrow(nn.Module):

    def __init__(
        self, embedding_dim, activation_factory=SSP, *, n_layers=3, sum_first=True
    ):
        super().__init__()
        self.net = get_log_dnn(embedding_dim, 1, activation_factory, n_layers=n_layers)
        self.sum_first = sum_first

    def forward(self, xs):
        if self.sum_first:
            xs = self.net(xs.sum(dim=-2))
        else:
            xs = self.net(xs).sum(dim=-2)
        return xs.squeeze(dim=-1)
    
class Bipartite(Data):
    def __init__(self, edge_index, coord_elec, coord_nuc,s_nuc,v_nuc,num_nodes):
        super(Bipartite, self).__init__()
        self.edge_index = edge_index
        self.coord_elec = coord_elec
        self.coord_nuc = coord_nuc
        self.s_nuc = s_nuc
        self.v_nuc = v_nuc
        self.num_nodes = num_nodes
    def __inc__(self, key, value):
        if key == 'edge_index':
            return torch.tensor([[self.coord_nuc.size(0)], [self.coord_elec.size(0)]])
        else:
            return super().__inc__(key, value)
        
class BatchGraphNuc(nn.Module):
    def __init__(self, dim=1):
        super(BatchGraphNuc, self).__init__()
        self.dim = dim
        
    def forward(self,  s_nuc,v_nuc, coord_elec, coord_nuc):
        
        batch_dim, n_elec = coord_elec.shape[:2]
        
        
        edge_attr =  (coord_elec[..., :, None, :] - coord_nuc[..., None, :, :]).reshape(-1,3)
        coord_nuc = coord_nuc.repeat(batch_dim,1,1)
        
        data_list = [Bipartite(radius(e,n,5.0),e,n,sn,vn, n_elec) 
                     for e,n,sn,vn in zip(coord_elec, coord_nuc, s_nuc,v_nuc)]
        
        loader = DataLoader(data_list, batch_size=batch_dim)
        batch = next(iter(loader))
        
        return (batch.s_nuc, batch.v_nuc, batch.edge_index, edge_attr)
    
class BatchGraphElec(nn.Module):

    def __init__(self,cut_off=5.0):
        super(BatchGraphElec, self).__init__()
        self.cut_off = cut_off

    def forward(self,s, v, rs):
        # rs are converted to edge_attributes
        # num_elec = num_nodes
        batch_dim, n_elec = rs.shape[:2] 
        data = Batch.from_data_list([Data(x=s, v=v, r=r) for s, v, r in zip(s, v, rs)])
        
        
        batch_edge_index = radius_graph(data.r, r=self.cut_off, batch=data.batch, loop=False)

        batch_row, batch_col = batch_edge_index
        batch_edge_attr = data.r[batch_row] - data.r[batch_col]

        return data.x, data.v, batch_edge_index, batch_edge_attr
    
class Backflow(nn.Module):


    def __init__(
        self,
        embedding_dim,
        n_orbitals,
        n_backflows,
        activation_factory=SSP,
        *,
        n_layers=3,
    ):
        super().__init__()
        nets = [
            get_log_dnn(
                embedding_dim,
                n_orbitals,
                activation_factory,
                n_layers=n_layers,
                last_bias=True,
            )
            for _ in range(n_backflows)
        ]
        self.nets = nn.ModuleList(nets)

    def forward(self, xs):
        return torch.stack([net(xs) for net in self.nets], dim=1) ## Backup Backflow function!!!
    
class BackflowPaiNN(nn.Module):


    def __init__(
        self,
        embedding_dim,
        n_backflows,
        num_electrons
    ):
        super().__init__()
         
        self.net = nn.Sequential(
            Linear(embedding_dim, embedding_dim),
            Linear(embedding_dim, 1)
        )
            

    def forward(self, xs):
        return torch.squeeze(self.net(xs))
    

In [36]:
class OmniPaiNN(nn.Module):

    def __init__(
        self,
        n_nuc,
        n_up,
        n_down,
        n_orbitals,
        n_backflows,
        *,
        embedding_dim=128,
        num_nodes = 4,
        cut_off=5.0,
        n_rbf=20,
        num_interactions=3,
        ):
        
        super().__init__()
        self.ElectronicPaiNN = PaiNNElecNuc(embedding_dim, 
                           embedding_dim, 
                           num_nodes)
        
        self.jastrow = Jastrow(embedding_dim,sum_first=True)
        
        self.backflow = BackflowPaiNN(
                embedding_dim,
                n_backflows,
                n_up + n_down
                ) 
        
        self.batch_E = BatchGraphElec(cut_off)
        self.batch_N = BatchGraphNuc(cut_off)
        self.spin_idxs = torch.tensor(
            (n_up + n_down) * [0] if n_up == n_down else n_up * [0] + n_down * [1])
        
        self.nuc_idxs = torch.arange(n_nuc)
        self.X = nn.Embedding(1 if n_up == n_down else 2, embedding_dim)
        self.Y = nn.Embedding(n_nuc, embedding_dim)
        
        self.eb = embedding_dim
            
    def forward(self,rs,rn):
        # Took out elect_dists and nuc_dists for try
        batch_dim, n_elec = rs.shape[:2]
        n_nuc = rn.shape[0]
        
        # Initializing Scalars and Vectors
        s_e = self.X(self.spin_idxs.repeat(batch_dim,1))
        s_n = self.Y(self.nuc_idxs.repeat(batch_dim,1))
        
        v_e = torch.zeros(batch_dim,n_elec,self.eb,3, dtype=torch.float)
        v_n = torch.zeros(batch_dim,n_nuc,self.eb,3, dtype=torch.float)

        # Creating Batches for e-e-Graph and e-N-Graph
        s_e, v_e, edge_index, edge_attr = self.batch_E(s_e, v_e, rs)
        s_n, v_n, edge_index_n, edge_attr_n = self.batch_N(s_n, v_n, rs,rn)
        
        scalars, vectors = self.ElectronicPaiNN(s_e, v_e, s_n,v_n,edge_index, 
                                        edge_attr, edge_index_n, edge_attr_n)
        
        
        scalars = scalars.reshape(batch_dim, n_elec, -1)
        vectors = vectors.reshape(batch_dim, n_elec, self.eb, -1)
        
        jastrow = self.jastrow(scalars)
        backflow = self.backflow(torch.transpose(vectors,-2,-1))
        
        return jastrow, None, backflow

In [43]:
def eval_slater(xs):
    if xs.shape[-1] == 0:
        return xs.new_ones(xs.shape[:-2])
    return torch.det(xs.contiguous())


def eval_log_slater(xs):
    if xs.shape[-1] == 0:
        return xs.new_ones(xs.shape[:-2]), xs.new_zeros(xs.shape[:-2])
    return xs.contiguous().slogdet()


class PauliNet(WaveFunction):
    

    OMNI_FACTORIES = {'omni_schnet': OmniSchNet, 'omni_paiNN': OmniPaiNN}

    def __init__(
        self,
        mol,
        basis,
        n_configurations=1,
        n_orbitals=None,
        return_log=True,
        use_sloglindet='training',
        *,
        cusp_correction=True,
        cusp_electrons=True,
        backflow_type='orbital',
        backflow_channels=1,
        backflow_transform='mult',
        rc_scaling=1.0,
        cusp_alpha=10.0,
        freeze_embed=False,
        omni_factory='omni_paiNN', #omni_schnet
        omni_kwargs=None,
        cut_off=1.5,
        n_rbf=20,
        num_interactions=3,
        mb_embedding_dim=128
    ):
        assert use_sloglindet in {'never', 'training', 'always'}
        assert return_log or use_sloglindet == 'never'
        super().__init__(mol)
        n_up, n_down = self.n_up, self.n_down
        n_orbitals = n_orbitals or max(n_up, n_down)
        confs = [list(range(n_up)) + list(range(n_down))] + [
            sum((torch.randperm(n_orbitals)[:n].tolist() for n in (n_up, n_down)), [])
            for _ in range(n_configurations - 1)
        ]
        self.register_buffer('confs', torch.tensor(confs))
        self.conf_coeff = (
            nn.Linear(n_configurations, 1, bias=False)
            if n_configurations > 1
            else nn.Identity()
        )
     
        
        
        self.mo = MolecularOrbital(
            mol,
            basis,
            n_orbitals,
            cusp_correction=cusp_correction,
            rc_scaling=rc_scaling,
        )
        self.cusp_same, self.cusp_anti = (
            (ElectronicAsymptotic(cusp=cusp, alpha=cusp_alpha) for cusp in (0.25, 0.5))
            if cusp_electrons
            else (None, None)
        )
        backflow_spec = {
            'orbital': [n_orbitals, backflow_channels],
            'det': [max(n_up, n_down), len(self.confs) * backflow_channels],
        }[backflow_type]
        if backflow_transform == 'both':
            backflow_spec[1] *= 2
        self.backflow_type = backflow_type
        self.backflow_transform = backflow_transform
        if 'paulinet.omni_factory' in PLUGINS:
            log.info('Using a plugin for paulinet.omni_factory')
            omni_factory = PLUGINS['paulinet.omni_factory']
        elif isinstance(omni_factory, str):
            if omni_kwargs:
                omni_kwargs = omni_kwargs[omni_factory]
            omni_factory = self.OMNI_FACTORIES[omni_factory]
            
        #if omni_factory == 'omni_paiNN'
        self.omni = (
            omni_factory(
                len(mol.coords), n_up, n_down, *backflow_spec, **(omni_kwargs or {})
            )
            if omni_factory
            else None
        )
        
        self.return_log = return_log
        if freeze_embed:
            self.requires_grad_embeddings_(False)
        self.n_determinants = len(self.confs) * backflow_channels
        if n_up <= 1 or n_down <= 1:
            self.use_sloglindet = 'never'
            log.warning(
                'Setting use_sloglindet to "never" as not implemented for n=0 and n=1.'
            )
        # TODO implement sloglindet for special cases n=0 and n=1
        else:
            self.use_sloglindet = use_sloglindet
            
          ###################################################     
#         self.omni_painn = OmniPaiNN(
#         len(mol.coords),
#         n_up,
#         n_down,
#         n_orbitals,
#         *backflow_spec, **(omni_kwargs or {})

    def requires_grad_classes_(self, classes, requires_grad):
        for m in self.modules():
            if isinstance(m, classes):
                for p in m.parameters(recurse=False):
                    p.requires_grad_(requires_grad)
        return self

    def requires_grad_cusps_(self, requires_grad):
        return self.requires_grad_classes_(CuspCorrection, requires_grad)

    def requires_grad_embeddings_(self, requires_grad):
        return self.requires_grad_classes_(nn.Embedding, requires_grad)

    def requires_grad_nets_(self, requires_grad):
        return self.requires_grad_classes_(nn.Linear, requires_grad)

    @classmethod
    def DEFAULTS(cls):
        from .omni import Backflow, Jastrow
        from .schnet import ElectronicSchNet, SubnetFactory

        return {
            (cls.from_hf, 'kwargs'): cls.from_pyscf,
            (cls.from_pyscf, 'kwargs'): cls,
            (cls, 'omni_kwargs'): cls.OMNI_FACTORIES,
            (OmniSchNet, 'schnet_kwargs'): ElectronicSchNet,
            (OmniSchNet, 'mf_schnet_kwargs'): (ElectronicSchNet, ['version']),
            (OmniSchNet, 'subnet_kwargs'): SubnetFactory,
            (OmniSchNet, 'mf_subnet_kwargs'): SubnetFactory,
            (OmniSchNet, 'jastrow_kwargs'): Jastrow,
            (OmniSchNet, 'backflow_kwargs'): Backflow,
        }

    @classmethod
    def from_pyscf(
        cls,
        mf,
        *,
        init_weights=True,
        freeze_mos=True,
        freeze_confs=False,
        conf_cutoff=1e-2,
        conf_limit=None,
        **kwargs,
    ):
        r"""Construct a :class:`PauliNet` instance from a finished PySCF_ calculation.

        Args:
            mf (:class:`pyscf.scf.hf.RHF` | :class:`pyscf.mcscf.mc1step.CASSCF`):
                restricted (multireference) HF calculation
            init_weights (bool): whether molecular orbital coefficients and
                configuration coefficients are initialized from the HF calculation
            freeze_mos (bool): whether the MO coefficients are frozen for
                gradient optimization
            freeze_confs (bool): whether the configuration coefficients are
                frozen for gradient optimization
            conf_cutoff (float): determinants with a linear coefficient above
                this threshold are included in the determinant expansion
            conf_limit (int): if given, at maximum the given number of configurations
                with the largest linear coefficients are used in the ansatz
            kwargs: all other arguments are passed to the :class:`PauliNet`
                constructor

        .. _PySCF: http://pyscf.org
        """
        assert not (set(kwargs) & {'n_configurations', 'n_orbitals'})
        n_up, n_down = mf.mol.nelec
        if hasattr(mf, 'fcisolver'):
            if conf_limit:
                conf_cutoff = max(
                    np.sort(abs(mf.ci.flatten()))[-conf_limit:][0] - 1e-10, conf_cutoff
                )
            for tol in [conf_cutoff, conf_cutoff + 2e-10]:
                conf_coeff, *confs = zip(
                    *mf.fcisolver.large_ci(
                        mf.ci, mf.ncas, mf.nelecas, tol=tol, return_strs=False
                    )
                )
                if not conf_limit or len(conf_coeff) <= conf_limit:
                    break
            else:
                raise AssertionError()
            # discard the last ci wave function if degenerate
            ns_dbl = n_up - mf.nelecas[0], n_down - mf.nelecas[1]
            conf_coeff = torch.tensor(conf_coeff)
            confs = [
                [
                    torch.arange(n_dbl, dtype=torch.long).expand(len(conf_coeff), -1),
                    torch.tensor(cfs, dtype=torch.long) + n_dbl,
                ]
                for n_dbl, cfs in zip(ns_dbl, confs)
            ]
            confs = [torch.cat(cfs, dim=-1) for cfs in confs]
            confs = torch.cat(confs, dim=-1)
            kwargs['n_configurations'] = len(confs)
            kwargs['n_orbitals'] = confs.max().item() + 1
        else:
            confs = None
        mol = Molecule(
            mf.mol.atom_coords().astype('float32'),
            mf.mol.atom_charges(),
            mf.mol.charge,
            mf.mol.spin,
        )
        basis = GTOBasis.from_pyscf(mf.mol)
        wf = cls(mol, basis, **kwargs)
        if init_weights:
            wf.mo.init_from_pyscf(mf, freeze_mos=freeze_mos)
            if confs is not None:
                wf.confs.detach().copy_(confs)
                if len(confs) > 1:
                    wf.conf_coeff.weight.detach().copy_(conf_coeff)
                if freeze_confs:
                    wf.conf_coeff.weight.requires_grad_(False)
        return wf


    @classmethod
    def from_hf(cls, mol, *, basis='6-311g', cas=None, workdir=None, **kwargs):
        r"""Construct a :class:`PauliNet` instance by running a HF calculation.

        This is the top-level interface.

        Args:
            mol (:class:`~deepqmc.Molecule`): molecule whose wave function
                is represented
            basis (str): basis of the internal HF calculation
            cas ((int, int)): tuple of the number of active orbitals and number of
                active electrons for a complete active space multireference
                HF calculation
            workdir (str): path where PySCF calculations are cached
            kwargs: all other arguments are passed to :func:`PauliNet.from_pyscf`
        """
        mf, mc = pyscf_from_mol(mol, basis, cas, workdir)
        assert bool(cas) == bool(mc)
        wf = PauliNet.from_pyscf(mc or mf, **kwargs)
        wf.mf = mf
        return wf


    def pop_charges(self):
        try:
            mf = self.mf
        except AttributeError:
            return super().pop_charges()
        return self.mol.charges.new(mf.pop(verbose=0)[1])

    def _backflow_op(self, xs, fs):
        if self.backflow_transform == 'mult':
            fs_mult, fs_add = fs, None
            print(fs_mult.shape)
        elif self.backflow_transform == 'add':
            fs_mult, fs_add = None, fs
        elif self.backflow_transform == 'both':
            fs_mult, fs_add = fs[:, : fs.shape[1] // 2], fs[:, fs.shape[1] // 2 :]
        if fs_add is not None:
            envel = (xs ** 2).mean(dim=-1, keepdim=True).sqrt()
        if fs_mult is not None:
            xs = xs * (1 + 2 * torch.tanh(fs_mult / 4))
        if fs_add is not None:
            xs = xs + 0.1 * envel * torch.tanh(fs_add / 4)
        return xs
#################################################################################################
    def forward(self, rs):  # noqa: C901
        batch_dim, n_elec = rs.shape[:2]
        
        
        assert n_elec == self.confs.shape[1]
        n_atoms = len(self.mol)
        coords = self.mol.coords
        
        J, fs, ps = self.omni(rs, coords) # ps: particle shift
        rs = rs + ps
        
        diffs_nuc = pairwise_diffs(torch.cat([coords, rs.flatten(end_dim=1)]), coords)
        dists_elec = pairwise_distance(rs, rs)
        if self.omni:
            dists_nuc = (
                diffs_nuc[n_atoms:, :, 3].sqrt().view(batch_dim, n_elec, n_atoms)
            )
            
                 
        xs = self.mo(diffs_nuc)
        # get orbitals as [bs, 1, i, mu]
        xs = xs.view(batch_dim, 1, n_elec, -1)
        # get jastrow J and backflow fs (as [bs, q, i, mu/nu])
        #J, fs = self.omni(dists_nuc, dists_elec) if self.omni else (None, None)
        if fs is not None and self.backflow_type == 'orbital':
            xs = self._backflow_op(xs, fs)
        # form dets as [bs, q, p, i, nu]
        conf_up, conf_down = self.confs[:, : self.n_up], self.confs[:, self.n_up :]
        det_up = xs[:, :, : self.n_up, conf_up].transpose(-3, -2)
        det_down = xs[:, :, self.n_up :, conf_down].transpose(-3, -2)
        if fs is not None and self.backflow_type == 'det':
            n_conf = len(self.confs)
            fs = fs.unflatten(1, ((None, fs.shape[1] // n_conf), (None, n_conf)))
            det_up = self._backflow_op(det_up, fs[..., : self.n_up, : self.n_up])
            det_down = self._backflow_op(det_down, fs[..., self.n_up :, : self.n_down])
            # with open-shell systems, part of the backflow output is not used
        if self.use_sloglindet == 'always' or (
            self.use_sloglindet == 'training' and not self.sampling
        ):
            bf_dim = det_up.shape[-4]
            if isinstance(self.conf_coeff, nn.Linear):
                conf_coeff = self.conf_coeff.weight[0]
                conf_coeff = conf_coeff.expand(bf_dim, -1).flatten() / np.sqrt(bf_dim)
            else:
                conf_coeff = det_up.new_ones(1)
            det_up = det_up.flatten(start_dim=-4, end_dim=-3).contiguous()
            det_down = det_down.flatten(start_dim=-4, end_dim=-3).contiguous()
            sign, psi = sloglindet(conf_coeff, det_up, det_down)
            sign = sign.detach()
        else:
            if self.return_log:
                sign_up, det_up = eval_log_slater(det_up)
                sign_down, det_down = eval_log_slater(det_down)
                xs = det_up + det_down
                xs_shift = xs.flatten(start_dim=1).max(dim=-1).values
                # the exp-normalize trick, to avoid over/underflow of the exponential
                xs = sign_up * sign_down * torch.exp(xs - xs_shift[:, None, None])
            else:
                det_up = eval_slater(det_up)
                det_down = eval_slater(det_down)
                xs = det_up * det_down
            psi = self.conf_coeff(xs).squeeze(dim=-1).mean(dim=-1)
            if self.return_log:
                psi, sign = psi.abs().log() + xs_shift, psi.sign().detach()
        if self.cusp_same:
            cusp_same = self.cusp_same(
                torch.cat(
                    [triu_flat(dists_elec[:, idxs, idxs]) for idxs in self.spin_slices],
                    dim=1,
                )
            )
            cusp_anti = self.cusp_anti(
                dists_elec[:, : self.n_up, self.n_up :].flatten(start_dim=1)
            )
            psi = (
                psi + cusp_same + cusp_anti
                if self.return_log
                else psi * torch.exp(cusp_same + cusp_anti)
            )
        if J is not None:
            psi = psi + J if self.return_log else psi * torch.exp(J)
        return (psi, sign) if self.return_log else psi

In [6]:
# def pairwise_distance(coords1, coords2):
#     return (coords1[..., :, None, :] - coords2[..., None, :, :]).norm(dim=-1)

# def pairwise_diffs(coords1, coords2, axes_offset=True):
#     diffs = coords1[..., :, None, :] - coords2[..., None, :, :]
#     if axes_offset:
#         diffs = offset_from_axes(diffs)
#     return torch.cat([diffs, (diffs ** 2).sum(dim=-1, keepdim=True)], dim=-1)
# def offset_from_axes(rs):
#     eps = rs.new_tensor(100 * torch.finfo(rs.dtype).eps)
#     offset = torch.where(rs < 0, -eps, eps)
#     return torch.where(rs.abs() < eps, rs + offset, rs)

In [29]:
mol = Molecule(  # LiH
    coords=[[0.0, 0.0, 0.0], [3.0, 0.0, 0.0]],
    charges=[3, 1],
    charge=0,
    spin=0,
)

In [44]:
net = PauliNet.from_hf(mol, cas=(2, 4))

converged SCF energy = -7.98461406083101
CASSCF energy = -7.98461406083102
CASCI E = -7.98461406083102  E(CI) = -8.98461406083102  S^2 = 0.0000000


In [50]:
from deepqmc import train
#train(net)

### Test Stuff

In [11]:
n_up = 2
n_down = 2
batch_dim = 10
num_elec = n_up + n_down
num_nuc = 2
rn = torch.rand(num_nuc,3)
rs = torch.rand(batch_dim,num_elec,3)
#Paramerts
# F: Num. features, r_ij: cartesian positions
F = int(128)

# Embedding Functions for Scalars
spin_idxs = torch.tensor(
            (n_up + n_down) * [0] if n_up == n_down else n_up * [0] + n_down * [1]).repeat(batch_dim,1)

nuc_idxs = torch.arange(num_nuc).repeat(batch_dim,1)

# s0 = X(spin_idxs)
# sn = Y(nuc_idxs)

# Features Scalars and Vectors
#s_old = torch.rand(batch_dim,num_elec,F, dtype=torch.float)
v0 = torch.zeros(batch_dim,num_elec,F,3, dtype=torch.float)
#sn = torch.rand(batch_dim,num_nuc,F, dtype=torch.float)
vn = torch.zeros(batch_dim,num_nuc,F,3, dtype=torch.float)





In [19]:
test = OmniPaiNN(num_nuc,n_up, n_down ,1,1,4,1.5,num_elec,3,128)
jast, fs, ps = test(rs,rn)