In [3]:
import torch
import dqc
import dqc.xc
import dqc.utils

In [2]:
atomzs, atomposs = dqc.parse_moldesc("H -1 0 0; H 1 0 0")
print(atomzs)
atomposs

tensor([1, 1])


tensor([[-1.,  0.,  0.],
        [ 1.,  0.,  0.]], dtype=torch.float64)

In [3]:
atomposs = atomposs.requires_grad_()  # mark atomposs as differentiable
mol = dqc.Mol((atomzs, atomposs), basis="3-21G")
qc = dqc.HF(mol).run()
ene = qc.energy()  # calculate the energy
force = -torch.autograd.grad(ene, atomposs)[0]  
force

tensor([[ 0.1033, -0.0000, -0.0000],
        [-0.1033, -0.0000, -0.0000]], dtype=torch.float64)

In [4]:
class MyLDAX(dqc.xc.CustomXC):
    def __init__(self, a, p):
        super().__init__()
        self.a = a
        self.p = p

    @property
    def family(self):
        # 1 for LDA, 2 for GGA, 4 for MGGA
        return 1

    def get_edensityxc(self, densinfo):
        # densinfo has up and down components
        if isinstance(densinfo, dqc.utils.SpinParam):
            # spin-scaling of the exchange energy
            return 0.5 * (self.get_edensityxc(densinfo.u * 2) 
                          + self.get_edensityxc(densinfo.d * 2))
        else:
            rho = densinfo.value.abs() + 1e-15  # safeguarding from nan
            return self.a * rho ** self.p

In [5]:
a = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.double))
p = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.double))
myxc = MyLDAX(a, p)

In [6]:
atomzs, atomposs = dqc.parse_moldesc("H -1 0 0; H 1 0 0")
atomposs = atomposs.requires_grad_()  # mark atomposs as differentiable
mol = dqc.Mol((atomzs, atomposs), basis="3-21G")

In [7]:
ks = dqc.KS(mol, xc=myxc).run()
ene = ks.energy()
ene

tensor(-0.4645, dtype=torch.float64, grad_fn=<AddBackward0>)

In [8]:
grad_a, grad_p = torch.autograd.grad(ene, (a, p), retain_graph=True)
print(grad_a, grad_p)

tensor(0.0711, dtype=torch.float64) tensor(-0.2108, dtype=torch.float64)


In [9]:
force = -torch.autograd.grad(ene, atomposs)[0] 
force

tensor([[-2.8237e-02, -9.7578e-19, -2.7105e-18],
        [ 2.8237e-02, -1.0842e-18, -1.7347e-18]], dtype=torch.float64)