In [15]:
import numpy as np
import torch
import dxtb
from dxtb.typing import DD
from tqdm import tqdm
from tad_mctc.data.molecules import mols as samples
from dxtb import Calculator
from dxtb.config import ConfigCache

dd: DD = {"dtype": torch.double, "device": torch.device("cuda:0")}
cache_config = ConfigCache(enabled=False, density=True, fock=True, overlap=True)
opts = {"verbosity": 2}

# Molecule
sample = samples["vancoh2"]
numbers = sample["numbers"].to(dd["device"])
print(f"len(numbers): {len(numbers)}")
positions = sample["positions"].clone().to(**dd)
charges = torch.tensor(0.0, **dd)

pos = positions.clone().requires_grad_(True)

# instantiate a dxtb calculator
# opts = dict(opts, **{"scf_mode": "reconnect", "scp_mode": "potential"})
opts = dict(opts, **{"scf_mode": "implicit", "scp_mode": "charges"})
# opts = dict(opts, **{"scf_mode": "reconnect"})

calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, opts=opts, **dd, timer=True)
dxtb.timer.reset()
calc.opts.cache = cache_config

print(f"Calc scf mode: {calc.opts.scf.scf_mode}")
print(f"Calc scp mode: {calc.opts.scf.scp_mode}")

S = calc.integrals.build_overlap(pos)
print(f"S.shape: {S.shape}")
dxtb.timer.print(v=1)


len(numbers): 176
Calc scf mode: 1
Calc scp mode: 1
S.shape: torch.Size([550, 550])


Timings
-------

[1mObjective                Time (s)        % Total[0m
------------------------------------------------
------------------------------------------------
Sum                    [37m     0.000           0.00[0m
[1mTotal                       0.085         100.00[0m


In [17]:
with torch.autograd.profiler.profile(record_shapes=True, use_cuda=False) as prof:
    torch.autograd.grad(S.sum(), pos, retain_graph=True)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=50))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
autograd::engine::evaluate_function: Int2c_V2Backwar...         0.02%       7.654us        50.43%      25.156ms      25.156ms             1  
                                       Int2c_V2Backward         0.89%     442.125us        50.39%      25.137ms      25.137ms             1  
                                               Int2c_V2        27.55%      13.740ms        27.55%      13.742ms      13.742ms             1  
                                           aten::einsum         0.06%      30.446us        21.80%      10.876ms       5.438ms             2  
      