### Generate labeled DFT and $GW$ Data
should generate three chkfiles in the `./test_chk` directory. This step is the slow step (i.e. labeling the data with the $GW$ calculations that we want to accelerate with ML). You may skip this step and use the existing chkfiles in the test_chk folder

arguments:
1. --xc pbe0 : DFT functional is PBE0
2. --gf gw : do a DFT+GW calculation and save everything to disk
3. --src : folder with .xyz files to use as structures for DFT
4. -o : folder where to write .chk files with all the calculation outputs
5. -b ccpvdz : use the ccpvdz basis
6. --gw_nw2 30 : evaluate sigmaI and gHF on 30 GL points (numerical integration within GWGF is still on full 100 GL points)

In [2]:
# !python -m mlgf.workflow.generate_splitxyz --xc pbe0 --gf gw --src ./test_xyz -o ./test_chk -b ccpvdz --gw_nw2 30

### Inspect a chkfile by loading with Moldatum.load_chk()
Here is a summary of the attributes in the Moldatum object:
1. `dm_gw` - $GW$-level correlated density matrix, MO basis
2. `dm_hf` - mean-field (DFT or Hartree-Fock) density matrix, AO basis
3. `e_tot` - DFT total energy
4. `ef` - DFT fermi level
5.  `fock` - DFT fock matrix, AO basis
6.  `freqs` - iomega points on which sigmaI (i.e. self-energy) is evaluated, typically not used, instead omega_fit is used
7.   `hcore` - DFT core hamiltonian, AO basis
8.   `mo_coeff` - DFT MO coefficients
9.   `mo_energy` - DFT KS orbital energies
10.   `mo_occ` - DFT MO occupation numbers
11.   `nocc` - number of occupied MOs
12.   `omega_ac` - subset of `omega_fit` used for AC fitting to obtain spectrum and QP energies
13.   `omega_fit` - full set `ef` + i*`freqs` points on which self-energy is evaluated. used for density matrix integration
14.   `ovlp` - overlap matrix S, AO basis
15.   `sigmaI` - self-energy on evaluated on imaginary frequencies, MO basis
16.   `time_gw`, `time_rks` - calculation times for $GW$ and $DFT$
17.   `vj` - DFT coulomb matrix J, AO basis
18.   `vk` - DFT exchange matrix K, AO basis
19.   `vk_hf` - $V_text{xc} - V_k$  (equal $V_k / 2$ for Hartree-Fock calculation)
20.   `vxc` - $V_text{xc}$ from DFT
21.   `wts` - Gauss-Legendre weights for density-matrix integration
22.   `xc` - the type of DFT exchange
23.   `inds_core` - the AO/SAIAO indices corresponding to core orbitals
24.   `C_ao_saiao` - rotation AO to SAIAO
25.   `my_feature_saiao` - features above in the SAIAO equivariant basis
26.   `fname` - chkfile name that was read to create the Moldatum object

In [3]:
from mlgf.data import Moldatum
import os

current_dir = os.getcwd()
chkfile_src_dir = f'{current_dir}/test_chk'

my_chkfile = f'{chkfile_src_dir}/methane.chk'
my_mlf = Moldatum.load_chk(my_chkfile)
print(my_mlf.keys())

# Note that there are some additional feature matrices compared to the raw chkfile loaded without Moldatum.load_chk() because we call get_saiao_features inside Moldatum.load_chk()

dict_keys(['coeff', 'dm_gw', 'dm_hf', 'e_tot', 'ef', 'fock', 'freqs', 'hcore', 'mo_coeff', 'mo_energy', 'mo_occ', 'mol', 'nocc', 'omega_ac', 'omega_fit', 'ovlp', 'sigmaI', 'time_gw', 'time_rks', 'vj', 'vk', 'vk_hf', 'vxc', 'wts', 'xc', 'inds_core', 'C_ao_saiao', 'sigma_saiao', 'dm_saiao', 'fock_saiao', 'hcore_saiao', 'vj_saiao', 'vk_saiao', 'C_saiao_mo', 'vxc_saiao', 'cat_orbtype_principal', 'cat_orbtype_angular', 'cat_orbtype_saiao', 'fock_iao', 'C_ao_iao', 'C_iao_saiao', 'hcore+vj_saiao', 'atomic_charge_saiao', 'boys_saiao', 'fname'])


In [4]:
# we cam also load the file as a mol and scf with pyscf.lib
from pyscf import lib
my_chkfile = f'{chkfile_src_dir}/methane.chk'
my_scf_data = lib.chkfile.load(my_chkfile, 'scf')
my_mol = lib.chkfile.load_mol(my_chkfile)
print(my_mol)
print(my_scf_data)

### Read and modify MBGF-Net metadata
This code chunk will edit file paths involved in the MBGF-Net workflow. It first loads metadata in 'gnn_config.json'm then modifies path variables to point to your working directory. Finally overwrites overwrites with new file paths for the next two steps

In [6]:
import json
import os

json_file = 'gnn_config.json'


with open(json_file) as f:
    job = json.load(f)


current_dir = os.getcwd()
chkfile_src_dir = f'{current_dir}/test_chk'

# training data: methane, ammonia, water (not ethane, this is the test case)
my_train_files = [f'{chkfile_src_dir}/methane.chk', f'{chkfile_src_dir}/water.chk', f'{chkfile_src_dir}/ammonia.chk']

my_test_joblib_file = f'{current_dir}/test_gnn_orchestrator.joblib' # pic
torch_data_root = f'{current_dir}/torch_data_root' # where the graph objects are written to
dset_store_dir = f'{current_dir}/dset_tmp_dir' # temporary directory, deleted after prepare_gnn_data.py

job['train_files'] = my_train_files
job['gnn_orch_file_to_copy'] = my_test_joblib_file
job['gnn_orch_file'] = my_test_joblib_file
job['torch_data_root'] = torch_data_root
job['dset_store_dir'] = dset_store_dir


with open(json_file, 'w') as f:
    json.dump(job, f, indent=4)

### Create the `GraphOrchestrator` class with mlgf.model.prepare_gnn_data and save to joblib
The above metadata specification dictates the creation of the `GraphOrchestrator` class, which controls many aspects of the GNN workflow, including:
1. Sourcing DFT and $GW$ calculations saved .chk file (numpy h5 format) as the training data
2. Transformation/standardization of DFT features $X_{ii}$ and $X_{ij}$ 
3. Converting DFT and $GW$ data for each system into a custom graph object that inherits from `torch_geometric.data.Data` object. Each graph object holds node features, edge features, node-pair indices for each edge, SAIAO to MO rotation matrices, MO energies, number of occupied orbitals, and indices that map graph attributes (which may have edges and nodes removed) to full self-energy indicies (full nmo x nmo x nomega)
4. `torch_geometric.data.Dataset` object that holds indexable Graph objects from file reads
5. Storing training hyperparameters for the loss function (`loss_kwargs`) and the MBGF-Net architecture (`model_kwargs`)
6. Predicting the full self-energy of an unseen DFT calculation from .chk file with MBGF-Net 

Starting from a list of chkfiles with DFT and $GW$ data, the intended workflow relies on a metadata config read as .json:
1. Call mlgf.model.prepare_gnn_data to prepare the graph objects for training the self-energy GNN
2. Call mlgf.model.train_graph_ensemble to train the GNN with `model_kwargs`, `loss_kwargs`, and `train_config`


In [7]:
!python -m mlgf.model.prepare_gnn_data --json_spec gnn_config.json

None
gnn_orch_file:  /root/capsule/code/mlgf/examples/test_gnn_orchestrator.joblib
dset_store_dir:  /root/capsule/code/mlgf/examples/dset_tmp_dir
torch_data_root:  /root/capsule/code/mlgf/examples/torch_data_root
-----Checking integrity of data on rank 0-----
Aug 19 14:34:53 ####### 
                ####### libDMET   version 0.5
                ####### A periodic DMET library for lattice model and realistic solid.
                ####### 


#### Inspect The `GraphOrchestrator` object before training the GNN
Load with joblib and inspect its GraphDataset object, which contains Graphs that can be indexed like elements of a list.

In [8]:
import joblib
gnn_orch = joblib.load('test_gnn_orchestrator.joblib') # controls most of the data processing workflow
print(gnn_orch.data) # GraphDataset object is the data attribute (essentially a list of Graph objects)
print(gnn_orch.data[0]) # GraphDataset can be indexed like list to get all the Graph objects

GraphDataset(3)
Graph(x=[33, 46], edge_index=[2, 854], edge_attr=[854, 60], nodes_plus_edges=887, nmo=34, mo_energy=[34], node_indices_nonzero=[33], edge_indices_nonzero=[854, 2], undirected=True, fname='/root/capsule/code/mlgf/examples/test_chk/methane.chk', ef=-0.1548177835534029, nomega=30, iomega=[30], num_elements_inv=[1], n_edges=854, n_nodes=33, homo_ind=4, lumo_ind=5, C_lo_mo=[1156], sigma_ii=[33, 60], sigma_ij=[854, 60])




### Train MBGF-Net
Key json config variables:
1. train_config: the parameters for Adam gradient descent: batch_size, epochs, learning_rate, learning_rate schedule
2. model_kwargs: parameters controlling sizes of GNN modules (scale_y = 0.001 means we train self-energy in units of mH)
3. loss_kwargs: loss parameters


In [8]:
!python -m mlgf.model.train_graph_ensemble --json_spec gnn_config.json

gnn_orch_file: /root/capsule/code/mlgf/examples/test_gnn_orchestrator.joblib
train_config:  {'epochs': 400, 'learning_rate': 0.001, 'weight_decay': 0.0, 'optimizer': 'Adam', 'dropout': 0.0, 'batch_size': 1, 'shuffle': True, 'num_workers': 0, 'cosine_t0': 100, 'cos_range': [100, 300]}
loss_kwargs:  {'frontier_weight': 0.1, 'smoothness_weight': 0.0, 'gradient_weight': 0.1, 'frontier_range': [10, 10], 'always_in_batch': []}
Device used for training:  cpu
Model total number of params:  667212
tensor([7.7784e-04, 4.1166e-03, 1.0198e-02, 1.9157e-02, 3.1191e-02, 4.6578e-02,
        6.5682e-02, 8.8980e-02, 1.1708e-01, 1.5078e-01, 1.9107e-01, 2.3926e-01,
        2.9704e-01, 3.6665e-01, 4.5105e-01, 5.5426e-01, 6.8185e-01, 8.4163e-01,
        1.0449e+00, 1.3084e+00, 1.6581e+00, 2.1352e+00, 2.8096e+00, 3.8062e+00,
        5.3674e+00, 8.0150e+00, 1.3050e+01, 2.4514e+01, 6.0730e+01, 3.2140e+02],
       dtype=torch.float64)
-------Loss function initialization of precision and device-------
attr x_val

time for Adam #67: 0.10s, main loop value (final batch loss): 1.510808e+00, step_size: 1.000000e-03
time for Adam #68: 0.09s, main loop value (final batch loss): 2.308365e+00, step_size: 1.000000e-03
time for Adam #69: 0.09s, main loop value (final batch loss): 1.364740e+00, step_size: 1.000000e-03
time for Adam #70: 0.09s, main loop value (final batch loss): 3.386496e+00, step_size: 1.000000e-03
time for Adam #71: 0.09s, main loop value (final batch loss): 2.170771e+00, step_size: 1.000000e-03
time for Adam #72: 0.09s, main loop value (final batch loss): 2.076168e+00, step_size: 1.000000e-03
time for Adam #73: 0.09s, main loop value (final batch loss): 2.043852e+00, step_size: 1.000000e-03
time for Adam #74: 0.09s, main loop value (final batch loss): 1.184565e+00, step_size: 1.000000e-03
time for Adam #75: 0.09s, main loop value (final batch loss): 2.932502e+00, step_size: 1.000000e-03
time for Adam #76: 0.10s, main loop value (final batch loss): 2.825708e+00, step_size: 1.000000e-03


time for Adam #149: 0.09s, main loop value (final batch loss): 4.618000e-01, step_size: 5.157054e-04
time for Adam #150: 0.10s, main loop value (final batch loss): 4.530957e-01, step_size: 5.000000e-04
time for Adam #151: 0.09s, main loop value (final batch loss): 4.453891e-01, step_size: 4.842946e-04
time for Adam #152: 0.09s, main loop value (final batch loss): 3.209853e-01, step_size: 4.686047e-04
time for Adam #153: 0.09s, main loop value (final batch loss): 3.180005e-01, step_size: 4.529458e-04
time for Adam #154: 0.09s, main loop value (final batch loss): 3.162759e-01, step_size: 4.373334e-04
time for Adam #155: 0.09s, main loop value (final batch loss): 4.201361e-01, step_size: 4.217828e-04
time for Adam #156: 0.09s, main loop value (final batch loss): 4.161770e-01, step_size: 4.063093e-04
time for Adam #157: 0.09s, main loop value (final batch loss): 3.080525e-01, step_size: 3.909284e-04
time for Adam #158: 0.09s, main loop value (final batch loss): 3.077980e-01, step_size: 3.7

time for Adam #231: 0.09s, main loop value (final batch loss): 2.955916e-01, step_size: 7.810417e-04
time for Adam #232: 0.09s, main loop value (final batch loss): 3.041283e-01, step_size: 7.679134e-04
time for Adam #233: 0.09s, main loop value (final batch loss): 2.740042e-01, step_size: 7.545207e-04
time for Adam #234: 0.09s, main loop value (final batch loss): 2.893398e-01, step_size: 7.408768e-04
time for Adam #235: 0.09s, main loop value (final batch loss): 2.824707e-01, step_size: 7.269952e-04
time for Adam #236: 0.10s, main loop value (final batch loss): 2.792503e-01, step_size: 7.128896e-04
time for Adam #237: 0.09s, main loop value (final batch loss): 2.477064e-01, step_size: 6.985739e-04
time for Adam #238: 0.09s, main loop value (final batch loss): 2.432878e-01, step_size: 6.840623e-04
time for Adam #239: 0.09s, main loop value (final batch loss): 2.691729e-01, step_size: 6.693690e-04
time for Adam #240: 0.09s, main loop value (final batch loss): 2.688382e-01, step_size: 6.5

time for Adam #313: 0.10s, main loop value (final batch loss): 2.004197e-01, step_size: 1.000000e-03
time for Adam #314: 0.10s, main loop value (final batch loss): 2.055272e-01, step_size: 1.000000e-03
time for Adam #315: 0.10s, main loop value (final batch loss): 2.021253e-01, step_size: 1.000000e-03
time for Adam #316: 0.10s, main loop value (final batch loss): 2.016004e-01, step_size: 1.000000e-03
time for Adam #317: 0.09s, main loop value (final batch loss): 1.323020e-01, step_size: 1.000000e-03
time for Adam #318: 0.09s, main loop value (final batch loss): 1.350814e-01, step_size: 1.000000e-03
time for Adam #319: 0.09s, main loop value (final batch loss): 1.470975e-01, step_size: 1.000000e-03
time for Adam #320: 0.10s, main loop value (final batch loss): 1.926186e-01, step_size: 1.000000e-03
time for Adam #321: 0.09s, main loop value (final batch loss): 2.199366e-01, step_size: 1.000000e-03
time for Adam #322: 0.09s, main loop value (final batch loss): 1.277759e-01, step_size: 1.0

time for Adam #395: 0.09s, main loop value (final batch loss): 1.343646e-01, step_size: 1.000000e-03
time for Adam #396: 0.09s, main loop value (final batch loss): 1.557217e-01, step_size: 1.000000e-03
time for Adam #397: 0.09s, main loop value (final batch loss): 1.724749e-01, step_size: 1.000000e-03
time for Adam #398: 0.09s, main loop value (final batch loss): 1.051174e-01, step_size: 1.000000e-03
time for Adam #399: 0.09s, main loop value (final batch loss): 1.379246e-01, step_size: 1.000000e-03
Rank 0 finished training!


### Predict $\Sigma(i\omega)$ with MBGF-Net
using the gnn_orch object, take an unseen molecule from a chkfile (ethane) and compute its self-energy in the SAIAO basis.

In [11]:
import joblib
gnn_orch = joblib.load('test_gnn_orchestrator.joblib')

# predict sigma(iw) in the SAIAO basis froma chkfile not seen in the training dataset
example_chkfile = 'test_chk/ethane.chk'
sigma_ml = gnn_orch.predict_full_sigma(example_chkfile) 




### Post process the predicted $\textit{GW}$ self-energy to get density matrix, qpe, and photoemission spectrum


In [13]:
# the pdset attribute contains the Moldatum object that was mounted for an MBGF-Net prediction
mlf = gnn_orch.pdset[0]

# extract the SAIAO to MO rotation for sigma(iw)
C_saiao_mo = mlf['C_saiao_mo']

# rotate sigma(iw) to MO basis
from mlgf.lib.helpers import sigma_lo_mo
sigma_ml_mo = sigma_lo_mo(sigma_ml, C_saiao_mo) 

# extract the properties from MBGF and self-energy, takes sigma in the MO basis, can be machine learned sigma or true self-energy
from mlgf.workflow.get_ml_info import get_properties
import numpy as np
eta = 0.01 # band broadening for DOS
freqs = np.linspace(-1, 1, 201) # real frequency points on which to evaluate the DOS
properties = 'dqm' # a short string denoting which properties to compute from sigma: d - dos, q - qpe, m - density matrix

properties_ml = get_properties(sigma_ml_mo, mlf, freqs, eta, properties = properties)
properties_true = get_properties(mlf['sigmaI'], mlf, freqs, eta, properties = properties)

### Compare the property predictions derived to from MLGF with the true reference(s)
Ethane should have a poor LUMO prediction, because the training data lacks a C-C single bond.

In [16]:
au_to_ev = 27.21
nocc = mlf['nocc']

homo_error = properties_ml['qpe'][nocc-1] - properties_true['qpe'][nocc-1]
print(f'HOMO Error (eV): {homo_error*au_to_ev:0.4f}') 

lumo_error = properties_ml['qpe'][nocc] - properties_true['qpe'][nocc]
print(f'LUMO Error (eV): {lumo_error*au_to_ev:0.4f}') 

dm_mae = np.mean(np.abs(properties_ml['dm'] - properties_true['dm']))
print(f'Density matrix MAE: {dm_mae:0.4f}') 

# dipoles and quadrupoles, note predicted dm is in MO basis
from mlgf.lib.dm_helper import dm_mo_to_ao, get_dipole, scalar_quadrupole
from pyscf import dft, lib
mol = lib.chkfile.load_mol(example_chkfile)
rks = dft.RKS(mol)
rks.xc = mlf['xc'].decode('utf-8')
scf_data = lib.chkfile.load(example_chkfile, 'scf')
rks.__dict__.update(scf_data)
dipole_ml = get_dipole(rks, dm_mo_to_ao(properties_ml['dm'], scf_data['mo_coeff']))
dipole_true = get_dipole(rks, dm_mo_to_ao(properties_true['dm'], scf_data['mo_coeff']))

quadrupole_ml = scalar_quadrupole(mol, dm_mo_to_ao(properties_ml['dm'], scf_data['mo_coeff']))
quadrupole_true = scalar_quadrupole(mol, dm_mo_to_ao(properties_true['dm'], scf_data['mo_coeff']))

print(f'Dipole error (Debye): {(dipole_ml-dipole_true):0.4f}') 
print(f'Quadrupole error (Debye⋅Å): {(quadrupole_ml-quadrupole_true):0.4f}') 


HOMO Error (eV): 0.0977
LUMO Error (eV): 0.5847
Density matrix MAE: 0.0005
Dipole error (Debye): 0.2615
Quadrupole error (Debye⋅Å): -0.1182
