In [1]:
#!/usr/bin/env python

"""Module to train for a folder with formatted dataset."""
import csv
import os
import sys
import time
from jarvis.core.atoms import Atoms
from alignn.data import get_train_val_loaders
from alignn.train import train_dgl
from alignn.config import TrainingConfig
from jarvis.db.jsonutils import loadjson
import argparse
import glob
import torch

device = "cpu"
if torch.cuda.is_available():
    device = torch.device("cuda")


parser = argparse.ArgumentParser(
    description="Atomistic Line Graph Neural Network"
)
parser.add_argument(
    "--root_dir",
    default="./",
    help="Folder with id_props.csv, structure files",
)
parser.add_argument(
    "--config_name",
    default="./config.json",
    help="Name of the config file",
)

parser.add_argument(
    "--file_format", default="poscar", help="poscar/cif/xyz/pdb file format."
)

# parser.add_argument(
#    "--keep_data_order",
#    default=True,
#    help="Whether to randomly shuffle samples, True/False",
# )

parser.add_argument(
    "--classification_threshold",
    default=None,
    help="Floating point threshold for converting into 0/1 class"
    + ", use only for classification tasks",
)

parser.add_argument(
    "--batch_size", default=None, help="Batch size, generally 64"
)

parser.add_argument(
    "--epochs", default=None, help="Number of epochs, generally 300"
)

parser.add_argument(
    "--output_dir",
    default="./",
    help="Folder to save outputs",
)

parser.add_argument(
    "--device",
    default=None,
    help="set device for training the model [e.g. cpu, cuda, cuda:2]",
)

parser.add_argument(
    "--restart_model_path",
    default=None,
    help="Checkpoint file path for model",
)


def train_for_folder(
    root_dir="./",
    config_name="config.json",
    # keep_data_order=False,
    classification_threshold=None,
    batch_size=None,
    epochs=None,
    restart_model_path=None,
    file_format="poscar",
    output_dir=None,
):
    """Train for a folder."""
    # config_dat=os.path.join(root_dir,config_name)
    id_prop_dat = os.path.join(root_dir, "id_prop.csv")
    config = loadjson(config_name)
    if type(config) is dict:
        try:
            config = TrainingConfig(**config)
        except Exception as exp:
            print("Check", exp)

    # config.keep_data_order = keep_data_order
    if classification_threshold is not None:
        config.classification_threshold = float(classification_threshold)
    if output_dir is not None:
        config.output_dir = output_dir
    if batch_size is not None:
        config.batch_size = int(batch_size)
    if epochs is not None:
        config.epochs = int(epochs)
    if restart_model_path is not None:
        print("Restarting model from:", restart_model_path)
        from alignn.models.alignn import ALIGNN, ALIGNNConfig

        rest_config = loadjson(os.path.join(restart_model_path, "config.json"))
        print("rest_config", rest_config)
        model = ALIGNN(ALIGNNConfig(**rest_config["model"]))
        chk_glob = os.path.join(restart_model_path, "*.pt")
        tmp = "na"
        for i in glob.glob(chk_glob):
            tmp = i
        print("Checkpoint file", tmp)
        model.load_state_dict(torch.load(tmp, map_location=device)["model"])
        model.to(device)
    else:
        model = None
    with open(id_prop_dat, "r") as f:
        reader = csv.reader(f)
        data = [row for row in reader]

    dataset = []
    n_outputs = []
    multioutput = False
    lists_length_equal = True
    for i in data:
        info = {}
        file_name = i[0]
        file_path = os.path.join(root_dir, file_name)
        if file_format == "poscar":
            atoms = Atoms.from_poscar(file_path)
        elif file_format == "cif":
            atoms = Atoms.from_cif(file_path)
        elif file_format == "xyz":
            # Note using 500 angstrom as box size
            atoms = Atoms.from_xyz(file_path, box_size=500)
        elif file_format == "pdb":
            # Note using 500 angstrom as box size
            # Recommended install pytraj
            # conda install -c ambermd pytraj
            atoms = Atoms.from_pdb(file_path, max_lat=500)
        else:
            raise NotImplementedError(
                "File format not implemented", file_format
            )

        info["atoms"] = atoms.to_dict()
        info["jid"] = file_name

        tmp = [float(j) for j in i[1:]]  # float(i[1])
        if len(tmp) == 1:
            tmp = tmp[0]
        else:
            multioutput = True
        info["target"] = tmp  # float(i[1])
        n_outputs.append(info["target"])
        dataset.append(info)
    if multioutput:
        lists_length_equal = False not in [
            len(i) == len(n_outputs[0]) for i in n_outputs
        ]

    # print ('n_outputs',n_outputs[0])
    if multioutput and classification_threshold is not None:
        raise ValueError("Classification for multi-output not implemented.")
    if multioutput and lists_length_equal:
        config.model.output_features = len(n_outputs[0])
    else:
        # TODO: Pad with NaN
        if not lists_length_equal:
            raise ValueError("Make sure the outputs are of same size.")
        else:
            config.model.output_features = 1
    (
        train_loader,
        val_loader,
        test_loader,
        prepare_batch,
    ) = get_train_val_loaders(
        dataset_array=dataset,
        target=config.target,
        n_train=config.n_train,
        n_val=config.n_val,
        n_test=config.n_test,
        train_ratio=config.train_ratio,
        val_ratio=config.val_ratio,
        test_ratio=config.test_ratio,
        batch_size=config.batch_size,
        atom_features=config.atom_features,
        neighbor_strategy=config.neighbor_strategy,
        standardize=config.atom_features != "cgcnn",
        id_tag=config.id_tag,
        pin_memory=config.pin_memory,
        workers=config.num_workers,
        save_dataloader=config.save_dataloader,
        use_canonize=config.use_canonize,
        filename=config.filename,
        cutoff=config.cutoff,
        max_neighbors=config.max_neighbors,
        output_features=config.model.output_features,
        classification_threshold=config.classification_threshold,
        target_multiplication_factor=config.target_multiplication_factor,
        standard_scalar_and_pca=config.standard_scalar_and_pca,
        keep_data_order=config.keep_data_order,
        output_dir=config.output_dir,
    )
    t1 = time.time()
    train_dgl(
        config,
        model,
        train_val_test_loaders=[
            train_loader,
            val_loader,
            test_loader,
            prepare_batch,
        ],
    )
    t2 = time.time()
    print("Time taken (s):", t2 - t1)

    # train_data = get_torch_dataset(


if __name__ == "__main__":
    args = parser.parse_args(sys.argv[1:])
    train_for_folder(
        root_dir=args.root_dir, 
        config_name=args.config_name,
        # keep_data_order=args.keep_data_order,
        classification_threshold=args.classification_threshold,
        output_dir=args.output_dir,
        batch_size=(args.batch_size),
        epochs=(args.epochs),
        file_format="poscar",
        restart_model_path=(args.restart_model_path),
    )

MAX val: 6.796
MIN val: 0.0
MAD: 1.061926993296
Baseline MAE: 1.0578761642499999
data range 6.796 0.0
Converting to graphs!


800it [00:37, 21.38it/s]


df                                                  atoms  \
0    {'lattice_mat': [[2.518121486183809, -5.829129...   
1    {'lattice_mat': [[6.817903676856986, 0.0, 0.0]...   
2    {'lattice_mat': [[6.817904509400231, 0.0, 0.0]...   
3    {'lattice_mat': [[6.8178961386046355, 0.0, 0.0...   
4    {'lattice_mat': [[6.817901950461007, 0.0, 0.0]...   
..                                                 ...   
795  {'lattice_mat': [[6.817897409267758, 0.0, 0.0]...   
796  {'lattice_mat': [[6.817904509400231, 0.0, 0.0]...   
797  {'lattice_mat': [[3.4089525047476097, -6.77023...   
798  {'lattice_mat': [[9.10930382796073, 0.0, 0.0],...   
799  {'lattice_mat': [[6.81790448930608, 0.0, 0.0],...   

                        jid   target  
0    POSCAR-hMOF-52788.vasp  0.00000  
1    POSCAR-hMOF-53420.vasp  3.48533  
2    POSCAR-hMOF-53231.vasp  2.76507  
3    POSCAR-hMOF-53001.vasp  0.00000  
4    POSCAR-hMOF-52851.vasp  1.16529  
..                      ...      ...  
795  POSCAR-hMOF-52689.vasp

100%|██████████| 800/800 [00:01<00:00, 578.13it/s]


data range 5.78814 0.0
Converting to graphs!


100it [00:04, 21.45it/s]


df                                                 atoms                     jid  \
0   {'lattice_mat': [[11.102239927739806, 0.0, 0.0...  POSCAR-hMOF-52757.vasp   
1   {'lattice_mat': [[6.817904509400231, 0.0, 0.0]...  POSCAR-hMOF-53286.vasp   
2   {'lattice_mat': [[6.817898332844221, 0.0, 0.0]...  POSCAR-hMOF-53322.vasp   
3   {'lattice_mat': [[6.817897409267758, 0.0, 0.0]...  POSCAR-hMOF-52676.vasp   
4   {'lattice_mat': [[6.817897998136963, 0.0, 0.0]...  POSCAR-hMOF-53495.vasp   
..                                                ...                     ...   
95  {'lattice_mat': [[6.817903676856986, 0.0, 0.0]...  POSCAR-hMOF-52786.vasp   
96  {'lattice_mat': [[6.817901950461007, 0.0, 0.0]...  POSCAR-hMOF-52861.vasp   
97  {'lattice_mat': [[6.817904509400231, 0.0, 0.0]...  POSCAR-hMOF-53263.vasp   
98  {'lattice_mat': [[6.8178961386046355, 0.0, 0.0...  POSCAR-hMOF-53089.vasp   
99  {'lattice_mat': [[6.817897409267758, 0.0, 0.0]...  POSCAR-hMOF-52695.vasp   

     target  
0   3.1580

100%|██████████| 100/100 [00:00<00:00, 589.81it/s]


data range 6.29285 0.0
Converting to graphs!


100it [00:04, 21.09it/s]


df                                                 atoms                     jid  \
0   {'lattice_mat': [[9.109307733088174, 0.0, 0.0]...  POSCAR-hMOF-53428.vasp   
1   {'lattice_mat': [[6.817904509400231, 0.0, 0.0]...  POSCAR-hMOF-53333.vasp   
2   {'lattice_mat': [[6.817901950461007, 0.0, 0.0]...  POSCAR-hMOF-52797.vasp   
3   {'lattice_mat': [[6.817904509400231, 0.0, 0.0]...  POSCAR-hMOF-53248.vasp   
4   {'lattice_mat': [[6.817904509400231, 0.0, 0.0]...  POSCAR-hMOF-53361.vasp   
..                                                ...                     ...   
95  {'lattice_mat': [[6.8178961386046355, 0.0, 0.0...  POSCAR-hMOF-53010.vasp   
96  {'lattice_mat': [[13.63580022564499, 0.0, 0.0]...  POSCAR-hMOF-53375.vasp   
97  {'lattice_mat': [[6.817897409267758, 0.0, 0.0]...  POSCAR-hMOF-52698.vasp   
98  {'lattice_mat': [[6.817901950461007, 0.0, 0.0]...  POSCAR-hMOF-52867.vasp   
99  {'lattice_mat': [[6.817898507678153, 0.0, 0.0]...  POSCAR-hMOF-52665.vasp   

     target  
0   2.3075

100%|██████████| 100/100 [00:00<00:00, 576.24it/s]


n_train: 800
n_val  : 100
n_test : 100
version='112bbedebdaecf59fb18e11c929080fb2f358246' dataset='user_data' target='target' atom_features='basic' neighbor_strategy='k-nearest' id_tag='jid' random_seed=123 classification_threshold=None n_val=None n_test=None n_train=None train_ratio=0.8 val_ratio=0.1 test_ratio=0.1 target_multiplication_factor=None epochs=3 batch_size=32 weight_decay=1e-05 learning_rate=0.001 filename='sample' warmup_steps=2000 criterion='mse' optimizer='adamw' scheduler='onecycle' pin_memory=False save_dataloader=False write_checkpoint=True write_predictions=True store_outputs=True progress=True log_tensorboard=False standard_scalar_and_pca=False use_canonize=True num_workers=0 cutoff=8.0 cutoff_extra=3.0 max_neighbors=12 keep_data_order=False normalize_graph_level_loss=False distributed=False data_parallel=False n_early_stopping=None output_dir='./' model=ALIGNNConfig(name='alignn', alignn_layers=4, gcn_layers=4, atom_input_features=11, edge_input_features=80, tripl

  from tqdm.autonotebook import tqdm


[1/25]   4%|4          [00:00<?]