In [1]:
%reload_ext autoreload
%autoreload 2

import h5py
N_Cs = 9

with h5py.File('../dxtb/dxtb-gpu/gpu-cpu_analysis/rdkit/alkanes_data_500.hdf5', 'r') as f:
    for mol_name, data in f.items():
        if mol_name == f"alkane_{N_Cs}_carbons":
            atomic_numbers = data['atomic_numbers'][:]
            coordinates = data['coordinates'][:]

print(f"Number of carbon atoms in {mol_name}: {N_Cs}")
print(f"Nb of atoms: {len(atomic_numbers)}")

Number of carbon atoms in alkane_9_carbons: 9
Nb of atoms: 29


# Batched

In [18]:
import dxtb
from dxtb._src.typing import DD
import torch
from ase.build import molecule

opts = {"scf_mode": "implicit", "batch_mode": 1, "int_driver": "libcint", "maxiter": 1}
batch_size = 128
results = {}

for device in ["cuda:0", "cpu"]:
    print(f"\nDevice: {device}")
    dd = {"dtype": torch.float32, "device": torch.device(device)}
    numbers = torch.tensor(atomic_numbers, device=dd["device"], dtype=torch.int32)
    positions = torch.tensor(coordinates, device=dd["device"], dtype=dd["dtype"])
    numbers = torch.stack([numbers] * batch_size)
    positions = torch.stack([positions] * batch_size).requires_grad_()
    charges = torch.zeros((batch_size,), device=dd["device"], dtype=dd["dtype"])

    calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts, timer=True)
    
    dxtb.timer.reset()
    e = calc.get_energy(positions, chrg=charges)
    dxtb.timer.start("Forces autograd")
    forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]
    dxtb.timer.stop("Forces autograd")
    dxtb.timer.print(v=0)

    results[device] = {
        "energy": e.detach().cpu(),
        "forces": forces.detach().cpu()
    }

# Compare results
energy_diff = (results["cuda:0"]["energy"] - results["cpu"]["energy"]).abs().max()
forces_diff = (results["cuda:0"]["forces"] - results["cpu"]["forces"]).abs().max()

print(f"\n[Comparison]")
print(f"GPU energy: {results['cuda:0']['energy'].mean().item():.6e}")
print(f"CPU energy: {results['cpu']['energy'].mean().item():.6e}")
print(f"Max energy diff: {energy_diff.item():.6e}")
print(f"Max forces diff: {forces_diff.item():.6e}")



Device: cuda:0


Timings
-------

[1mObjective                Time (s)        % Total[0m
------------------------------------------------
[1mClassicals                  0.011           1.55[0m
 - Repulsion           [37m     0.001           6.43[0m
 - DispersionD3        [37m     0.003          27.53[0m
 - Halogen             [37m     0.007          65.63[0m
[1mIntegrals                   0.504          71.90[0m
 - Overlap             [37m     0.502          99.65[0m
 - Core Hamiltonian    [37m     0.002           0.35[0m
[1mSCF                         0.018           2.63[0m
 - Interaction Cache   [37m     0.001           3.94[0m
 - Potential           [37m     0.012          64.70[0m
 - Fock build          [37m     0.000           0.59[0m
 - Diagonalize         [37m     0.008          43.24[0m
 - Density             [37m     0.001           3.39[0m
 - Charges             [37m     0.001           2.95[0m
[1mcupy_eigh                   0.004           0.

In [None]:
import torch
import dxtb
from dxtb._src.typing import DD
import ase.build.molecule as molecule
import torch.profiler

opts = {"scf_mode": "implicit", "batch_mode": 2, "int_driver": "libcint"}
batch_size = 100

for device in ["cpu", "cuda:0"]:
    dd = {"dtype": torch.float64, "device": torch.device(device)}
    numbers = torch.tensor(atomic_numbers, device=dd["device"], dtype=torch.int32)
    positions = torch.tensor(coordinates, device=dd["device"], dtype=dd["dtype"])
    numbers = torch.stack([numbers] * batch_size)
    positions = torch.stack([positions] * batch_size).requires_grad_()
    charges = torch.zeros((batch_size,), device=dd["device"], dtype=dd["dtype"])

    torch.cuda.synchronize() if "cuda" in device else None
    dxtb.timer.reset()
    calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts, timer=True)

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA
        ] if "cuda" in device else [torch.profiler.ProfilerActivity.CPU],
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        e = calc.get_energy(positions, chrg=charges)
        dxtb.timer.start("Forces autograd")
        forces = torch.autograd.grad(sum(e), positions, retain_graph=True)[0]
        dxtb.timer.stop("Forces autograd")
        torch.cuda.synchronize() if "cuda" in device else None

    dxtb.timer.print(v=0)
    print(f"\n--- PyTorch Profiler Summary ({device}) ---")
    print(prof.key_averages().table(sort_by="cuda_time_total" if "cuda" in device else "cpu_time_total", row_limit=20))


# Non-Batched

In [None]:
import dxtb
from dxtb._src.typing import DD
import torch
import ase.build.molecule as molecule

opts = {"scf_mode": "implicit", "batch_mode": 0}

for device in ["cpu", "cuda:0"]:
    dd = {"dtype": torch.float64, "device": torch.device(device)}
    numbers = torch.tensor(atomic_numbers, device= dd["device"], dtype = torch.int64)
    positions = torch.tensor(coordinates, device = dd['device'], dtype = torch.float64).requires_grad_()

    dxtb.timer.reset()
    calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)
    e = calc.get_energy(positions)
    forces = calc.get_forces(positions)
    dxtb.timer.print(v=0)


# Line plot

In [None]:
import dxtb
from dxtb._src.typing import DD
import torch
import ase.build.molecule as molecule
import matplotlib.pyplot as plt

# Setup
opts = {"scf_mode": "implicit", "batch_mode": 1}

batch_sizes = [2, 4, 8, 16, 32, 64, 128, 256]
timings = {"cpu": [], "cuda:0": []}

for batch_size in batch_sizes:
    for device in ["cpu", "cuda:0"]:
        dd: DD = {"dtype": torch.float64, "device": torch.device(device)}
        numbers = torch.tensor(atomic_numbers, device=dd["device"], dtype=torch.int32)
        positions = torch.tensor(coordinates, device=dd["device"], dtype=dd["dtype"])

        numbers = torch.stack([numbers] * batch_size)
        positions = torch.stack([positions] * batch_size)
        charges = torch.zeros((batch_size,), device=dd["device"], dtype=dd["dtype"])

        dxtb.timer.reset()
        calc = dxtb.Calculator(numbers, dxtb.GFN1_XTB, **dd, opts=opts)
        e = calc.get_energy(positions, chrg=charges)
        dxtb.timer.print(v=5)
        total = dxtb.timer.timers["total"].elapsed_time
        timings[device].append(total)
        print(f"[{device}] Batch {batch_size} → {total:.2f}s")

In [None]:
plt.figure()

for device in timings:
    actual_times = timings[device]
    
    # Plot actual timings
    plt.plot(batch_sizes, actual_times, label=f"{device}", marker='o', linestyle='-')
    
    # Get anchor time at batch size 2
    anchor_batch_size = 4
    anchor_idx = batch_sizes.index(anchor_batch_size)
    anchor_time = actual_times[anchor_idx]
    
    # Generate linear scaling line
    linear_times = [anchor_time * (b / anchor_batch_size) for b in batch_sizes]
    plt.plot(batch_sizes, linear_times, linestyle='--', label=f"{device} (linear)", alpha=0.6)

plt.xlabel("Batch Size")
plt.ylabel("Total Time (s)")
plt.title("dxtb Total Time vs. Batch Size with Linear Scaling Baseline")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
