### Generate labeled DFT and $GW$ Data
should generate three chkfiles in the `./test_chk` directory 

arguments:
1. --calc : dft+gwac means do a DFT calculation followed by a G0W0 calculation
2. --charge : system charge, default is 0 (neutral)
3. --basis : system basis set, e.g. ccpvdz
4. --xyz_file : system xyz structure input file
5. --chkfile : .chk file to write outputs
6. --json_spec: json dict with keyword arguments for calculation, e.g. `xc` is xc functional for DFT

In [1]:
!python -m mlgf.workflow.generate --calc dft+gwac --xyz_file ./test_xyz/ammonia.xyz --chk_file ./test_chk/ammonia.chk --basis ccpvdz --json_spec generate_config.json
!python -m mlgf.workflow.generate --calc dft+gwac --xyz_file ./test_xyz/methane.xyz --chk_file ./test_chk/methane.chk --basis ccpvdz --json_spec generate_config.json
!python -m mlgf.workflow.generate --calc dft+gwac --xyz_file ./test_xyz/water.xyz --chk_file ./test_chk/water.chk --basis ccpvdz --json_spec generate_config.json
!python -m mlgf.workflow.generate --calc dft+gwac --xyz_file ./test_xyz/ethane.xyz --chk_file ./test_chk/ethane.chk --basis ccpvdz --json_spec generate_config.json

Found existing SCF calculation from ./test_chk/ammonia.chk
Done loading previous SCF!

******** <class 'fcdmft.gw.mol.gw_ac.GWAC'> ********
method = GWAC
GW nocc = 5, nvir = 24
frozen orbitals = None
off-diagonal self-energy = True
GW density matrix = True
density-fitting for exchange = False
outcore for self-energy= False
broadening parameter = 5.000e-03
grid size for W is 100
grid size for self-energy is 30
analytic continuation method = pade
imaginary frequency cutoff = 5.0
Pade points = 18
Pade step ratio = 0.667
use perturbative linearized QP eqn = False
QPE max iter = 100
QPE tolerance = 1.0e-06

Starting get_sigma_diag main loop with 100 frequency points.
   0    1    2    3    4    5    6    7    8    9   10   11 
  12   13   14   15   16   17   18   19   20   21   22   23 
  24   25   26   27   28   29   30   31   32   33   34   35 
  36   37   38   39   40   41   42   43   44   45   46   47 
  48   49   50   51   52   53   54   55   56   57   58   59 
  60   61   62   63   64

### Inspect a chkfile by loading with Data.load_chk()

In [2]:
from mlgf.data import Data
import os

current_dir = os.getcwd()
chkfile_src_dir = f'{current_dir}/test_chk'

my_chkfile = f'{chkfile_src_dir}/methane.chk'
my_mlf = Data.load_chk(my_chkfile)
print(my_mlf.keys())

# Note that there are some additional feature matrices compared to the raw chkfile loaded without Moldatum.load_chk() because we call get_saiao_features inside Moldatum.load_chk()

Apr 08 16:43:16 ####### 
                ####### libDMET   version 0.5
                ####### A periodic DMET library for lattice model and realistic solid.
                ####### 
dict_keys(['dm_gw', 'dm_hf', 'e_corr', 'e_hf', 'e_mf', 'e_tot', 'ef', 'fock', 'freqs', 'hcore', 'mo_coeff', 'mo_energy', 'mo_occ', 'nocc', 'omega_fit', 'ovlp', 'sigmaI', 'time_gw', 'vj', 'vk', 'vk_hf', 'vxc', 'wts', 'xc', 'C_ao_saiao', 'sigma_saiao', 'dm_saiao', 'fock_saiao', 'hcore_saiao', 'vj_saiao', 'vk_saiao', 'C_saiao_mo', 'vxc_saiao', 'cat_orbtype_principal', 'cat_orbtype_angular', 'cat_orbtype_saiao', 'fock_iao', 'C_ao_iao', 'C_iao_saiao', 'hcore+vj_saiao', 'inds_core', 'atomic_charge_saiao', 'boys_saiao', 'fname', 'basis'])


In [3]:
# can also load the file as a mol and scf with pyscf.lib
from pyscf import lib
my_chkfile = f'{chkfile_src_dir}/methane.chk'
my_scf_data = lib.chkfile.load(my_chkfile, 'scf')
my_mol = lib.chkfile.load_mol(my_chkfile)


### Read and modify MBGF-Net metadata
This chunk will edit file paths involved in the MBGF-Net workflow. It first loads metadata in `gnn_config.json` then modifies path variables to point to your working directory. Finally overwrites overwrites with new file paths for the next two steps

In [4]:
import json
import os

json_file = 'gnn_config.json'


with open(json_file) as f:
    job = json.load(f)


current_dir = os.getcwd()
chkfile_src_dir = f'{current_dir}/test_chk'
my_train_files = [f'{chkfile_src_dir}/{f}' for f in os.listdir(chkfile_src_dir)]
my_test_joblib_file = f'{current_dir}/test_gnn_orchestrator.joblib'
torch_data_root = f'{current_dir}/torch_data_root'
dset_store_dir = f'{current_dir}/dset_tmp_dir'

job['train_files'] = my_train_files
job['gnn_orch_file_to_copy'] = my_test_joblib_file
job['gnn_orch_file'] = my_test_joblib_file
job['torch_data_root'] = torch_data_root
job['dset_store_dir'] = dset_store_dir


with open(json_file, 'w') as f:
    json.dump(job, f, indent=4)

### Inspect the `GraphOrchestrator` class that was pickled to joblib
The above metadata specification dictates the creation of the `GraphOrchestrator` class, which controls many aspects of the GNN workflow, including:
1. Sourcing DFT and $GW$ calculations saved .chk file (numpy h5 format) as the training data
2. Transformation/standardization of DFT features $X_{ii}$ and $X_{ij}$ 
3. Converting DFT and $GW$ data for each system into a custom graph object that inherits from `torch_geometric.data.Data` object. Each graph object holds node features, edge features, node-pair indices for each edge, SAIAO to MO rotation matrices, MO energies, number of occupied orbitals, and indices that map graph attributes (which may have edges and nodes removed) to full self-energy indicies (full nmo x nmo x nomega)
4. `torch_geometric.data.Dataset` object that holds indexable Graph objects from file reads
5. Storing training hyperparameters for the loss function (`loss_kwargs`) and the MBGF-Net architecture (`model_kwargs`)
6. Predicting the full self-energy of an unseen DFT calculation from .chk file with MBGF-Net 

Starting from a list of chkfiles with DFT and $GW$ data, the intended workflow relies on a metadata config read as .json:
1. Call mlgf.model.prepare_gnn_data to prepare the graph objects for training the self-energy GNN
2. Call mlgf.model.train_graph_ensemble to train the GNN with `model_kwargs`, `loss_kwargs`, and `train_config`


In [5]:
!python -m mlgf.model.prepare_gnn_data --json_spec gnn_config.json

saiao
None
gnn_orch_file:  /gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/test_gnn_orchestrator.joblib
dset_store_dir:  /gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/dset_tmp_dir
torch_data_root:  /gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/torch_data_root
-----Checking integrity of data on rank 0-----
Apr 08 16:43:44 ####### 
                ####### libDMET   version 0.5
                ####### A periodic DMET library for lattice model and realistic solid.
                ####### 
9.467319377520766e-05
Tr(dm_gw(iw)) == particle_number test failed for mlf_chkfile: /gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/test_chk/methane.chk
1.95299800473947e-05
Tr(dm_gw(iw)) == particle_number test failed for mlf_chkfile: /gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/test_chk/ammonia.chk
8.821478481024769e-05
Tr(dm_gw(iw)) == particle_number test failed for mlf_chkfile: /gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/test_chk/water.

##### Inspect The `GraphOrchestrator` object

In [6]:
import joblib
gnn_obj = joblib.load('test_gnn_orchestrator.joblib') # controls most of the data processing workflow
print(gnn_obj.data) # GraphDataset object is the data attribute (essentially a list of Graph objects)
print(gnn_obj.data[0]) # GraphDataset can be indexed like list to get all the Graph objects

GraphDataset(4)
Graph(x=[33, 46], edge_index=[2, 856], edge_attr=[856, 60], nodes_plus_edges=889, nmo=34, mo_energy=[34], node_indices_nonzero=[33], edge_indices_nonzero=[856, 2], undirected=True, fname='/gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/test_chk/methane.chk', ef=-0.15481777732521734, nomega=30, iomega=[30], num_elements_inv=[1], n_edges=856, n_nodes=33, homo_ind=4, lumo_ind=5, C_lo_mo=[1156], sigma_ii=[33, 60], sigma_ij=[856, 60])


### Train MBGF-Net
Key json config variables:
1. train_config: the parameters for Adam gradient descent: batch_size, epochs, learning_rate, learning_rate schedule
2. model_kwargs: parameters controlling sizes of GNN modules (scale_y = 0.001 means we train self-energy in units of mH)


In [7]:
!python -m mlgf.model.train_graph_ensemble --json_spec gnn_config.json

gnn_orch_file: /gpfs/gibbs/project/zhu/scv22/mlgf/examples/workflow_gnn/test_gnn_orchestrator.joblib
train_config:  {'epochs': 300, 'learning_rate': 0.005, 'weight_decay': 0.0, 'optimizer': 'Adam', 'dropout': 0.0, 'batch_size': 2, 'shuffle': True, 'num_workers': 0, 'cosine_t0': 100, 'cos_range': [100, 10000]}
loss_kwargs:  {'frontier_weight': 0.1, 'smoothness_weight': 0.0, 'gradient_weight': 0.1, 'frontier_range': [10, 10], 'always_in_batch': []}
Device used for training:  cpu
Model total number of params:  667212
tensor([7.7784e-04, 4.1166e-03, 1.0198e-02, 1.9157e-02, 3.1191e-02, 4.6578e-02,
        6.5682e-02, 8.8980e-02, 1.1708e-01, 1.5078e-01, 1.9107e-01, 2.3926e-01,
        2.9704e-01, 3.6665e-01, 4.5105e-01, 5.5426e-01, 6.8185e-01, 8.4163e-01,
        1.0449e+00, 1.3084e+00, 1.6581e+00, 2.1352e+00, 2.8096e+00, 3.8062e+00,
        5.3674e+00, 8.0150e+00, 1.3050e+01, 2.4514e+01, 6.0730e+01, 3.2140e+02],
       dtype=torch.float64)
-------Loss function initialization of precision an

### Predict $\Sigma(i\omega)$ with MBGF-Net
using the gnn_orch object as the wrapper that creates Graph objects from unseen chkfiles and converts the torch tensor to numpy.complex

In [8]:
import joblib
gnn_obj_file = 'test_gnn_orchestrator.joblib'
gnn_obj = joblib.load(gnn_obj_file)

# predict sigma(iw) in the SAIAO basis froma chkfile not seen in the training dataset
example_chkfile = 'test_chk/ethane.chk'
sigma_ml = gnn_obj.predict_full_sigma(example_chkfile) 


### Post process the predicted $\textit{GW}$ self-energy to get density matrix, qpe, and photoemission spectrum, bse S1


In [9]:
# the pdset attribute contains the Moldatum object that was mounted for an MBGF-Net prediction
mlf = gnn_obj.pdset[0]

# extract the SAIAO to MO rotation for sigma(iw)
C_saiao_mo = mlf['C_saiao_mo']

# rotate sigma(iw) to MO basis
from mlgf.lib.ml_helper import sigma_lo_mo
sigma_ml_mo = sigma_lo_mo(sigma_ml, C_saiao_mo) 

# extract the properties from MBGF and self-energy, takes sigma in the MO basis, can be machine learned sigma or true self-energy
from mlgf.workflow.get_ml_info import get_properties
import numpy as np
eta = 0.01 # band broadening for DOS
freqs = np.linspace(-1, 1, 201) # real frequency points on which to evaluate the DOS
properties = 'dqmb' # a short string denoting which properties to compute from sigma: d - dos, q - qpe, m - density matrix, b - bse

# the indicies of the sigmaI points used for analytic continuation for QPE
ac_idx = [ 0,  2,  3,  5,  6,  8,  9, 11, 12, 14, 15, 16, 17, 19, 20, 21, 22, 23]
properties_ml = get_properties(sigma_ml_mo, mlf, freqs, eta, properties = properties, ac_idx = ac_idx)
properties_true = get_properties(mlf['sigmaI'], mlf, freqs, eta, properties = properties, ac_idx = ac_idx)

### Compare the property predictions derived to from MLGF with the true reference(s)

In [10]:
au_to_ev = 27.21
nocc = mlf['nocc']

homo_error = properties_ml['qpe'][nocc-1] - properties_true['qpe'][nocc-1]
print(f'HOMO Error (eV): {homo_error*au_to_ev:0.4f}') 

lumo_error = properties_ml['qpe'][nocc] - properties_true['qpe'][nocc]
print(f'LUMO Error (eV): {lumo_error*au_to_ev:0.4f}') 

dm_mae = np.mean(np.abs(properties_ml['dm'] - properties_true['dm']))
print(f'Density matrix MAE: {dm_mae:0.4f}') 

# dipoles and quadrupoles, note predicted dm is in MO basis
from mlgf.lib.dm_helper import dm_mo_to_ao, get_dipole, scalar_quadrupole
from pyscf import dft, lib
mol = lib.chkfile.load_mol(example_chkfile)
rks = dft.RKS(mol)
rks.xc = mlf['xc'].decode('utf-8')
scf_data = lib.chkfile.load(example_chkfile, 'scf')
rks.__dict__.update(scf_data)
dipole_ml = get_dipole(rks, dm_mo_to_ao(properties_ml['dm'], scf_data['mo_coeff']))
dipole_true = get_dipole(rks, dm_mo_to_ao(properties_true['dm'], scf_data['mo_coeff']))

quadrupole_ml = scalar_quadrupole(mol, dm_mo_to_ao(properties_ml['dm'], scf_data['mo_coeff']))
quadrupole_true = scalar_quadrupole(mol, dm_mo_to_ao(properties_true['dm'], scf_data['mo_coeff']))

print(f'Dipole error (Debye): {(dipole_ml-dipole_true):0.4f}') 
print(f'Quadrupole error (Debye⋅Å): {(quadrupole_ml-quadrupole_true):0.4f}') 

if 'b' in properties:
    s1_ml = properties_ml['bse_exci_s'][0]*au_to_ev
    s1_true = properties_true['bse_exci_s'][0]*au_to_ev
    print(f'BSE S1 error (eV): {(s1_ml-s1_true):0.4f}') 
    

HOMO Error (eV): 0.0048
LUMO Error (eV): 0.0656
Density matrix MAE: 0.0002
Dipole error (Debye): 0.0890
Quadrupole error (Debye⋅Å): -0.0552
BSE S1 error (eV): 0.0565
