diff --git a/MANIFEST.in b/MANIFEST.in index 24c9839..7ab3152 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,4 @@ include nablaDFT/links/* include nablaDFT/data/* +include nablaDFT/equiformer_v2/Jd.pt +include nablaDFT/escn/Jd.pt \ No newline at end of file diff --git a/config/datamodule/nablaDFT_ase.yaml b/config/datamodule/nablaDFT_ase.yaml index 7b10ea7..1a426cc 100644 --- a/config/datamodule/nablaDFT_ase.yaml +++ b/config/datamodule/nablaDFT_ase.yaml @@ -1,7 +1,6 @@ # Dataset config for ASE nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.ASENablaDFT -type_of_nn: ASE split: ${job_type} datapath: ./datasets/nablaDFT/train/raw dataset_name: ${dataset_name} diff --git a/config/datamodule/nablaDFT_ase_test.yaml b/config/datamodule/nablaDFT_ase_test.yaml index 3907e44..2b11c2a 100644 --- a/config/datamodule/nablaDFT_ase_test.yaml +++ b/config/datamodule/nablaDFT_ase_test.yaml @@ -1,7 +1,6 @@ # Dataset config for ASE nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.ASENablaDFT -type_of_nn: ASE split: ${job_type} datapath: ./datasets/nablaDFT/test/raw dataset_name: ${dataset_name} diff --git a/config/datamodule/nablaDFT_hamiltonian.yaml b/config/datamodule/nablaDFT_hamiltonian.yaml new file mode 100644 index 0000000..3f55569 --- /dev/null +++ b/config/datamodule/nablaDFT_hamiltonian.yaml @@ -0,0 +1,13 @@ +# Dataset config for torch geometric nablaDFT +_target_: nablaDFT.dataset.PyGHamiltonianDataModule + +root: ./datasets/nablaDFT/hamiltonian +dataset_name: ${dataset_name} +train_size: 0.9 +val_size: 0.1 +seed: 23 +# Dataloader args +batch_size: 8 +num_workers: 4 +persistent_workers: True +pin_memory: True diff --git a/config/datamodule/nablaDFT_hamiltonian_test.yaml b/config/datamodule/nablaDFT_hamiltonian_test.yaml new file mode 100644 index 0000000..47b4a66 --- /dev/null +++ b/config/datamodule/nablaDFT_hamiltonian_test.yaml @@ -0,0 +1,11 @@ +# Dataset config for torch geometric nablaDFT +_target_: nablaDFT.dataset.PyGHamiltonianDataModule + +root: ./datasets/nablaDFT/hamiltonian +dataset_name: ${dataset_name} +seed: 23 +# Dataloader args +batch_size: 2 +num_workers: 4 +persistent_workers: True +pin_memory: True \ No newline at end of file diff --git a/config/datamodule/nablaDFT_pyg.yaml b/config/datamodule/nablaDFT_pyg.yaml index 893a8ab..5c31eef 100644 --- a/config/datamodule/nablaDFT_pyg.yaml +++ b/config/datamodule/nablaDFT_pyg.yaml @@ -1,12 +1,14 @@ # Dataset config for torch geometric nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.PyGNablaDFTDataModule -type_of_nn: PyG -split: ${job_type} root: ./datasets/nablaDFT/train dataset_name: ${dataset_name} train_size: 0.9 val_size: 0.1 +seed: 23 +# Dataloader args batch_size: 32 num_workers: 8 -seed: 23 +persistent_workers: True +pin_memory: True + diff --git a/config/datamodule/nablaDFT_pyg_test.yaml b/config/datamodule/nablaDFT_pyg_test.yaml index 50ba6e8..5c81a7d 100644 --- a/config/datamodule/nablaDFT_pyg_test.yaml +++ b/config/datamodule/nablaDFT_pyg_test.yaml @@ -1,9 +1,10 @@ # Dataset config for torch geometric nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.PyGNablaDFTDataModule -type_of_nn: PyG -split: ${job_type} root: ./datasets/nablaDFT/test dataset_name: ${dataset_name} +# Dataloader args batch_size: 32 -num_workers: 12 \ No newline at end of file +num_workers: 12 +persistent_workers: True +pin_memory: True \ No newline at end of file diff --git a/config/dimenetplusplus.yaml b/config/dimenetplusplus.yaml index 930cf80..6ad943f 100644 --- a/config/dimenetplusplus.yaml +++ b/config/dimenetplusplus.yaml @@ -1,6 +1,6 @@ # Global variables name: DimeNet++ -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 job_type: train pretrained: False diff --git a/config/equiformer_v2_oc20.yaml b/config/equiformer_v2_oc20.yaml index d9712bb..5c7ed16 100644 --- a/config/equiformer_v2_oc20.yaml +++ b/config/equiformer_v2_oc20.yaml @@ -1,6 +1,6 @@ # Global variables -name: Equiformer_v2_OC20 -dataset_name: dataset_train_2k +name: Equiformer_v2 +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 0 job_type: train diff --git a/config/escn-oc.yaml b/config/escn-oc.yaml index 14c6933..c975d16 100644 --- a/config/escn-oc.yaml +++ b/config/escn-oc.yaml @@ -1,6 +1,6 @@ # Global variables name: ESCN-OC -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 0 job_type: train diff --git a/config/gemnet-oc.yaml b/config/gemnet-oc.yaml index e4fbf52..0b900f5 100644 --- a/config/gemnet-oc.yaml +++ b/config/gemnet-oc.yaml @@ -1,6 +1,6 @@ # Global variables name: GemNet-OC -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 0 job_type: train diff --git a/config/gemnet-oc_test.yaml b/config/gemnet-oc_test.yaml index 29ef46a..0cac57e 100644 --- a/config/gemnet-oc_test.yaml +++ b/config/gemnet-oc_test.yaml @@ -1,6 +1,6 @@ # Global variables name: GemNet-OC -dataset_name: dataset_test_conformations_2k +dataset_name: dataset_test_conformations_tiny max_steps: 1000000 warmup_steps: 0 job_type: test diff --git a/config/graphormer3d.yaml b/config/graphormer3d.yaml index 05f6d70..01bba18 100644 --- a/config/graphormer3d.yaml +++ b/config/graphormer3d.yaml @@ -1,6 +1,6 @@ # Global variables -name: Graphormer3D-half -dataset_name: dataset_train_2k +name: Graphormer3D-small +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 60000 job_type: train diff --git a/config/model/graphormer3d-base.yaml b/config/model/graphormer3d-base.yaml index 40dd41b..dbf9cb9 100644 --- a/config/model/graphormer3d-base.yaml +++ b/config/model/graphormer3d-base.yaml @@ -20,7 +20,7 @@ optimizer: lr: 3e-4 lr_scheduler: - _target_: nablaDFT.graphormer.schedulers.get_linear_schedule_with_warmup + _target_: nablaDFT.schedulers.get_linear_schedule_with_warmup _partial_: true num_warmup_steps: ${warmup_steps} num_training_steps: ${max_steps} diff --git a/config/model/graphormer3d-half.yaml b/config/model/graphormer3d-half.yaml index dbbebaa..a440b66 100644 --- a/config/model/graphormer3d-half.yaml +++ b/config/model/graphormer3d-half.yaml @@ -20,7 +20,7 @@ optimizer: lr: 3e-4 lr_scheduler: - _target_: nablaDFT.graphormer.schedulers.get_linear_schedule_with_warmup + _target_: nablaDFT.schedulers.get_linear_schedule_with_warmup _partial_: true num_warmup_steps: ${warmup_steps} num_training_steps: ${max_steps} diff --git a/config/model/graphormer3d-large.yaml b/config/model/graphormer3d-large.yaml index cc989cc..f7bf26a 100644 --- a/config/model/graphormer3d-large.yaml +++ b/config/model/graphormer3d-large.yaml @@ -22,7 +22,7 @@ optimizer: weight_decay: 1e-3 lr_scheduler: - _target_: nablaDFT.graphormer.schedulers.get_linear_schedule_with_warmup + _target_: nablaDFT.schedulers.get_linear_schedule_with_warmup _partial_: true num_warmup_steps: ${warmup_steps} num_training_steps: ${max_steps} diff --git a/config/model/painn-oc.yaml b/config/model/painn-oc.yaml index 5c21880..c13ff17 100644 --- a/config/model/painn-oc.yaml +++ b/config/model/painn-oc.yaml @@ -3,11 +3,11 @@ _target_: nablaDFT.painn_pyg.PaiNNLightning model_name: "PAINN-OC" net: _target_: nablaDFT.painn_pyg.PaiNN - hidden_channels: 512 + hidden_channels: 128 num_layers: 6 - num_rbf: 128 - cutoff: 12.0 - max_neighbors: 50 + num_rbf: 100 + cutoff: 5.0 + max_neighbors: 100 rbf: name: 'gaussian' envelope: @@ -17,22 +17,21 @@ net: direct_forces: false use_pbc: false otf_graph: true - num_elements: 65 + num_elements: 100 optimizer: _target_: torch.optim.AdamW _partial_: true - amsgrad: true - betas: [0.9, 0.95] - lr: 1e-3 + lr: 5e-4 weight_decay: 0 lr_scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau _partial_: true factor: 0.8 - patience: 10 + patience: 100 + min_lr: 1e-6 losses: energy: @@ -41,12 +40,7 @@ losses: _target_: nablaDFT.gemnet_oc.loss.L2Loss loss_coefs: energy: 1.0 - forces: 100.0 - -ema: - _target_: torch_ema.ExponentialMovingAverage - _partial_: true - decay: 0.9999 + forces: 1.0 metric: _target_: torchmetrics.MultitaskWrapper @@ -55,4 +49,4 @@ metric: energy: _target_: torchmetrics.MeanAbsoluteError forces: - _target_: torchmetrics.MeanAbsoluteError + _target_: torchmetrics.MeanAbsoluteError \ No newline at end of file diff --git a/config/model/qhnet.yaml b/config/model/qhnet.yaml new file mode 100644 index 0000000..0ef7ed4 --- /dev/null +++ b/config/model/qhnet.yaml @@ -0,0 +1,54 @@ +_target_: nablaDFT.qhnet.QHNetLightning + +model_name: "QHNet" +net: + _target_: nablaDFT.qhnet.QHNet + _convert_: partial + sh_lmax: 4 + hidden_size: 128 + bottle_hidden_size: 32 + num_gnn_layers: 5 + max_radius: 12 + num_nodes: 83 + radius_embed_dim: 32 + orbitals: + 1: [0, 0, 1] + 6: [0, 0, 0, 1, 1, 2] + 7: [0, 0, 0, 1, 1, 2] + 8: [0, 0, 0, 1, 1, 2] + 9: [0, 0, 0, 1, 1, 2] + 16: [0, 0, 0, 0, 1, 1, 1, 2] + 17: [0, 0, 0, 0, 1, 1, 1, 2] + 35: [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2] + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + amsgrad: true + betas: [0.9, 0.95] + lr: 5e-4 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + factor: 0.8 + patience: 10 + min_lr: 1e-6 + +losses: + hamiltonian: + _target_: nablaDFT.qhnet.loss.HamiltonianLoss +loss_coefs: + hamiltonian: 1.0 + +metric: + _target_: torchmetrics.MultitaskWrapper + _convert_: all + task_metrics: + hamiltonian: + _target_: nablaDFT.qhnet.MaskedMeanAbsoluteError + +ema: + _target_: torch_ema.ExponentialMovingAverage + _partial_: true + decay: 0.9999 \ No newline at end of file diff --git a/config/painn-oc.yaml b/config/painn-oc.yaml index 1d06933..0df1706 100644 --- a/config/painn-oc.yaml +++ b/config/painn-oc.yaml @@ -1,8 +1,8 @@ # Global variables name: PaiNN-OC -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 -job_type: test +job_type: train pretrained: False ckpt_path: null # path to checkpoint for training resume or test run diff --git a/config/painn.yaml b/config/painn.yaml index 47ce455..9a10dcb 100644 --- a/config/painn.yaml +++ b/config/painn.yaml @@ -1,6 +1,6 @@ # Global variables name: PaiNN -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 job_type: train pretrained: False diff --git a/config/qhnet.yaml b/config/qhnet.yaml new file mode 100644 index 0000000..55a4b78 --- /dev/null +++ b/config/qhnet.yaml @@ -0,0 +1,26 @@ +# Global variables +name: QHNet +dataset_name: dataset_train_tiny +max_steps: 1000000 +warmup_steps: 0 +job_type: train +pretrained: False +ckpt_path: null # path to checkpoint for training resume or test run + +# configs +defaults: + - _self_ + - datamodule: nablaDFT_hamiltonian.yaml # dataset config + - model: qhnet.yaml # model config + - callbacks: default.yaml # pl callbacks config + - loggers: wandb.yaml # pl loggers config + - trainer: train.yaml # trainer config + +# need this to set working dir as current dir +hydra: + output_subdir: null + run: + dir: . +original_work_dir: ${hydra:runtime.cwd} + +seed: 23 \ No newline at end of file diff --git a/config/schnet.yaml b/config/schnet.yaml index 50205e0..a28b5a2 100644 --- a/config/schnet.yaml +++ b/config/schnet.yaml @@ -1,6 +1,6 @@ # Global variables name: SchNet -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 job_type: train pretrained: False diff --git a/config/schnet_test.yaml b/config/schnet_test.yaml index a804365..79ef52e 100644 --- a/config/schnet_test.yaml +++ b/config/schnet_test.yaml @@ -1,6 +1,6 @@ # Global variables name: SchNet -dataset_name: dataset_test_conformations_2k +dataset_name: dataset_test_conformations_tiny max_steps: 1000000 job_type: test pretrained: False diff --git a/config/trainer/train.yaml b/config/trainer/train.yaml index 093f8f4..2dffd85 100644 --- a/config/trainer/train.yaml +++ b/config/trainer/train.yaml @@ -5,7 +5,8 @@ accelerator: "gpu" devices: [0] strategy: _target_: pytorch_lightning.strategies.ddp.DDPStrategy - + # QHNet has unused params, uncomment line for train + # find_unused_parameters: True max_steps: ${max_steps} # example of additional arguments for trainer diff --git a/nablaDFT/__init__.py b/nablaDFT/__init__.py index bc469f0..c23f60a 100644 --- a/nablaDFT/__init__.py +++ b/nablaDFT/__init__.py @@ -7,3 +7,5 @@ from . import escn from . import ase_model from . import painn_pyg + +from . import schedulers diff --git a/nablaDFT/ase_model/task.py b/nablaDFT/ase_model/task.py index 2e90685..2aecc23 100644 --- a/nablaDFT/ase_model/task.py +++ b/nablaDFT/ase_model/task.py @@ -1,11 +1,9 @@ from typing import Any, Dict, List, Optional, Type -import pytorch_lightning as pl import schnetpack as spk import torch from schnetpack.model.base import AtomisticModel from schnetpack.task import UnsupervisedModelOutput -from torch import nn class AtomisticTaskFixed(spk.task.AtomisticTask): @@ -21,20 +19,18 @@ def __init__( scheduler_monitor: Optional[str] = None, warmup_steps: int = 0, ): - pl.LightningModule.__init__(self) + + super(AtomisticTaskFixed, self).__init__( + model=model, + outputs=outputs, + optimizer_cls=optimizer_cls, + optimizer_args=optimizer_args, + scheduler_cls=scheduler_cls, + scheduler_args=scheduler_args, + scheduler_monitor=scheduler_monitor, + warmup_steps=warmup_steps + ) self.model_name = model_name - self.model = model - self.optimizer_cls = optimizer_cls - self.optimizer_kwargs = optimizer_args - self.scheduler_cls = scheduler_cls - self.scheduler_kwargs = scheduler_args - self.schedule_monitor = scheduler_monitor - self.outputs = nn.ModuleList(outputs) - - self.grad_enabled = len(self.model.required_derivatives) > 0 - self.lr = optimizer_args["lr"] - self.warmup_steps = warmup_steps - self.save_hyperparameters(ignore=["model"]) def predict_step(self, batch, batch_idx): torch.set_grad_enabled(self.grad_enabled) @@ -55,10 +51,11 @@ def predict_step(self, batch, batch_idx): def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # reshape model.postprocessors (AddOffsets) # otherwise during test error will occur - checkpoint["state_dict"]["model.postprocessors.0.mean"] = checkpoint[ - "state_dict" - ]["model.postprocessors.0.mean"].reshape(1) - + if checkpoint["state_dict"].get("model.postprocessors.0.mean", None): + checkpoint["state_dict"]["model.postprocessors.0.mean"] = checkpoint[ + "state_dict" + ]["model.postprocessors.0.mean"].reshape(1) + # override base class method def predict_without_postprocessing(self, batch): pred = self(batch) diff --git a/nablaDFT/dataset/__init__.py b/nablaDFT/dataset/__init__.py index 04fa2e4..28fdb0a 100644 --- a/nablaDFT/dataset/__init__.py +++ b/nablaDFT/dataset/__init__.py @@ -1,2 +1,13 @@ -from .nablaDFT_dataset import NablaDFT -from .hamiltonian_dataset import HamiltonianDataset \ No newline at end of file +from .nablaDFT_dataset import ( + ASENablaDFT, + PyGNablaDFTDataModule, + PyGHamiltonianDataModule, +) # PyTorch Lightning interfaces for datasets +from .hamiltonian_dataset import ( + HamiltonianDataset, + HamiltonianDatabase, +) # database interface for Hamiltonian datasets +from .pyg_datasets import ( + PyGNablaDFT, + PyGHamiltonianNablaDFT, +) # PyTorch Geometric interfaces for datasets diff --git a/nablaDFT/dataset/atoms_datamodule.py b/nablaDFT/dataset/atoms_datamodule.py index b3130a5..88c58ee 100644 --- a/nablaDFT/dataset/atoms_datamodule.py +++ b/nablaDFT/dataset/atoms_datamodule.py @@ -1,8 +1,6 @@ -# Overrided AtomsDataModule from SchNetPack -from copy import copy +"""Overrided AtomsDataModule from SchNetPack""" from typing import List, Dict, Optional, Union -import pytorch_lightning as pl import torch from schnetpack.data import ( @@ -10,14 +8,17 @@ load_dataset, AtomsLoader, SplittingStrategy, - AtomsDataModule + AtomsDataModule, ) class AtomsDataModule(AtomsDataModule): """ - Overrided AtomsDataModule from SchNetPack with additional - methods for prediction task. + Overrided AtomsDataModule from SchNetPack with predict_dataloader + method and overrided setup for prediction task. + + Args: + split (str): string contains type of task/dataset, must be one of ['train', 'test', 'predict'] """ def __init__( @@ -47,44 +48,6 @@ def __init__( splitting: Optional[SplittingStrategy] = None, pin_memory: Optional[bool] = False, ): - """ - Args: - split: string contains type of task/dataset, must be one of ['train', 'test', 'predict'] - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples (absolute or relative) - num_val: number of validation examples (absolute or relative) - num_test: number of test examples (absolute or relative) - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then - batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then - batch_size. - transforms: Preprocessing transform applied to each system separately before - batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers - (overrides num_workers). - num_test_workers: Number of test data loader workers - (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string - (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string - (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. to a local file - system for faster performance. - cleanup_workdir_stage: Determines after which stage to remove the data - workdir - splitting: Method to generate train/validation/test partitions - (default: RandomSplit) - pin_memory: If true, pin memory of loaded data to GPU. Default: Will be - set to true, when GPUs are used. - """ super().__init__( datapath=datapath, batch_size=batch_size, @@ -108,12 +71,17 @@ def __init__( data_workdir=data_workdir, cleanup_workdir_stage=cleanup_workdir_stage, splitting=splitting, - pin_memory=pin_memory + pin_memory=pin_memory, ) self.split = split self._predict_dataloader = None - + def setup(self, stage: Optional[str] = None): + """Overrided method from original AtomsDataModule class + + Args: + stage (str): trainer stage, must be one of ['fit', 'test', 'predict'] + """ # check whether data needs to be copied if self.data_workdir is None: datapath = self.datapath @@ -143,12 +111,13 @@ def setup(self, stage: Optional[str] = None): self._setup_transforms() def predict_dataloader(self) -> AtomsLoader: + """Describes predict dataloader, used for prediction task""" if self._predict_dataloader is None: self._predict_dataloader = AtomsLoader( self.test_dataset, batch_size=self.test_batch_size, num_workers=self.num_test_workers, pin_memory=self._pin_memory, - shuffle=False + shuffle=False, ) return self._predict_dataloader diff --git a/nablaDFT/dataset/hamiltonian_dataset.py b/nablaDFT/dataset/hamiltonian_dataset.py index 52eafd7..1361ad7 100644 --- a/nablaDFT/dataset/hamiltonian_dataset.py +++ b/nablaDFT/dataset/hamiltonian_dataset.py @@ -22,7 +22,11 @@ class HamiltonianDatabase: C (Norb, Norb) core hamiltonian in atomic units moses_id () (int) molecule id in MOSES dataset conformer_id () (int) conformation id + + Args: + filename (str): path to database. """ + def __init__(self, filename, flags=apsw.SQLITE_OPEN_READONLY): self.db = filename self.connections = {} # allow multiple connections (needed for multi-threading) @@ -38,12 +42,24 @@ def __getitem__(self, idx): data = cursor.execute( """SELECT * FROM data WHERE id IN (""" + str(idx)[1:-1] + ")" ).fetchall() - return [self._unpack_data_tuple(i) for i in data] + ids = cursor.execute( + """SELECT * FROM dataset_ids WHERE id IN (""" + str(idx)[1:-1] + ")" + ).fetchall() + moses_ids, conformer_ids = [i[1] for i in ids], [i[2] for i in ids] + unpacked_data = [ + (*self._unpack_data_tuple(chunk), moses_ids[i], conformer_ids[i]) + for i, chunk in enumerate(data) + ] + return unpacked_data else: data = cursor.execute( """SELECT * FROM data WHERE id=""" + str(idx) ).fetchone() - return self._unpack_data_tuple(data) + ids = cursor.execute( + """SELECT * FROM dataset_ids WHERE id=""" + str(idx) + ).fetchall()[0] + moses_id, conformer_id = ids[1], ids[2] + return (*self._unpack_data_tuple(data), moses_id, conformer_id) def _unpack_data_tuple(self, data): N = ( @@ -244,8 +260,21 @@ def Z(self): return self._deblob(data[2], dtype=np.int32, shape=(N,)) - class HamiltonianDataset(torch.utils.data.Dataset): + """PyTorch interface for nablaDFT Hamiltonian databases. + + Collates hamiltonian, overlap and core hamiltonian matrices + from batch into block diagonal matrix. + + Args: + - filepath (str): path to database. + - max_batch_orbitals (int): maximum number of orbitals in one batch. + - max_batch_atoms (int): maximum number of atoms in one batch.. + - max_squares (int): maximum size of block diagonal matrix in one batch. + - subset() path to saved numpy array with selected indexes. + - dtype (torch.dtype): defines torch.dtype for tensors. + """ + def __init__( self, filepath, @@ -302,7 +331,6 @@ def collate_fn(self, batch, return_filtered=False): tuple((int(z), int(l)) for l in self.database.get_orbitals(z)) ) local_orbitals_number += sum(2 * l + 1 for _, l in local_orbitals[-1]) - # print (local_orbitals_number, orbitals_number, len(local_orbitals), len(orbitals)) if ( orbitals_number + local_orbitals_number > self.max_batch_orbitals or len(local_orbitals) + len(orbitals) > self.max_batch_atoms diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index fcb3b06..c84044d 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -1,27 +1,42 @@ +"""Module defines Pytorch Lightning DataModule interfaces for various NablaDFT datasets""" + import json import os -from typing import Optional, List +from typing import Optional from urllib import request as request import numpy as np import torch from ase.db import connect -from torch.utils.data import Subset -from torch_geometric.data.lightning import LightningDataset -from torch_geometric.data import InMemoryDataset, Data +from pytorch_lightning import LightningDataModule +from torch.utils.data import random_split +from torch_geometric.loader import DataLoader from schnetpack.data import AtomsDataFormat, load_dataset import nablaDFT from .atoms_datamodule import AtomsDataModule - -from .hamiltonian_dataset import HamiltonianDatabase, HamiltonianDataset +from .pyg_datasets import PyGNablaDFT, PyGHamiltonianNablaDFT class ASENablaDFT(AtomsDataModule): + """PytorchLightning interface for nablaDFT ASE datasets. + + Args: + split (str): type of split, must be one of ['train', 'test', 'predict']. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + datapath (str): path to existing dataset directory or location for download. + train_ratio (float): dataset part used for training. + val_ratio (float): dataset part used for validation. + test_ratio (float): dataset part used for test or prediction. + train_transforms (Callable): data transform, called for every sample in training dataset. + val_transforms (Callable): data transform, called for every sample in validation dataset. + test_transforms (Callable): data transform, called for every sample in test dataset. + """ + def __init__( self, split: str, - dataset_name: str = "dataset_train_2k", + dataset_name: str = "dataset_train_tiny", datapath: str = "database", data_workdir: Optional[str] = "logs", batch_size: int = 2000, @@ -34,6 +49,7 @@ def __init__( format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, **kwargs, ): + """""" super().__init__( split=split, datapath=datapath, @@ -60,7 +76,7 @@ def prepare_data(self): if self.split == "predict" and not exists: raise FileNotFoundError("Specified dataset not found") elif self.split != "predict" and not exists: - with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json") as f: + with open(nablaDFT.__path__[0] + "/links/energy_databases.json") as f: data = json.load(f) if self.train_ratio != 0: url = data["train_databases"][self.dataset_name] @@ -91,258 +107,132 @@ def prepare_data(self): self.dataset = load_dataset(self.datapath, self.format) -class HamiltonianNablaDFT(HamiltonianDataset): - def __init__( - self, - datapath="database", - dataset_name="dataset_train_2k", - max_batch_orbitals=1200, - max_batch_atoms=150, - max_squares=4802, - subset=None, - dtype=torch.float32, - ): - self.dtype = dtype - if not os.path.exists(datapath): - os.makedirs(datapath) - with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: - data = json.load(f) - url = data["train_databases"][dataset_name] - filepath = datapath + "/" + dataset_name + ".db" - if not os.path.exists(filepath): - request.urlretrieve(url, filepath) - self.database = HamiltonianDatabase(filepath) - max_orbitals = [] - for z in self.database.Z: - max_orbitals.append( - tuple((int(z), int(l)) for l in self.database.get_orbitals(z)) - ) - max_orbitals = tuple(max_orbitals) - self.max_orbitals = max_orbitals - self.max_batch_orbitals = max_batch_orbitals - self.max_batch_atoms = max_batch_atoms - self.max_squares = max_squares - self.subset = None - if subset: - self.subset = np.load(subset) - +class PyGDataModule(LightningDataModule): + """Parent class which encapsulates PyG dataset to use with Pytorch Lightning Trainer. + In order to add new dataset variant, define children class with setup() method. -class PyGNablaDFT(InMemoryDataset): - """Dataset adapter for ASE2PyG conversion. - Based on https://github.com/atomicarchitects/equiformer/blob/master/datasets/pyg/md17.py + Args: + root (str): path to directory with r'raw/' subfolder with existing dataset or download location. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + train_size (float): part of dataset used for training, must be in [0, 1]. + val_size (float): part of dataset used for validation, must be in [0, 1]. + seed (int): seed number, used for torch.Generator object during train/val split. + **kwargs: arguments for torch.DataLoader. """ - db_suffix = ".db" - - @property - def raw_file_names(self) -> List[str]: - return [(self.dataset_name + self.db_suffix)] - - @property - def processed_file_names(self) -> str: - return f"{self.dataset_name}_{self.split}.pt" - def __init__( self, - datapath: str = "database", - dataset_name: str = "dataset_train_2k", - split: str = "train", - transform=None, - pre_transform=None, - ): + root: str, + dataset_name: str, + train_size: float = 0.9, + val_size: float = 0.1, + seed: int = 23, + **kwargs, + ) -> None: + super().__init__() + self.dataset_train = None + self.dataset_val = None + self.dataset_test = None + self.dataset_predict = None + + self.root = root self.dataset_name = dataset_name - self.datapath = datapath - self.split = split - self.data_all, self.slices_all = [], [] - self.offsets = [0] - super(PyGNablaDFT, self).__init__(datapath, transform, pre_transform) - - for path in self.processed_paths: - data, slices = torch.load(path) - self.data_all.append(data) - self.slices_all.append(slices) - self.offsets.append( - len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1] - ) - - def len(self) -> int: - return sum( - len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all - ) - - def get(self, idx): - data_idx = 0 - while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]: - data_idx += 1 - self.data = self.data_all[data_idx] - self.slices = self.slices_all[data_idx] - return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx]) - - def download(self) -> None: - with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f: - data = json.load(f) - url = data[f"{self.split}_databases"][self.dataset_name] - request.urlretrieve(url, self.raw_paths[0]) - - def process(self) -> None: - db = connect(self.raw_paths[0]) - samples = [] - for db_row in db.select(): - z = torch.from_numpy(db_row.numbers).long() - positions = torch.from_numpy(db_row.positions).float() - y = torch.from_numpy(np.array(db_row.data["energy"])).float() - # TODO: temp workaround for dataset w/o forces - forces = db_row.data.get("forces", None) - if forces is not None: - forces = torch.from_numpy(np.array(forces)).float() - samples.append(Data(z=z, pos=positions, y=y, forces=forces)) - - if self.pre_filter is not None: - samples = [data for data in samples if self.pre_filter(data)] - - if self.pre_transform is not None: - samples = [self.pre_transform(data) for data in samples] - - data, slices = self.collate(samples) - torch.save((data, slices), self.processed_paths[0]) - - -# From https://github.com/torchmd/torchmd-net/blob/72cdc6f077b2b880540126085c3ed59ba1b6d7e0/torchmdnet/utils.py#L54 -def train_val_split(dset_len, train_size, val_size, seed, order=None): - assert (train_size is None) + ( - val_size is None - ) <= 1, "Only one of train_size, val_size is allowed to be None." - is_float = ( - isinstance(train_size, float), - isinstance(val_size, float), - ) - - train_size = round(dset_len * train_size) if is_float[0] else train_size - val_size = round(dset_len * val_size) if is_float[1] else val_size - - if train_size is None: - train_size = dset_len - val_size - elif val_size is None: - val_size = dset_len - train_size - - if train_size + val_size > dset_len: - if is_float[1]: - val_size -= 1 - elif is_float[0]: - train_size -= 1 - - assert train_size >= 0 and val_size >= 0, ( - f"One of training ({train_size}), validation ({val_size})" - f" splits ended up with a negative size." - ) - - total = train_size + val_size - - assert dset_len >= total, ( - f"The dataset ({dset_len}) is smaller than the " - f"combined split sizes ({total})." - ) - if total < dset_len: - print(f"{dset_len - total} samples were excluded from the dataset") - - idxs = np.arange(dset_len, dtype=int) - if order is None: - idxs = np.random.default_rng(seed).permutation(idxs) - - idx_train = idxs[:train_size] - idx_val = idxs[train_size:total] - - if order is not None: - idx_train = [order[i] for i in idx_train] - idx_val = [order[i] for i in idx_val] - - return np.array(idx_train), np.array(idx_val) - - -# From: https://github.com/torchmd/torchmd-net/blob/72cdc6f077b2b880540126085c3ed59ba1b6d7e0/torchmdnet/utils.py#L112 -def make_splits( - dataset_len, - train_size, - val_size, - seed, - filename=None, # path to save split index - splits=None, - order=None, -): - if splits is not None: - splits = np.load(splits) - idx_train = splits["idx_train"] - idx_val = splits["idx_val"] - else: - idx_train, idx_val = train_val_split( - dataset_len, train_size, val_size, seed, order - ) - - if filename is not None: - np.savez(filename, idx_train=idx_train, idx_val=idx_val) - - return ( - torch.from_numpy(idx_train), - torch.from_numpy(idx_val), - ) + self.seed = seed + self.sizes = [train_size, val_size] + dataloader_keys = [ + "batch_size", "num_workers", + "pin_memory", "persistent_workers" + ] + self.dataloader_kwargs = {} + for key in dataloader_keys: + val = kwargs.get(key, None) + self.dataloader_kwargs[key] = val + if val is not None: + del kwargs[key] + self.kwargs = kwargs + + def dataloader(self, dataset, **kwargs): + return DataLoader(dataset, **kwargs) + + def setup(self, stage: str) -> None: + raise NotImplementedError + + def train_dataloader(self): + return self.dataloader(self.dataset_train, shuffle=True, **self.dataloader_kwargs) + + def val_dataloader(self): + return self.dataloader(self.dataset_val, shuffle=False, **self.dataloader_kwargs) + + def test_dataloader(self): + return self.dataloader(self.dataset_test, shuffle=False, **self.dataloader_kwargs) + + def predict_dataloader(self): + return self.dataloader(self.dataset_predict, shuffle=False, **self.dataloader_kwargs) + + +class PyGHamiltonianDataModule(PyGDataModule): + """DataModule for Hamiltonian nablaDFT dataset, subclass of PyGDataModule. + + Keyword arguments: + hamiltonian (bool): retrieve from database molecule's full hamiltonian matrix. True by default. + include_overlap (bool): retrieve from database molecule's overlab matrix. + include_core (bool): retrieve from databaes molecule's core hamiltonian matrix. + **kwargs: arguments for torch.DataLoader and PyGDataModule instance. See PyGDatamodule docs. + """ + def __init__( + self, + root: str, + dataset_name: str, + train_size: float = None, + val_size: float = None, + **kwargs, + ) -> None: + super().__init__(root, dataset_name, train_size, val_size, **kwargs) -def get_PyG_nablaDFT_datasets( - root: str, - split: str, - dataset_name: str, - train_size: float = None, - val_size: float = None, - batch_size: int = None, - num_workers: int = None, - seed: int = None, -): - dataset = PyGNablaDFT(root, dataset_name, split=split) - if split == "train": - idx_train, idx_val = make_splits( - len(dataset), - train_size, - val_size, - seed, - filename=os.path.join(root, "splits.npz"), - splits=None, - ) - train_dataset = Subset(dataset, idx_train) - val_dataset = Subset(dataset, idx_val) - test_dataset = None - pred_dataset = None - else: - train_dataset = None - val_dataset = None - if split == "predict": - pred_dataset = dataset - test_dataset = None - else: - test_dataset = dataset - pred_dataset = None + def setup(self, stage: str) -> None: + if stage == "fit": + dataset = PyGHamiltonianNablaDFT( + self.root, self.dataset_name, "train", **self.kwargs + ) + self.dataset_train, self.dataset_val = random_split( + dataset, self.sizes, generator=torch.Generator().manual_seed(self.seed) + ) + elif stage == "test": + self.dataset_test = PyGHamiltonianNablaDFT( + self.root, self.dataset_name, "test", **self.kwargs + ) + elif stage == "predict": + self.dataset_predict = PyGHamiltonianNablaDFT( + self.root, self.dataset_name, "predict", **self.kwargs + ) - pl_datamodule = LightningDataset( - train_dataset=train_dataset, - val_dataset=val_dataset, - test_dataset=test_dataset, - pred_dataset=pred_dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=True, - persistent_workers=True, - ) - return pl_datamodule +class PyGNablaDFTDataModule(PyGDataModule): + """DataModule for nablaDFT dataset, subclass of PyGDataModule. + See PyGDatamodule doc.""" -class NablaDFT: - def __init__(self, type_of_nn, *args, **kwargs): - valid = {"ASE", "Hamiltonian", "PyG"} - if type_of_nn not in valid: - raise ValueError("results: type of nn must be one of %r." % valid) - self.type_of_nn = type_of_nn - if self.type_of_nn == "ASE": - self.dataset = ASENablaDFT(*args, **kwargs) - elif self.type_of_nn == "Hamiltonian": - self.dataset = HamiltonianNablaDFT(*args, **kwargs) - else: - self.dataset = get_PyG_nablaDFT_datasets(*args, **kwargs) + def __init__( + self, + root: str, + dataset_name: str, + train_size: float = None, + val_size: float = None, + **kwargs, + ) -> None: + super().__init__(root, dataset_name, train_size, val_size, **kwargs) + + def setup(self, stage: str) -> None: + if stage == "fit": + dataset = PyGNablaDFT(self.root, self.dataset_name, "train", **self.kwargs) + self.dataset_train, self.dataset_val = random_split( + dataset, self.sizes, generator=torch.Generator().manual_seed(self.seed) + ) + elif stage == "test": + self.dataset_test = PyGNablaDFT( + self.root, self.dataset_name, "test", **self.kwargs + ) + elif stage == "predict": + self.dataset_predict = PyGNablaDFT( + self.root, self.dataset_name, "predict", **self.kwargs + ) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py new file mode 100644 index 0000000..0932bf6 --- /dev/null +++ b/nablaDFT/dataset/pyg_datasets.py @@ -0,0 +1,235 @@ +"""Module describes PyTorch Geometric interfaces for various NablaDFT datasets""" + +import json +import os +import logging +from typing import List, Callable +from urllib import request as request + +from tqdm import tqdm +import numpy as np +import torch +from ase.db import connect +from torch_geometric.data import InMemoryDataset, Data, Dataset + +import nablaDFT +from .hamiltonian_dataset import HamiltonianDatabase +from nablaDFT.utils import tqdm_download_hook, get_file_size + +logger = logging.getLogger(__name__) + + +class PyGNablaDFT(InMemoryDataset): + """Pytorch Geometric interface for nablaDFT datasets. + Based on https://github.com/atomicarchitects/equiformer/blob/master/datasets/pyg/md17.py + + Args: + datapath (str): path to existing dataset directory or location for download. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + split (str): type of split, must be one of ['train', 'test', 'predict']. + transform (Callable): callable data transform, called on every access to element. + pre_transform (Callable): callable data transform, called on every access to element. + """ + + db_suffix = ".db" + + @property + def raw_file_names(self) -> List[str]: + return [(self.dataset_name + self.db_suffix)] + + @property + def processed_file_names(self) -> str: + return f"{self.dataset_name}_{self.split}.pt" + + def __init__( + self, + datapath: str = "database", + dataset_name: str = "dataset_train_tiny", + split: str = "train", + transform: Callable = None, + pre_transform: Callable = None, + ): + self.dataset_name = dataset_name + self.datapath = datapath + self.split = split + self.data_all, self.slices_all = [], [] + self.offsets = [0] + super(PyGNablaDFT, self).__init__(datapath, transform, pre_transform) + + for path in self.processed_paths: + data, slices = torch.load(path) + self.data_all.append(data) + self.slices_all.append(slices) + self.offsets.append( + len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1] + ) + + def len(self) -> int: + return sum( + len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all + ) + + def get(self, idx): + data_idx = 0 + while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]: + data_idx += 1 + self.data = self.data_all[data_idx] + self.slices = self.slices_all[data_idx] + return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx]) + + def download(self) -> None: + with open(nablaDFT.__path__[0] + "/links/energy_databases.json", "r") as f: + data = json.load(f) + url = data[f"{self.split}_databases"][self.dataset_name] + file_size = get_file_size(url) + with tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + total=file_size, + desc=f"Downloading split: {self.dataset_name}", + ) as t: + request.urlretrieve( + url, self.raw_paths[0], reporthook=tqdm_download_hook(t) + ) + + def process(self) -> None: + db = connect(self.raw_paths[0]) + samples = [] + for db_row in tqdm(db.select(), total=len(db)): + z = torch.from_numpy(db_row.numbers.copy()).long() + positions = torch.from_numpy(db_row.positions.copy()).float() + y = torch.from_numpy(np.array(db_row.data["energy"]).copy()).float() + forces = torch.from_numpy(np.array(db_row.data["forces"]).copy()).float() + samples.append(Data(z=z, pos=positions, y=y, forces=forces)) + + if self.pre_filter is not None: + samples = [data for data in samples if self.pre_filter(data)] + + if self.pre_transform is not None: + samples = [self.pre_transform(data) for data in samples] + + data, slices = self.collate(samples) + torch.save((data, slices), self.processed_paths[0]) + logger.info(f"Saved processed dataset: {self.processed_paths[0]}") + + +class PyGHamiltonianNablaDFT(Dataset): + """Pytorch Geometric interface for nablaDFT Hamiltonian datasets. + + Args: + datapath (str): path to existing dataset directory or location for download. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + split (str): type of split, must be one of ['train', 'test', 'predict']. + include_hamiltonian (bool): if True, retrieves full Hamiltonian matrices from database. + include_overlap (bool): if True, retrieves overlap matrices from database. + include_core (bool): if True, retrieves core Hamiltonian matrices from database. + dtype (torch.dtype): defines torch.dtype for energy, positions and forces tensors. + transform (Callable): callable data transform, called on every access to element. + pre_transform (Callable): callable data transform, called on every access to element. + + Note: + Hamiltonian matrix for each molecule has different shape. PyTorch Geometric tries to concatenate + each torch.Tensor in batch, so in order to make batch from data we leave all hamiltonian matrices + in numpy array form. During train, these matrices will be yield as List[np.array]. + """ + + db_suffix = ".db" + + @property + def raw_file_names(self) -> List[str]: + return [(self.dataset_name + self.db_suffix)] + + @property + def processed_file_names(self) -> str: + return f"{self.dataset_name}_{self.split}.pt" + + def __init__( + self, + datapath: str = "database", + dataset_name: str = "dataset_train_tiny", + split: str = "train", + include_hamiltonian: bool = True, + include_overlap: bool = False, + include_core: bool = False, + dtype=torch.float32, + transform: Callable = None, + pre_transform: Callable = None, + ): + self.dataset_name = dataset_name + self.datapath = datapath + self.split = split + self.data_all, self.slices_all = [], [] + self.offsets = [0] + self.dtype = dtype + self.include_hamiltonian = include_hamiltonian + self.include_overlap = include_overlap + self.include_core = include_core + + super(PyGHamiltonianNablaDFT, self).__init__(datapath, transform, pre_transform) + + self.max_orbitals = self._get_max_orbitals(datapath, dataset_name) + self.db = HamiltonianDatabase(self.raw_paths[0]) + + def len(self) -> int: + return len(self.db) + + def get(self, idx): + data = self.db[idx] + z = torch.tensor(data[0].copy()).long() + positions = torch.tensor(data[1].copy()).to(self.dtype) + # see notes + hamiltonian = data[4].copy() + if self.include_overlap: + overlap = data[5].copy() + else: + overlap = None + if self.include_core: + core = data[6].copy() + else: + core = None + y = torch.from_numpy(data[2].copy()).to(self.dtype) + forces = torch.from_numpy(data[3].copy()).to(self.dtype) + data = Data( + z=z, + pos=positions, + y=y, + forces=forces, + hamiltonian=hamiltonian, + overlap=overlap, + core=core, + ) + if self.pre_transform is not None: + data = self.pre_transform(data) + return data + + def download(self) -> None: + with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: + data = json.load(f) + url = data[f"{self.split}_databases"][self.dataset_name] + file_size = get_file_size(url) + with tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + total=file_size, + desc=f"Downloading split: {self.dataset_name}", + ) as t: + request.urlretrieve( + url, self.raw_paths[0], reporthook=tqdm_download_hook(t) + ) + + def _get_max_orbitals(self, datapath, dataset_name): + db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix) + if not os.path.exists(db_path): + self.download() + database = HamiltonianDatabase(db_path) + max_orbitals = [] + for z in database.Z: + max_orbitals.append( + tuple((int(z), int(l)) for l in database.get_orbitals(z)) + ) + max_orbitals = tuple(max_orbitals) + return max_orbitals diff --git a/nablaDFT/dataset/split.py b/nablaDFT/dataset/split.py index df9e8c4..b049d45 100644 --- a/nablaDFT/dataset/split.py +++ b/nablaDFT/dataset/split.py @@ -40,6 +40,8 @@ def split(self, dataset, *split_sizes): class TestSplit(SplittingStrategy): """Splitting strategy that allows to put all dataset elements in test split without index permutations. + Used for schnetpack datasets to overcome limitation + when train and val split are empty. """ def split(self, dataset, *split_sizes): diff --git a/nablaDFT/dimenetplusplus/dimenetplusplus.py b/nablaDFT/dimenetplusplus/dimenetplusplus.py index b6dd06a..6b60464 100644 --- a/nablaDFT/dimenetplusplus/dimenetplusplus.py +++ b/nablaDFT/dimenetplusplus/dimenetplusplus.py @@ -144,13 +144,7 @@ def step( ) -> Union[Tuple[Any, Dict], Any]: predictions_energy, predictions_forces = self.forward(batch) loss_energy = self.loss(predictions_energy, batch.y) - # TODO: temp workaround - if hasattr(batch, "forces"): - loss_forces = self.loss(predictions_forces, batch.forces) - else: - loss_forces = torch.zeros(1).to(self.device) - predictions_forces = torch.zeros(1).to(self.device) - forces = torch.zeros(1).to(self.device) + loss_forces = self.loss(predictions_forces, batch.forces) loss = self.loss_forces_coef * loss_forces + self.loss_energy_coef * loss_energy if calculate_metrics: preds = {"energy": predictions_energy, "forces": predictions_forces} diff --git a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py index 1de12e2..fa24169 100644 --- a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py +++ b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py @@ -698,11 +698,7 @@ def step(self, batch, calculate_metrics: bool = False): y = batch.y # make dense batch from PyG batch energy_out, forces_out = self.net(batch) - # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": y, "forces": forces} loss = self._calculate_loss(preds, target) @@ -754,8 +750,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -795,7 +790,6 @@ def on_fit_start(self) -> None: self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -803,7 +797,11 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") - + + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() + def _calculate_loss(self, y_pred, y_true) -> float: total_loss = 0.0 for name, loss in self.hparams.losses.items(): @@ -837,12 +835,11 @@ def _reduce_metrics(self, step_type: str = "train"): def _check_devices(self): self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) def _instantiate_ema(self): if self.ema is not None: self.ema = self.ema(self.parameters()) + self.ema.to(self.device) def _get_batch_size(self, batch): """Function for batch size infer.""" diff --git a/nablaDFT/escn/escn.py b/nablaDFT/escn/escn.py index 72af98a..529e196 100644 --- a/nablaDFT/escn/escn.py +++ b/nablaDFT/escn/escn.py @@ -1074,11 +1074,7 @@ def step(self, batch, calculate_metrics: bool = False): y = batch.y # make dense batch from PyG batch energy_out, forces_out = self.net(batch) - # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": y, "forces": forces} loss = self._calculate_loss(preds, target) @@ -1130,8 +1126,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -1171,7 +1166,6 @@ def on_fit_start(self) -> None: self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -1179,6 +1173,10 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") + + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() def _calculate_loss(self, y_pred, y_true) -> float: total_loss = 0.0 @@ -1213,12 +1211,11 @@ def _reduce_metrics(self, step_type: str = "train"): def _check_devices(self): self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) def _instantiate_ema(self): if self.ema is not None: self.ema = self.ema(self.parameters()) + self.ema.to(self.device) def _get_batch_size(self, batch): """Function for batch size infer.""" diff --git a/nablaDFT/gemnet_oc/gemnet_oc.py b/nablaDFT/gemnet_oc/gemnet_oc.py index aaf050f..be99651 100644 --- a/nablaDFT/gemnet_oc/gemnet_oc.py +++ b/nablaDFT/gemnet_oc/gemnet_oc.py @@ -1431,11 +1431,7 @@ def forward(self, data: Data): def step(self, batch, calculate_metrics: bool = False): energy_out, forces_out = self.net(batch) - # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": batch.y, "forces": forces} loss = self._calculate_loss(preds, target) @@ -1487,8 +1483,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -1528,7 +1523,6 @@ def on_fit_start(self) -> None: self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -1537,6 +1531,10 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() + def _calculate_loss(self, y_pred, y_true) -> float: total_loss = 0.0 for name, loss in self.hparams.losses.items(): @@ -1569,13 +1567,12 @@ def _reduce_metrics(self, step_type: str = "train"): self.hparams.metric.reset() def _check_devices(self): - self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) + self.hparams.metric = self.hparams.metric.to(self.device) def _instantiate_ema(self): if self.ema is not None: self.ema = self.ema(self.parameters()) + self.ema.to(self.device) def _get_batch_size(self, batch): """Function for batch size infer.""" diff --git a/nablaDFT/graphormer/__init__.py b/nablaDFT/graphormer/__init__.py index 751acc5..1e95b58 100644 --- a/nablaDFT/graphormer/__init__.py +++ b/nablaDFT/graphormer/__init__.py @@ -1,2 +1 @@ from .graphormer_3d import Graphormer3DLightning, Graphormer3D -from .schedulers import get_linear_schedule_with_warmup diff --git a/nablaDFT/graphormer/graphormer_3d.py b/nablaDFT/graphormer/graphormer_3d.py index 1e970c7..cf7c4f3 100644 --- a/nablaDFT/graphormer/graphormer_3d.py +++ b/nablaDFT/graphormer/graphormer_3d.py @@ -12,8 +12,8 @@ @torch.jit.script -def softmax_dropout(input, dropout_prob: float, is_training: bool): - return F.dropout(F.softmax(input, -1), dropout_prob, is_training) +def softmax_dropout(x, dropout_prob: float, is_training: bool): + return F.dropout(F.softmax(x, -1), dropout_prob, is_training) class SelfMultiheadAttention(nn.Module): @@ -380,17 +380,9 @@ def step( y = batch.y energy_out, forces_out, mask_out = self(batch) loss_energy = self.loss(energy_out, y) - # TODO: temp workaround for datasets w/o forces - if hasattr(batch, "forces"): - forces, mask_forces = to_dense_batch( - batch.forces, batch.batch, batch_size=bsz - ) - masked_forces_out = forces_out * mask_forces.unsqueeze(-1) - loss_forces = self.loss(masked_forces_out, forces) - else: - loss_forces = torch.zeros(1).to(self.device) - masked_forces_out = torch.zeros(1).to(self.device) - forces = torch.zeros(1).to(self.device) + forces, mask_forces = to_dense_batch(batch.forces, batch.batch, batch_size=bsz) + masked_forces_out = forces_out * mask_forces.unsqueeze(-1) + loss_forces = self.loss(masked_forces_out, forces) loss = self.loss_forces_coef * loss_forces + self.loss_energy_coef * loss_energy if calculate_metrics: preds = {"energy": energy_out, "forces": masked_forces_out} @@ -456,7 +448,9 @@ def test_step(self, batch, batch_idx): return loss def predict_step(self, data, **kwargs): - energy_out, forces_out, _ = self(data) + """Note: predictions output consistent with PyG networks""" + energy_out, forces_out, mask = self(data) + forces_out = torch.masked_select(forces_out, mask).reshape(-1, 3) return energy_out, forces_out def configure_optimizers(self): @@ -515,50 +509,3 @@ def _get_batch_size(self, batch): """Function for batch size infer.""" bsz = batch.batch.max().detach().item() + 1 # get batch size return bsz - - -from functools import partial -from torch.optim.lr_scheduler import LambdaLR - - -def get_linear_schedule_with_warmup( - optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L104 - """ - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - lr_lambda = partial( - _get_linear_schedule_with_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - ) - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def _get_linear_schedule_with_warmup_lr_lambda( - current_step: int, *, num_warmup_steps: int, num_training_steps: int -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L98 - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, - float(num_training_steps - current_step) - / float(max(1, num_training_steps - num_warmup_steps)), - ) diff --git a/nablaDFT/graphormer/schedulers.py b/nablaDFT/graphormer/schedulers.py deleted file mode 100644 index d941dc8..0000000 --- a/nablaDFT/graphormer/schedulers.py +++ /dev/null @@ -1,45 +0,0 @@ -from functools import partial -from torch.optim.lr_scheduler import LambdaLR - - -def get_linear_schedule_with_warmup( - optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L104 - """ - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - lr_lambda = partial( - _get_linear_schedule_with_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - ) - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def _get_linear_schedule_with_warmup_lr_lambda( - current_step: int, *, num_warmup_steps: int, num_training_steps: int -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L98 - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, - float(num_training_steps - current_step) - / float(max(1, num_training_steps - num_warmup_steps)), - ) diff --git a/nablaDFT/links/energy_databases.json b/nablaDFT/links/energy_databases.json index f3791c5..0070cdd 100644 --- a/nablaDFT/links/energy_databases.json +++ b/nablaDFT/links/energy_databases.json @@ -1,16 +1,17 @@ { "train_databases": { - "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_100k_energy.db", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_10k_energy.db", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_5k_energy.db", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_2k_energy.db" + "dataset_train_full": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_full_v2_formation_energy.db", + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_100k_v2_formation_energy_w_forces.db", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_10k_v2_formation_energy_w_forces.db", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_5k_v2_formation_energy_w_forces.db", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_2k_v2_formation_energy_w_forces.db" }, "test_databases": { - "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_random.db", - "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_scaffolds.db", - "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_100k_conformers.db", - "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_10k_conformers.db.db", - "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_5k_conformers.db", - "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_2k_conformers.db" + "dataset_test_structures": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_full_structures_v2_formation_energy_forces.db", + "dataset_test_scaffolds": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_full_scaffolds_v2_formation_energy_forces.db", + "dataset_test_conformations_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_100k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_10k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_5k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_2k_conformers_v2_formation_energy_w_forces.db" } -} \ No newline at end of file +} diff --git a/nablaDFT/links/energy_databases_v2.json b/nablaDFT/links/energy_databases_v2.json deleted file mode 100644 index c94874c..0000000 --- a/nablaDFT/links/energy_databases_v2.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "train_databases": { - "dataset_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_full_v2_formation_energy.db", - "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train100k_v2_formation_energy_w_forces_wo_outliers.db", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train10k_v2_formation_energy_w_forces.db", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train5k_v2_formation_energy_w_forces.db", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train2k_v2_formation_energy_w_forces.db" - }, - "test_databases": { - "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_structures_v2_formation_energy.db", - "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_scaffolds_v2_formation_energy.db", - "dataset_test_conformations_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_conformers_v2_formation_energy.db", - "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_100k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_10k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_5k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_2k_conformers_v2_formation_energy_w_forces.db" - } -} diff --git a/nablaDFT/links/hamiltonian_databases.json b/nablaDFT/links/hamiltonian_databases.json index bc6849a..2646558 100644 --- a/nablaDFT/links/hamiltonian_databases.json +++ b/nablaDFT/links/hamiltonian_databases.json @@ -1,16 +1,16 @@ { "train_databases": { - "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_100k.db", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_10k.db", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_5k.db", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_2k.db" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_100k.db", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_10k.db", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_5k.db", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_2k.db" }, "test_databases": { - "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_random.db", - "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_scaffolds.db", - "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_100k_conformers.db", - "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_10k_conformers.db", - "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_5k_conformers.db", - "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_2k_conformers.db" + "dataset_test_structures": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_structures.db", + "dataset_test_scaffolds": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_scaffolds.db", + "dataset_test_conformations_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_100k_conformers.db", + "dataset_test_conformations_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_10k_conformers.db", + "dataset_test_conformations_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_5k_conformers.db", + "dataset_test_conformations_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_2k_conformers.db" } } \ No newline at end of file diff --git a/nablaDFT/links/models_checkpoints.json b/nablaDFT/links/models_checkpoints.json index 628dc3f..44c33d1 100644 --- a/nablaDFT/links/models_checkpoints.json +++ b/nablaDFT/links/models_checkpoints.json @@ -1,53 +1,65 @@ { "SchNet": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_100k.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_10k.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_5k.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_2k.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_100k.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_10k.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_5k.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_2k.ckpt" }, "PaiNN": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_100k.ckpt", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/models_checkpoints/painn/painn_10k.ckpt", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/models_checkpoints/painn/painn_5k_split.pt", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/models_checkpoints/painn/painn_2k_split.pt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_100k.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_10k.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_5k.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_2k.ckpt" }, "DimeNet++": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_100k_epoch=0258.ckpt", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_10k_epoch=0651.ckpt", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_5k_epoch=0545.ckpt", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_2k_epoch=0442.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_100k_epoch=0258.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_10k_epoch=0651.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_5k_epoch=0442.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_2k_epoch=0545.ckpt" }, "PhiSNet": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_100k_split.pt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_10k_split.pt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_5k_split.pt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_2k_split.pt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_100k_split.pt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_10k_split.pt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_5k_split.pt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_2k_split.pt" }, "SchNOrb": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_100k_split.pt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_10k_split.pt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_5k_split.pt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_2k_split.pt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_100k_split.pt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_10k_split.pt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_5k_split.pt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_2k_split.pt" }, "GemNet-OC": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_100k_epoch=085_val_loss=0.028551.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_10k_epoch=337_val_loss=0.060103.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_5k_epoch=427_val_loss=0.078626.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_2k_epoch=400_val_loss=0.142715.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_100k_epoch=085_val_loss=0.028551.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_10k_epoch=337_val_loss=0.060103.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_5k_epoch=427_val_loss=0.078626.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_2k_epoch=400_val_loss=0.142715.ckpt" }, "ESCN-OC": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_100k_epoch=028_val_loss=0.029561.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_10k_epoch=189_val_loss=0.052877.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_5k_epoch=175_val_loss=0.065441.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_2k_epoch=253_val_loss=0.124682.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_100k_epoch=028_val_loss=0.029561.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_10k_epoch=189_val_loss=0.052877.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_5k_epoch=175_val_loss=0.065441.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_2k_epoch=253_val_loss=0.124682.ckpt" }, "Equiformer_v2": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_100k_epoch=010_val_loss=0.302613.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_10k_epoch=107_val_loss=0.337899.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_5k.ckptt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_2k_epoch=409_val_loss=0.354093.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_100k_epoch=010_val_loss=0.302613.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_10k_epoch=107_val_loss=0.337899.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_5k.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_2k_epoch=409_val_loss=0.354093.ckpt" + }, + "Graphormer3D-small": { + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_100k_epoch_420_val_loss_0.005773.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_10k_epoch_1095_val_loss_0.010159.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_5k_epoch_1331_val_loss_0.012179.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_2k_epoch_672_val_loss_0.018476.ckpt" + }, + "QHNet": { + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-100k_dataset_train_100k_epoch=017_val_loss=0.002419.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-10k_dataset_train_10k_epoch=034_val_loss=0.001594.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-5k_dataset_train_5k_epoch=067_val_loss=0.002631.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-2k_dataset_train_2k_epoch=055_val_loss=0.002044.ckpt" } } \ No newline at end of file diff --git a/nablaDFT/optimization/calculator.py b/nablaDFT/optimization/calculator.py index 9a50685..dc64051 100644 --- a/nablaDFT/optimization/calculator.py +++ b/nablaDFT/optimization/calculator.py @@ -13,6 +13,14 @@ class BatchwiseCalculator: """ Base class calculator for neural network models for batchwise optimization. + Args: + model (nn.Module): loaded trained model. + device (str): device used for calculations (default="cpu") + energy_key (str): name of energies in model (default="energy") + force_key (str): name of forces in model (default="forces") + energy_unit (str): energy units used by model (default="eV") + position_unit (str): position units used by model (default="Angstrom") + dtype (torch.dtype): required data type for the model input (default: torch.float32) """ def __init__( @@ -25,34 +33,6 @@ def __init__( position_unit: str = "Ang", dtype: torch.dtype = torch.float32, ): - """ - model: - path to trained model or trained model - - atoms_converter: - Class used to convert ase Atoms objects to schnetpack input - - device: - device used for calculations (default="cpu") - - auxiliary_output_modules: - auxiliary module to manipulate output properties (e.g., prior energy or forces) - - energy_key: - name of energies in model (default="energy") - - force_key: - name of forces in model (default="forces") - - energy_unit: - energy units used by model (default="eV") - - position_unit: - position units used by model (default="Angstrom") - - dtype: - required data type for the model input (default: torch.float32) - """ self.results = None self.atoms = None @@ -93,11 +73,10 @@ def _requires_calculation(self, property_keys: List[str], atoms: List[ase.Atoms] def get_forces( self, atoms: List[ase.Atoms], fixed_atoms_mask: Optional[List[int]] = None ) -> np.array: - """ - atoms: - - fixed_atoms_mask: - list of indices corresponding to atoms with positions fixed in space. + """Return atom's forces. + Args: + atoms (List[ase.Atoms]): list of ase.Atoms objects. + fixed_atoms_mask (optional, List[int]): list of indices corresponding to atoms with positions fixed in space. """ if self._requires_calculation( property_keys=[self.energy_key, self.force_key], atoms=atoms @@ -118,8 +97,11 @@ def calculate(self, atoms: List[ase.Atoms]) -> None: class PyGBatchwiseCalculator(BatchwiseCalculator): - """Batchwise calculator for PyTorch Geometric models for batchwise optimization""" - + """Batchwise calculator for PyTorch Geometric models for batchwise optimization + Args: + model (nn.Module): loaded PyG model. + """ + def __init__( self, model: nn.Module, @@ -157,7 +139,12 @@ def calculate(self, atoms: List[ase.Atoms]) -> None: class SpkBatchwiseCalculator(BatchwiseCalculator): - """Batchwise calculator for SchNetPack models for batchwise optimization""" + """Batchwise calculator for SchNetPack models for batchwise optimization. + + Args: + model (nn.Module): loaded train schnetpack model. + atoms_converter (AtomsConverter): Class used to convert ase Atoms objects to schnetpack input. + """ def __init__( self, diff --git a/nablaDFT/optimization/opt_utils.py b/nablaDFT/optimization/opt_utils.py index 8fc07ce..7eaa54f 100644 --- a/nablaDFT/optimization/opt_utils.py +++ b/nablaDFT/optimization/opt_utils.py @@ -10,6 +10,11 @@ def np_scatter_add(updates, indices, shape): def atoms_list_to_PYG(ase_atoms_list, device): + """Function to convert ase.Atoms object to PyG data batches. + Args: + ase_atoms_list (List[ase.Atoms]): list of ase.Atoms object to convert. + device (str): task device. + """ data = [] for ase_atoms in ase_atoms_list: z = torch.from_numpy(ase_atoms.numbers).long() diff --git a/nablaDFT/optimization/optimizers.py b/nablaDFT/optimization/optimizers.py index cddc75f..f6577f7 100644 --- a/nablaDFT/optimization/optimizers.py +++ b/nablaDFT/optimization/optimizers.py @@ -305,7 +305,6 @@ class ASEBatchwiseLBFGS(BatchwiseOptimizer): extension/adaptation of the ase.optimize.LBFGS optimizer particularly designed for batch-wise relaxation of atomic structures. The inverse Hessian is approximated for each sample separately, which allows for optimizing batches of different structures/compositions. - """ atoms = None @@ -326,7 +325,7 @@ def __init__( fixed_atoms_mask: Optional[List[int]] = None, verbose: bool = False, ): - """Parameters: + """Args: calculator: This calculator provides properties such as forces and energy, which can be used for MD simulations or diff --git a/nablaDFT/optimization/pyg_ase_interface.py b/nablaDFT/optimization/pyg_ase_interface.py index c3c437e..8cde418 100644 --- a/nablaDFT/optimization/pyg_ase_interface.py +++ b/nablaDFT/optimization/pyg_ase_interface.py @@ -40,6 +40,17 @@ class PYGCalculator(Calculator): """ ASE calculator for pytorch geometric machine learning models. + Args: + model_file (str): path to trained model + energy_key (str): name of energies in model (default="energy") + force_key (str): name of forces in model (default="forces") + energy_unit (str, float): energy units used by model (default="kcal/mol") + position_unit (str, float): position units used by model (default="Angstrom") + device (torch.device): device used for calculations (default="cpu") + dtype (torch.dtype): select model precision (default=float32) + converter (callable): converter used to set up input batches + additional_inputs (dict): additional inputs required for some transforms in the converter. + **kwargs: Additional arguments for basic ASE calculator class. """ energy = "energy" @@ -58,19 +69,6 @@ def __init__( dtype: torch.dtype = torch.float32, **kwargs, ): - """ - Args: - model_file (str): path to trained model - energy_key (str): name of energies in model (default="energy") - force_key (str): name of forces in model (default="forces") - energy_unit (str, float): energy units used by model (default="kcal/mol") - position_unit (str, float): position units used by model (default="Angstrom") - device (torch.device): device used for calculations (default="cpu") - dtype (torch.dtype): select model precision (default=float32) - converter (callable): converter used to set up input batches - additional_inputs (dict): additional inputs required for some transforms in the converter. - **kwargs: Additional arguments for basic ase calculator class - """ Calculator.__init__(self, **kwargs) self.energy_key = energy_key @@ -124,8 +122,8 @@ def calculate( """ Args: atoms (ase.Atoms): ASE atoms object. - properties (list of str): select properties computed and stored to results. - system_changes (list of str): List of changes for ASE. + properties (List[str]): select properties computed and stored to results. + system_changes (List[str]): List of changes for ASE. """ # First call original calculator to set atoms attribute # (see https://wiki.fysik.dtu.dk/ase/_modules/ase/calculators/calculator.html#Calculator) @@ -151,6 +149,19 @@ def calculate( class PYGAseInterface: """ Interface for ASE calculations (optimization and molecular dynamics) + + Args: + molecule_path (str): Path to initial geometry + working_dir (str): Path to directory where files should be stored + model_file (str): path to trained model + energy_key (str): name of energies in model (default="energy") + force_key (str): name of forces in model (default="forces") + energy_unit (str, float): energy units used by model (default="kcal/mol") + position_unit (str, float): position units used by model (default="Angstrom") + device (torch.device): device used for calculations (default="cpu") + dtype (torch.dtype): select model precision (default=float32) + optimizer_class (ase.optimize.optimizer): ASE optimizer used for structure relaxation. + fixed_atoms (list(int)): list of indices corresponding to atoms with positions fixed in space. """ def __init__( @@ -168,21 +179,6 @@ def __init__( optimizer_class: type = QuasiNewton, fixed_atoms: Optional[List[int]] = None, ): - """ - Args: - molecule_path: Path to initial geometry - working_dir: Path to directory where files should be stored - model_file (str): path to trained model - energy_key (str): name of energies in model (default="energy") - force_key (str): name of forces in model (default="forces") - energy_unit (str, float): energy units used by model (default="kcal/mol") - position_unit (str, float): position units used by model (default="Angstrom") - device (torch.device): device used for calculations (default="cpu") - dtype (torch.dtype): select model precision (default=float32) - optimizer_class (ase.optimize.optimizer): ASE optimizer used for structure relaxation. - fixed_atoms (list(int)): list of indices corresponding to atoms with positions fixed in space. - - """ # Setup directory self.working_dir = working_dir if not os.path.exists(self.working_dir): diff --git a/nablaDFT/optimization/task.py b/nablaDFT/optimization/task.py index 6bf64f5..5c9831b 100644 --- a/nablaDFT/optimization/task.py +++ b/nablaDFT/optimization/task.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List import tqdm from ase.db import connect @@ -7,7 +7,15 @@ class BatchwiseOptimizeTask: - """Use for batchwise molecules conformations geometry optimization.""" + """Use for batchwise molecules conformations geometry optimization. + + Args: + input_datapath (str): path to ASE database with molecules. + output_datapath (str): path to output database. + optimizer (BatchwiseOptimizer): used for molecule geometry optimization. + converter (AtomsConverter): optional, mandatory for SchNetPack models. + batch_size (int): number of samples per batch. + """ def __init__( self, @@ -16,14 +24,6 @@ def __init__( optimizer: BatchwiseOptimizer, batch_size: int, ) -> None: - """ - Args: - input_datapath (str): path to ASE database with molecules. - output_datapath (str): path to output database. - optimizer (BatchwiseOptimizer): used for molecule geometry optimization. - converter (AtomsConverter): optional, mandatory for SchNetPack models. - batch_size (int): number of samples per batch. - """ self.optimizer = optimizer self.bs = batch_size self.data_db_conn = None diff --git a/nablaDFT/painn_pyg/painn.py b/nablaDFT/painn_pyg/painn.py index 2650ddf..b96fa01 100644 --- a/nablaDFT/painn_pyg/painn.py +++ b/nablaDFT/painn_pyg/painn.py @@ -4,7 +4,7 @@ import torch from torch import nn -from torch_geometric.nn import MessagePassing +from torch_geometric.nn import MessagePassing, radius_graph from torch_scatter import scatter, segment_coo from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -16,18 +16,14 @@ AtomEmbedding, RadialBasis, ScaledSiLU, - ScaleFactor, ) -# from torch_geometric.nn import radius_graph from .utils import ( compute_neighbors, get_edge_id, get_pbc_distances, - radius_graph, radius_graph_pbc, repeat_blocks, - load_scales_compat, ) @@ -54,7 +50,6 @@ def __init__( use_pbc: bool = True, otf_graph: bool = True, num_elements: int = 83, - scale_file: Optional[str] = None, ) -> None: super(PaiNN, self).__init__() @@ -90,22 +85,75 @@ def __init__( PaiNNMessage(hidden_channels, num_rbf).jittable() ) self.update_layers.append(PaiNNUpdate(hidden_channels)) - setattr(self, "upd_out_scalar_scale_%d" % i, ScaleFactor()) self.out_energy = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2), - ScaledSiLU(), + nn.SiLU(), nn.Linear(hidden_channels // 2, 1), ) if self.regress_forces is True and self.direct_forces is True: self.out_forces = PaiNNOutput(hidden_channels) + self.reset_parameters() - self.inv_sqrt_2 = 1 / math.sqrt(2.0) - self.reset_parameters() + @torch.enable_grad() + def forward(self, data): + pos = data.pos + batch = data.batch + z = data.z.long() + + if self.regress_forces and not self.direct_forces: + pos = pos.requires_grad_(True) + + ( + edge_index, + neighbors, + edge_dist, + edge_vector, + id_swap, + ) = self.generate_graph_values(data) + + assert z.dim() == 1 and z.dtype == torch.long + + edge_rbf = self.radial_basis(edge_dist) # rbf * envelope + + x = self.atom_emb(z) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + + #### Interaction blocks ############################################### - load_scales_compat(self, scale_file) + for i in range(self.num_layers): + dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) + + x = x + dx + vec = vec + dvec + + dx, dvec = self.update_layers[i](x, vec) + + x = x + dx + vec = vec + dvec + + #### Output block ##################################################### + per_atom_energy = self.out_energy(x).squeeze(1) + energy = scatter(per_atom_energy, batch, dim=0) + + if self.regress_forces: + if self.direct_forces: + forces = self.out_forces(x, vec) + return energy, forces + else: + forces = -1 * ( + torch.autograd.grad( + energy, + pos, + grad_outputs=torch.ones_like(energy), + create_graph=self.training, + )[0] + ) + return energy, forces + else: + return energy def reset_parameters(self) -> None: nn.init.xavier_uniform_(self.out_energy[0].weight) @@ -308,11 +356,6 @@ def generate_graph_values(self, data): empty_image = neighbors == 0 if torch.any(empty_image): print(f"An image has no neighbors! #images = {empty_image.sum().item()}") - # raise ValueError( - # f"An image has no neighbors: id={data.id[empty_image]}, " - # f"sid={data.sid[empty_image]}, fid={data.fid[empty_image]}" - # ) - # Symmetrize edges for swapping in symmetric message passing ( edge_index, @@ -338,67 +381,6 @@ def generate_graph_values(self, data): id_swap, ) - - @torch.enable_grad() - def forward(self, data): - pos = data.pos - batch = data.batch - z = data.z.long() - - if self.regress_forces and not self.direct_forces: - pos = pos.requires_grad_(True) - - ( - edge_index, - neighbors, - edge_dist, - edge_vector, - id_swap, - ) = self.generate_graph_values(data) - - assert z.dim() == 1 and z.dtype == torch.long - - edge_rbf = self.radial_basis(edge_dist) # rbf * envelope - - x = self.atom_emb(z) - vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) - - #### Interaction blocks ############################################### - - for i in range(self.num_layers): - dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) - - x = x + dx - vec = vec + dvec - x = x * self.inv_sqrt_2 - - dx, dvec = self.update_layers[i](x, vec) - - x = x + dx - vec = vec + dvec - x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) - - #### Output block ##################################################### - per_atom_energy = self.out_energy(x).squeeze(1) - energy = scatter(per_atom_energy, batch, dim=0) - - if self.regress_forces: - if self.direct_forces: - forces = self.out_forces(x, vec) - return energy, forces - else: - forces = -1 * ( - torch.autograd.grad( - per_atom_energy, - pos, - grad_outputs=torch.ones_like(per_atom_energy), - create_graph=self.training, - )[0] - ) - return energy, forces - else: - return energy - def _generate_graph( self, data, @@ -471,7 +453,6 @@ def _generate_graph( j, i = edge_index distance_vec = data.pos[j] - data.pos[i] edge_dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt() - # edge_dist = distance_vec.norm(dim=-1) cell_offsets = torch.zeros(edge_index.shape[1], 3, device=data.pos.device) cell_offset_distances = torch.zeros_like( cell_offsets, device=data.pos.device @@ -514,15 +495,10 @@ def __init__( self.x_proj = nn.Sequential( nn.Linear(hidden_channels, hidden_channels), - ScaledSiLU(), + nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3), ) self.rbf_proj = nn.Linear(num_rbf, hidden_channels * 3) - - self.inv_sqrt_3 = 1 / math.sqrt(3.0) - self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.x_layernorm = nn.LayerNorm(hidden_channels) - self.reset_parameters() def reset_parameters(self) -> None: @@ -532,10 +508,9 @@ def reset_parameters(self) -> None: self.x_proj[2].bias.data.fill_(0) nn.init.xavier_uniform_(self.rbf_proj.weight) self.rbf_proj.bias.data.fill_(0) - self.x_layernorm.reset_parameters() def forward(self, x, vec, edge_index, edge_rbf, edge_vector): - xh = self.x_proj(self.x_layernorm(x)) + xh = self.x_proj(x) # TODO(@abhshkdz): Nans out with AMP here during backprop. Debug / fix. rbfh = self.rbf_proj(edge_rbf) @@ -554,10 +529,7 @@ def forward(self, x, vec, edge_index, edge_rbf, edge_vector): def message(self, xh_j, vec_j, rbfh_ij, r_ij): x, xh2, xh3 = torch.split(xh_j * rbfh_ij, self.hidden_channels, dim=-1) - xh2 = xh2 * self.inv_sqrt_3 - vec = vec_j * xh2.unsqueeze(1) + xh3.unsqueeze(1) * r_ij.unsqueeze(2) - vec = vec * self.inv_sqrt_h return x, vec @@ -587,13 +559,9 @@ def __init__(self, hidden_channels) -> None: self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 2, bias=False) self.xvec_proj = nn.Sequential( nn.Linear(hidden_channels * 2, hidden_channels), - ScaledSiLU(), + nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3), ) - - self.inv_sqrt_2 = 1 / math.sqrt(2.0) - self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.reset_parameters() def reset_parameters(self) -> None: @@ -605,7 +573,7 @@ def reset_parameters(self) -> None: def forward(self, x, vec): vec1, vec2 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) - vec_dot = (vec1 * vec2).sum(dim=1) * self.inv_sqrt_h + vec_dot = (vec1 * vec2).sum(dim=1) # NOTE: Can't use torch.norm because the gradient is NaN for input = 0. # Add an epsilon offset to make sure sqrt is always positive. @@ -615,7 +583,6 @@ def forward(self, x, vec): xvec1, xvec2, xvec3 = torch.split(x_vec_h, self.hidden_channels, dim=-1) dx = xvec1 + xvec2 * vec_dot - dx = dx * self.inv_sqrt_2 dvec = xvec3.unsqueeze(1) * vec1 @@ -702,14 +669,12 @@ def __init__( optimizer: Optimizer, lr_scheduler: LRScheduler, losses: Dict, - ema, metric, loss_coefs ) -> None: super(PaiNNLightning, self).__init__() - self.ema = ema self.net = net - self.save_hyperparameters(logger=True, ignore=["net", "ema"]) + self.save_hyperparameters(logger=True, ignore=["net"]) def forward(self, data): energy, forces = self.net(data) @@ -722,11 +687,7 @@ def step( y = batch.y # make dense batch from PyG batch energy_out, forces_out = self.net(batch) - # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": y, "forces": forces} loss = self._calculate_loss(preds, target) @@ -752,8 +713,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "val/loss", loss, @@ -778,8 +738,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -811,15 +770,10 @@ def configure_optimizers(self): } return {"optimizer": optimizer} - def on_before_zero_grad(self, optimizer: Optimizer) -> None: - self.ema.update() - def on_fit_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -859,13 +813,7 @@ def _reduce_metrics(self, step_type: str = "train"): def _check_devices(self): self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) - - def _instantiate_ema(self): - if self.ema is not None: - self.ema = self.ema(self.parameters()) - + def _get_batch_size(self, batch): """Function for batch size infer.""" bsz = batch.batch.max().detach().item() + 1 # get batch size diff --git a/nablaDFT/pipelines.py b/nablaDFT/pipelines.py index dfee265..03f5cd0 100644 --- a/nablaDFT/pipelines.py +++ b/nablaDFT/pipelines.py @@ -30,12 +30,17 @@ def predict( ): """Function for prediction loop execution. Saves model prediction to "predictions" directory. + + Args: + ckpt_path (str): path to model checkpoint. + config (DictConfig): config for task. see r'config/' for examples. """ - trainer.logger = False # need this to disable save_hyperparameters during predict, otherwise OmegaConf DictConf can't be dumped to YAML + trainer.logger = False # need this to disable save_hyperparameters during predict, + # otherwise OmegaConf DictConf can't be dumped to YAML pred_path = os.path.join(os.getcwd(), "predictions") os.makedirs(pred_path, exist_ok=True) predictions = trainer.predict( - model=model, datamodule=datamodule.dataset, ckpt_path=ckpt_path + model=model, datamodule=datamodule, ckpt_path=ckpt_path ) torch.save(predictions, f"{pred_path}/{config.name}_{config.dataset_name}.pt") @@ -43,6 +48,10 @@ def predict( def optimize(config: DictConfig, ckpt_path: str): """Function for batched molecules optimization. Uses model defined in config. + + Args: + ckpt_path (str): path to model checkpoint. + config (DictConfig): config for task. see r'config/' for examples. """ output_dir = config.get("output_dir") if not os.path.exists(output_dir): @@ -66,6 +75,12 @@ def optimize(config: DictConfig, ckpt_path: str): def run(config: DictConfig): + """Main function to perform task runs on nablaDFT datasets. + Refer to r'nablaDFT/README.md' for detailed description of run configuration. + + Args: + config (DictConfig): config for task. see r'config/' for examples. + """ if config.get("seed"): seed_everything(config.seed) job_type = config.get("job_type") @@ -105,9 +120,9 @@ def run(config: DictConfig): # Datamodule datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) if job_type == "train": - trainer.fit(model=model, datamodule=datamodule.dataset, ckpt_path=ckpt_path) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) elif job_type == "test": - trainer.test(model=model, datamodule=datamodule.dataset, ckpt_path=ckpt_path) + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) elif job_type == "predict": predict(trainer, model, datamodule, ckpt_path, config) # Finalize diff --git a/nablaDFT/qhnet/__init__.py b/nablaDFT/qhnet/__init__.py new file mode 100644 index 0000000..d4ead39 --- /dev/null +++ b/nablaDFT/qhnet/__init__.py @@ -0,0 +1,2 @@ +from .qhnet import QHNet, QHNetLightning +from . masked_mae import MaskedMeanAbsoluteError diff --git a/nablaDFT/qhnet/layers.py b/nablaDFT/qhnet/layers.py new file mode 100644 index 0000000..ae8fdf4 --- /dev/null +++ b/nablaDFT/qhnet/layers.py @@ -0,0 +1,622 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from e3nn import o3 +from torch_scatter import scatter +from e3nn.nn import FullyConnectedNet +from e3nn.o3 import Linear, TensorProduct + + +def prod(x): + """Compute the product of a sequence.""" + out = 1 + for a in x: + out *= a + return out + + +def ShiftedSoftPlus(x): + return torch.nn.functional.softplus(x) - math.log(2.0) + + +def softplus_inverse(x): + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + return x + torch.log(-torch.expm1(-x)) + + +def get_nonlinear(nonlinear: str): + if nonlinear.lower() == 'ssp': + return ShiftedSoftPlus + elif nonlinear.lower() == 'silu': + return F.silu + elif nonlinear.lower() == 'tanh': + return F.tanh + elif nonlinear.lower() == 'abs': + return torch.abs + else: + raise NotImplementedError + + +def get_feasible_irrep(irrep_in1, irrep_in2, cutoff_irrep_out, tp_mode="uvu"): + irrep_mid = [] + instructions = [] + + for i, (_, ir_in) in enumerate(irrep_in1): + for j, (_, ir_edge) in enumerate(irrep_in2): + for ir_out in ir_in * ir_edge: + if ir_out in cutoff_irrep_out: + if (cutoff_irrep_out.count(ir_out), ir_out) not in irrep_mid: + k = len(irrep_mid) + irrep_mid.append((cutoff_irrep_out.count(ir_out), ir_out)) + else: + k = irrep_mid.index((cutoff_irrep_out.count(ir_out), ir_out)) + instructions.append((i, j, k, tp_mode, True)) + + irrep_mid = o3.Irreps(irrep_mid) + normalization_coefficients = [] + for ins in instructions: + ins_dict = { + 'uvw': (irrep_in1[ins[0]].mul * irrep_in2[ins[1]].mul), + 'uvu': irrep_in2[ins[1]].mul, + 'uvv': irrep_in1[ins[0]].mul, + 'uuw': irrep_in1[ins[0]].mul, + 'uuu': 1, + 'uvuv': 1, + 'uvu 0.0: + alpha /= x + normalization_coefficients += [math.sqrt(alpha)] + + irrep_mid, p, _ = irrep_mid.sort() + instructions = [ + (i_in1, i_in2, p[i_out], mode, train, alpha) + for (i_in1, i_in2, i_out, mode, train), alpha + in zip(instructions, normalization_coefficients) + ] + return irrep_mid, instructions + + +def cutoff_function(x, cutoff): + zeros = torch.zeros_like(x) + x_ = torch.where(x < cutoff, x, zeros) + return torch.where(x < cutoff, torch.exp(-x_**2/((cutoff-x_)*(cutoff+x_))), zeros) + + +class ExponentialBernsteinRadialBasisFunctions(nn.Module): + def __init__(self, num_basis_functions, cutoff, ini_alpha=0.5): + super(ExponentialBernsteinRadialBasisFunctions, self).__init__() + self.num_basis_functions = num_basis_functions + self.ini_alpha = ini_alpha + # compute values to initialize buffers + logfactorial = np.zeros((num_basis_functions)) + for i in range(2,num_basis_functions): + logfactorial[i] = logfactorial[i-1] + np.log(i) + v = np.arange(0,num_basis_functions) + n = (num_basis_functions-1)-v + logbinomial = logfactorial[-1]-logfactorial[v]-logfactorial[n] + #register buffers and parameters + self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float32)) + self.register_buffer('logc', torch.tensor(logbinomial, dtype=torch.float32)) + self.register_buffer('n', torch.tensor(n, dtype=torch.float32)) + self.register_buffer('v', torch.tensor(v, dtype=torch.float32)) + self.register_parameter('_alpha', nn.Parameter(torch.tensor(1.0, dtype=torch.float32))) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha)) + + def forward(self, r): + alpha = F.softplus(self._alpha) + x = - alpha * r + x = self.logc + self.n * x + self.v * torch.log(- torch.expm1(x) ) + rbf = cutoff_function(r, self.cutoff) * torch.exp(x) + return rbf + + +class NormGate(torch.nn.Module): + def __init__(self, irrep): + super(NormGate, self).__init__() + self.irrep = irrep + self.norm = o3.Norm(self.irrep) + + num_mul, num_mul_wo_0 = 0, 0 + for mul, ir in self.irrep: + num_mul += mul + if ir.l != 0: + num_mul_wo_0 += mul + + self.mul = o3.ElementwiseTensorProduct( + self.irrep[1:], o3.Irreps(f"{num_mul_wo_0}x0e")) + self.fc = nn.Sequential( + nn.Linear(num_mul, num_mul), + nn.SiLU(), + nn.Linear(num_mul, num_mul)) + + self.num_mul = num_mul + self.num_mul_wo_0 = num_mul_wo_0 + + def forward(self, x): + norm_x = self.norm(x)[:, self.irrep.slices()[0].stop:] + f0 = torch.cat([x[:, self.irrep.slices()[0]], norm_x], dim=-1) + gates = self.fc(f0) + gated = self.mul(x[:, self.irrep.slices()[0].stop:], gates[:, self.irrep.slices()[0].stop:]) + x = torch.cat([gates[:, self.irrep.slices()[0]], gated], dim=-1) + return x + + +class ConvLayer(torch.nn.Module): + def __init__( + self, + irrep_in_node, + irrep_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + invariant_layers=1, + invariant_neurons=32, + avg_num_neighbors=None, + nonlinear='ssp', + use_norm_gate=True, + edge_wise=False, + ): + super(ConvLayer, self).__init__() + self.avg_num_neighbors = avg_num_neighbors + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.edge_wise = edge_wise + self.use_norm_gate = use_norm_gate + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_hidden = irrep_hidden \ + if isinstance(irrep_hidden, o3.Irreps) else o3.Irreps(irrep_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + self.sh_irrep = sh_irrep if isinstance(sh_irrep, o3.Irreps) else o3.Irreps(sh_irrep) + self.nonlinear_layer = get_nonlinear(nonlinear) + + self.irrep_tp_out_node, instruction_node = get_feasible_irrep( + self.irrep_in_node, self.sh_irrep, self.irrep_hidden, tp_mode='uvu') + + self.tp_node = TensorProduct( + self.irrep_in_node, + self.sh_irrep, + self.irrep_tp_out_node, + instruction_node, + shared_weights=False, + internal_weights=False, + ) + + self.fc_node = FullyConnectedNet( + [self.edge_attr_dim] + invariant_layers * [invariant_neurons] + [self.tp_node.weight_numel], + self.nonlinear_layer + ) + + num_mul = 0 + for mul, ir in self.irrep_in_node: + num_mul = num_mul + mul + + self.layer_l0 = FullyConnectedNet( + [num_mul + self.irrep_in_node[0][0]] + invariant_layers * [invariant_neurons] + [self.tp_node.weight_numel], + self.nonlinear_layer + ) + + self.linear_out = Linear( + irreps_in=self.irrep_tp_out_node, + irreps_out=self.irrep_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.irrep_linear_out, instruction_node = get_feasible_irrep( + self.irrep_in_node, o3.Irreps("0e"), self.irrep_in_node) + if use_norm_gate: + # if this first layer, then it doesn't need this + self.norm_gate = NormGate(self.irrep_in_node) + self.linear_node = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.linear_node_pre = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.inner_product = InnerProduct(self.irrep_in_node) + + def forward(self, data, x): + edge_dst, edge_src = data.edge_index[0], data.edge_index[1] + + if self.use_norm_gate: + pre_x = self.linear_node_pre(x) + s0 = self.inner_product(pre_x[edge_dst], pre_x[edge_src])[:, self.irrep_in_node.slices()[0].stop:] + s0 = torch.cat([pre_x[edge_dst][:, self.irrep_in_node.slices()[0]], + pre_x[edge_dst][:, self.irrep_in_node.slices()[0]], s0], dim=-1) + x = self.norm_gate(x) + x = self.linear_node(x) + else: + s0 = self.inner_product(x[edge_dst], x[edge_src])[:, self.irrep_in_node.slices()[0].stop:] + s0 = torch.cat([x[edge_dst][:, self.irrep_in_node.slices()[0]], + x[edge_dst][:, self.irrep_in_node.slices()[0]], s0], dim=-1) + + self_x = x + + edge_features = self.tp_node( + x[edge_src], data.edge_sh, self.fc_node(data.edge_attr) * self.layer_l0(s0)) + + if self.edge_wise: + out = edge_features + else: + out = scatter(edge_features, edge_dst, dim=0, dim_size=len(x)) + + if self.irrep_in_node == self.irrep_out: + out = out + self_x + + out = self.linear_out(out) + return out + + +class InnerProduct(torch.nn.Module): + def __init__(self, irrep_in): + super(InnerProduct, self).__init__() + self.irrep_in = o3.Irreps(irrep_in).simplify() + irrep_out = o3.Irreps([(mul, "0e") for mul, _ in self.irrep_in]) + instr = [(i, i, i, "uuu", False, 1/ir.dim) for i, (mul, ir) in enumerate(self.irrep_in)] + self.tp = o3.TensorProduct(self.irrep_in, self.irrep_in, irrep_out, instr, irrep_normalization="component") + self.irrep_out = irrep_out.simplify() + + def forward(self, features_1, features_2): + out = self.tp(features_1, features_2) + return out + + +class ConvNetLayer(torch.nn.Module): + def __init__( + self, + irrep_in_node, + irrep_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + resnet: bool = True, + use_norm_gate=True, + edge_wise=False, + ): + super(ConvNetLayer, self).__init__() + self.nonlinear_scalars = {1: "ssp", -1: "tanh"} + self.nonlinear_gates = {1: "ssp", -1: "abs"} + + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_hidden = irrep_hidden if isinstance(irrep_hidden, o3.Irreps) \ + else o3.Irreps(irrep_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + self.sh_irrep = sh_irrep if isinstance(sh_irrep, o3.Irreps) else o3.Irreps(sh_irrep) + + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.resnet = resnet and self.irrep_in_node == self.irrep_out + + self.conv = ConvLayer( + irrep_in_node=self.irrep_in_node, + irrep_hidden=self.irrep_hidden, + sh_irrep=self.sh_irrep, + irrep_out=self.irrep_out, + edge_attr_dim=self.edge_attr_dim, + node_attr_dim=self.node_attr_dim, + invariant_layers=1, + invariant_neurons=32, + avg_num_neighbors=None, + nonlinear='ssp', + use_norm_gate=use_norm_gate, + edge_wise=edge_wise + ) + + def forward(self, data, x): + old_x = x + x = self.conv(data, x) + if self.resnet and self.irrep_out == self.irrep_in_node: + x = old_x + x + return x + + +class PairNetLayer(torch.nn.Module): + def __init__(self, + irrep_in_node, + irrep_bottle_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + resnet: bool = True, + invariant_layers=1, + invariant_neurons=8, + nonlinear='ssp'): + super(PairNetLayer, self).__init__() + self.nonlinear_scalars = {1: "ssp", -1: "tanh"} + self.nonlinear_gates = {1: "ssp", -1: "abs"} + self.invariant_layers = invariant_layers + self.invariant_neurons = invariant_neurons + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_bottle_hidden = irrep_bottle_hidden \ + if isinstance(irrep_bottle_hidden, o3.Irreps) else o3.Irreps(irrep_bottle_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + self.sh_irrep = sh_irrep if isinstance(sh_irrep, o3.Irreps) else o3.Irreps(sh_irrep) + + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.nonlinear_layer = get_nonlinear(nonlinear) + + self.irrep_tp_in_node, _ = get_feasible_irrep(self.irrep_in_node, o3.Irreps("0e"), self.irrep_bottle_hidden) + self.irrep_tp_out_node_pair, instruction_node_pair = get_feasible_irrep( + self.irrep_tp_in_node, self.irrep_tp_in_node, self.irrep_bottle_hidden, tp_mode='uuu') + + self.irrep_tp_out_node_pair_msg, instruction_node_pair_msg = get_feasible_irrep( + self.irrep_tp_in_node, self.sh_irrep, self.irrep_bottle_hidden, tp_mode='uvu') + + self.linear_node_pair = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_tp_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + + self.linear_node_pair_n = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.linear_node_pair_inner = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + + self.tp_node_pair = TensorProduct( + self.irrep_tp_in_node, + self.irrep_tp_in_node, + self.irrep_tp_out_node_pair, + instruction_node_pair, + shared_weights=False, + internal_weights=False, + ) + + self.irrep_tp_out_node_pair_2, instruction_node_pair_2 = get_feasible_irrep( + self.irrep_tp_out_node_pair, self.irrep_tp_out_node_pair, self.irrep_bottle_hidden, tp_mode='uuu') + + self.fc_node_pair = FullyConnectedNet( + [self.edge_attr_dim] + invariant_layers * [invariant_neurons] + [self.tp_node_pair.weight_numel], + self.nonlinear_layer + ) + + if self.irrep_in_node == self.irrep_out and resnet: + self.resnet = True + else: + self.resnet = False + + self.linear_node_pair = Linear( + irreps_in=self.irrep_tp_out_node_pair, + irreps_out=self.irrep_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.norm_gate = NormGate(self.irrep_tp_out_node_pair) + self.inner_product = InnerProduct(self.irrep_in_node) + self.norm = o3.Norm(self.irrep_in_node) + num_mul = 0 + for mul, ir in self.irrep_in_node: + num_mul = num_mul + mul + + self.norm_gate_pre = NormGate(self.irrep_tp_out_node_pair) + self.fc = nn.Sequential( + nn.Linear(self.irrep_in_node[0][0] + num_mul, self.irrep_in_node[0][0]), + nn.SiLU(), + nn.Linear(self.irrep_in_node[0][0], self.tp_node_pair.weight_numel)) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, data, node_attr, node_pair_attr=None): + dst, src = data.full_edge_index + node_attr_0 = self.linear_node_pair_inner(node_attr) + s0 = self.inner_product(node_attr_0[dst], node_attr_0[src])[:, self.irrep_in_node.slices()[0].stop:] + s0 = torch.cat([node_attr_0[dst][:, self.irrep_in_node.slices()[0]], + node_attr_0[src][:, self.irrep_in_node.slices()[0]], s0], dim=-1) + + node_attr = self.norm_gate_pre(node_attr) + node_attr = self.linear_node_pair_n(node_attr) + + node_pair = self.tp_node_pair(node_attr[src], node_attr[dst], + self.fc_node_pair(data.full_edge_attr) * self.fc(s0)) + + node_pair = self.norm_gate(node_pair) + node_pair = self.linear_node_pair(node_pair) + + if self.resnet and node_pair_attr is not None: + node_pair = node_pair + node_pair_attr + return node_pair + + +class SelfNetLayer(torch.nn.Module): + def __init__(self, + irrep_in_node, + irrep_bottle_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + resnet: bool = True, + nonlinear='ssp'): + super(SelfNetLayer, self).__init__() + self.nonlinear_scalars = {1: "ssp", -1: "tanh"} + self.nonlinear_gates = {1: "ssp", -1: "abs"} + self.sh_irrep = sh_irrep + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_bottle_hidden = irrep_bottle_hidden \ + if isinstance(irrep_bottle_hidden, o3.Irreps) else o3.Irreps(irrep_bottle_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.resnet = resnet + self.nonlinear_layer = get_nonlinear(nonlinear) + + self.irrep_tp_in_node, _ = get_feasible_irrep(self.irrep_in_node, o3.Irreps("0e"), self.irrep_bottle_hidden) + self.irrep_tp_out_node, instruction_node = get_feasible_irrep( + self.irrep_tp_in_node, self.irrep_tp_in_node, self.irrep_bottle_hidden, tp_mode='uuu') + + # - Build modules - + self.linear_node_1 = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + + self.linear_node_2 = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.tp = TensorProduct( + self.irrep_tp_in_node, + self.irrep_tp_in_node, + self.irrep_tp_out_node, + instruction_node, + shared_weights=True, + internal_weights=True + ) + self.norm_gate = NormGate(self.irrep_out) + self.norm_gate_1 = NormGate(self.irrep_in_node) + self.norm_gate_2 = NormGate(self.irrep_in_node) + self.linear_node_3 = Linear( + irreps_in=self.irrep_tp_out_node, + irreps_out=self.irrep_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + + def forward(self, data, x, old_fii): + old_x = x + xl = self.norm_gate_1(x) + xl = self.linear_node_1(xl) + xr = self.norm_gate_2(x) + xr = self.linear_node_2(xr) + x = self.tp(xl, xr) + if self.resnet: + x = x + old_x + x = self.norm_gate(x) + x = self.linear_node_3(x) + if self.resnet and old_fii is not None: + x = old_fii + x + return x + + @property + def device(self): + return next(self.parameters()).device + + +class Expansion(nn.Module): + def __init__(self, irrep_in, irrep_out_1, irrep_out_2): + super(Expansion, self).__init__() + self.irrep_in = irrep_in + self.irrep_out_1 = irrep_out_1 + self.irrep_out_2 = irrep_out_2 + self.instructions = self.get_expansion_path(irrep_in, irrep_out_1, irrep_out_2) + self.num_path_weight = sum(prod(ins[-1]) for ins in self.instructions if ins[3]) + self.num_bias = sum([prod(ins[-1][1:]) for ins in self.instructions if ins[0] == 0]) + if self.num_path_weight > 0: + self.weights = nn.Parameter(torch.rand(self.num_path_weight + self.num_bias)) + self.num_weights = self.num_path_weight + self.num_bias + + def forward(self, x_in, weights=None, bias_weights=None): + batch_num = x_in.shape[0] + if len(self.irrep_in) == 1: + x_in_s = [x_in.reshape(batch_num, self.irrep_in[0].mul, self.irrep_in[0].ir.dim)] + else: + x_in_s = [ + x_in[:, i].reshape(batch_num, mul_ir.mul, mul_ir.ir.dim) + for i, mul_ir in zip(self.irrep_in.slices(), self.irrep_in)] + + outputs = {} + flat_weight_index = 0 + bias_weight_index = 0 + for ins in self.instructions: + mul_ir_in = self.irrep_in[ins[0]] + mul_ir_out1 = self.irrep_out_1[ins[1]] + mul_ir_out2 = self.irrep_out_2[ins[2]] + x1 = x_in_s[ins[0]] + x1 = x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim) + w3j_matrix = o3.wigner_3j(ins[1], ins[2], ins[0]).to(self.device).type(x1.type()) + if ins[3] is True or weights is not None: + weight = weights[:, flat_weight_index:flat_weight_index + prod(ins[-1])].reshape([-1] + ins[-1]) + result = torch.einsum(f"bwuv, bwk-> buvk", weight, x1) + if ins[0] == 0 and bias_weights is not None: + bias_weight = bias_weights[:,bias_weight_index:bias_weight_index + prod(ins[-1][1:])].\ + reshape([-1] + ins[-1][1:]) + bias_weight_index += prod(ins[-1][1:]) + result = result + bias_weight.unsqueeze(-1) + result = torch.einsum(f"ijk, buvk->buivj", w3j_matrix, result) / mul_ir_in.mul + flat_weight_index += prod(ins[-1]) + else: + result = torch.einsum( + f"uvw, ijk, bwk-> buivj", torch.ones(ins[-1]).type(x1.type()).to(self.device), w3j_matrix, + x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim) + ) + result = result.reshape(batch_num, mul_ir_out1.dim, mul_ir_out2.dim) + key = (ins[1], ins[2]) + if key in outputs.keys(): + outputs[key] = outputs[key] + result + else: + outputs[key] = result + + rows = [] + for i in range(len(self.irrep_out_1)): + blocks = [] + for j in range(len(self.irrep_out_2)): + if (i, j) not in outputs.keys(): + blocks += [torch.zeros((x_in.shape[0], self.irrep_out_1[i].dim, self.irrep_out_2[j].dim), + device=x_in.device).type(x_in.type())] + else: + blocks += [outputs[(i, j)]] + rows.append(torch.cat(blocks, dim=-1)) + output = torch.cat(rows, dim=-2) + return output + + def get_expansion_path(self, irrep_in, irrep_out_1, irrep_out_2): + instructions = [] + for i, (num_in, ir_in) in enumerate(irrep_in): + for j, (num_out1, ir_out1) in enumerate(irrep_out_1): + for k, (num_out2, ir_out2) in enumerate(irrep_out_2): + if ir_in in ir_out1 * ir_out2: + instructions.append([i, j, k, True, 1.0, [num_in, num_out1, num_out2]]) + return instructions + + @property + def device(self): + return next(self.parameters()).device + + def __repr__(self): + return f'{self.irrep_in} -> {self.irrep_out_1}x{self.irrep_out_1} and bias {self.num_bias}' \ + f'with parameters {self.num_path_weight}' \ No newline at end of file diff --git a/nablaDFT/qhnet/loss.py b/nablaDFT/qhnet/loss.py new file mode 100644 index 0000000..cdfbab9 --- /dev/null +++ b/nablaDFT/qhnet/loss.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn + + +class HamiltonianLoss(nn.Module): + def __init__(self) -> None: + super(HamiltonianLoss, self).__init__() + + def forward(self, pred, target, mask): + diff = pred - target + mse = torch.mean(diff**2) + mae = torch.mean(torch.abs(diff)) + mse *= (pred.numel() / mask.sum()) + mae *= (pred.numel() / mask.sum()) + rmse = torch.sqrt(mse) + return rmse + mae diff --git a/nablaDFT/qhnet/masked_mae.py b/nablaDFT/qhnet/masked_mae.py new file mode 100644 index 0000000..0e2d345 --- /dev/null +++ b/nablaDFT/qhnet/masked_mae.py @@ -0,0 +1,19 @@ +import torch +from torch import Tensor + +from torchmetrics.functional.regression.mae import _mean_absolute_error_update +from torchmetrics.regression import MeanAbsoluteError + + +class MaskedMeanAbsoluteError(MeanAbsoluteError): + """Overloaded MAE for usage with block diagonal matrix. + Mask calculated from target tensor.""" + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + Both inputs are block diagonal Torch Tensors.""" + sum_abs_error, _ = _mean_absolute_error_update(preds, target) + num_obs = torch.count_nonzero(target).item() + + self.sum_abs_error += sum_abs_error + self.total += num_obs diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py new file mode 100644 index 0000000..348a6f6 --- /dev/null +++ b/nablaDFT/qhnet/qhnet.py @@ -0,0 +1,471 @@ +from typing import Dict, List + +import numpy as np +import torch +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch_geometric.data import Data +from torch_cluster import radius_graph +from e3nn import o3 +from e3nn.o3 import Linear +import pytorch_lightning as pl + +from .layers import ExponentialBernsteinRadialBasisFunctions, ConvNetLayer, PairNetLayer, SelfNetLayer, Expansion, get_nonlinear + + +class QHNet(nn.Module): + """Modified QHNet from paper + Args: + orbitals (Dict): defines orbitals for each atom type from the dataset. + """ + def __init__(self, + in_node_features=1, + sh_lmax=4, + hidden_size=128, + bottle_hidden_size=32, + num_gnn_layers=5, + max_radius=12, + num_nodes=10, + radius_embed_dim=32, # maximum nuclear charge (+1, i.e. 87 for up to Rn) for embeddings, can be kept at default + orbitals: Dict = None): + super(QHNet, self).__init__() + # store hyperparameter values + self.order = sh_lmax + self.sh_irrep = o3.Irreps.spherical_harmonics(lmax=self.order) + self.hs = hidden_size + self.hbs = bottle_hidden_size + self.radius_embed_dim = radius_embed_dim + self.max_radius = max_radius + self.num_gnn_layers = num_gnn_layers + self.node_embedding = nn.Embedding(num_nodes, self.hs) + self.hidden_irrep = o3.Irreps(f'{self.hs}x0e + {self.hs}x1o + {self.hs}x2e + {self.hs}x3o + {self.hs}x4e') + self.hidden_bottle_irrep = o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1o + {self.hbs}x2e + {self.hbs}x3o + {self.hbs}x4e') + self.hidden_irrep_base = o3.Irreps(f'{self.hs}x0e + {self.hs}x1e + {self.hs}x2e + {self.hs}x3e + {self.hs}x4e') # in use + self.input_irrep = o3.Irreps(f'{self.hs}x0e') + self.distance_expansion = ExponentialBernsteinRadialBasisFunctions(self.radius_embed_dim, self.max_radius) + self.num_fc_layer = 1 + + orbital_mask, max_s, max_p, max_d = self._get_mask(orbitals) # max_* used below to define output representation + self.orbital_mask = orbital_mask + + self.e3_gnn_layer = nn.ModuleList() + self.e3_gnn_node_pair_layer = nn.ModuleList() + self.e3_gnn_node_layer = nn.ModuleList() + self.udpate_layer = nn.ModuleList() + self.start_layer = 2 + for i in range(self.num_gnn_layers): + input_irrep = self.input_irrep if i == 0 else self.hidden_irrep + self.e3_gnn_layer.append(ConvNetLayer( + irrep_in_node=input_irrep, + irrep_hidden=self.hidden_irrep, + irrep_out=self.hidden_irrep, + edge_attr_dim=self.radius_embed_dim, + node_attr_dim=self.hs, + sh_irrep=self.sh_irrep, + resnet=True, + use_norm_gate=True if i != 0 else False + )) + + if i > self.start_layer: + self.e3_gnn_node_layer.append(SelfNetLayer( + irrep_in_node=self.hidden_irrep_base, + irrep_bottle_hidden=self.hidden_irrep_base, + irrep_out=self.hidden_irrep_base, + sh_irrep=self.sh_irrep, + edge_attr_dim=self.radius_embed_dim, + node_attr_dim=self.hs, + resnet=True, + )) + + self.e3_gnn_node_pair_layer.append(PairNetLayer( + irrep_in_node=self.hidden_irrep_base, + irrep_bottle_hidden=self.hidden_irrep_base, + irrep_out=self.hidden_irrep_base, + sh_irrep=self.sh_irrep, + edge_attr_dim=self.radius_embed_dim, + node_attr_dim=self.hs, + invariant_layers=self.num_fc_layer, + invariant_neurons=self.hs, + resnet=True, + )) + + self.nonlinear_layer = get_nonlinear('ssp') + self.expand_ii, self.expand_ij, self.fc_ii, self.fc_ij, self.fc_ii_bias, self.fc_ij_bias = \ + nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict() + for name in {"hamiltonian"}: + input_expand_ii = o3.Irreps(f"{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e") + + self.expand_ii[name] = Expansion( + input_expand_ii, + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e"), # here we define which basis we use + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e") # here we define which basis we use + ) + self.fc_ii[name] = torch.nn.Sequential( + nn.Linear(self.hs, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ii[name].num_path_weight) + ) + self.fc_ii_bias[name] = torch.nn.Sequential( + nn.Linear(self.hs, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ii[name].num_bias) + ) + self.expand_ij[name] = Expansion( + o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e'), + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e"), # here we define which basis we use + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e") # here we define which basis we use + ) + + self.fc_ij[name] = torch.nn.Sequential( + nn.Linear(self.hs * 2, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ij[name].num_path_weight) + ) + + self.fc_ij_bias[name] = torch.nn.Sequential( + nn.Linear(self.hs * 2, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ij[name].num_bias) + ) + + self.output_ii = Linear(self.hidden_irrep, self.hidden_bottle_irrep) + self.output_ij = Linear(self.hidden_irrep, self.hidden_bottle_irrep) + + def set(self): + for key in self.orbital_mask.keys(): + self.orbital_mask[key] = self.orbital_mask[key].to(self.device) + + def get_number_of_parameters(self): + num = 0 + for param in self.parameters(): + if param.requires_grad: + num += param.numel() + return num + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, data, keep_blocks=False): + node_attr, edge_index, rbf_new, edge_sh, _ = self.build_graph(data, self.max_radius) + node_attr = self.node_embedding(node_attr) + data.node_attr, data.edge_index, data.edge_attr, data.edge_sh = \ + node_attr, edge_index, rbf_new, edge_sh + + _, full_edge_index, full_edge_attr, full_edge_sh, transpose_edge_index = \ + self.build_graph(data, max_radius=10000) + + data.full_edge_index, data.full_edge_attr, data.full_edge_sh = \ + full_edge_index, full_edge_attr, full_edge_sh + + full_dst, full_src = data.full_edge_index + + fii = None + fij = None + for layer_idx, layer in enumerate(self.e3_gnn_layer): + node_attr = layer(data, node_attr) + if layer_idx > self.start_layer: + fii = self.e3_gnn_node_layer[layer_idx-self.start_layer-1](data, node_attr, fii) + fij = self.e3_gnn_node_pair_layer[layer_idx-self.start_layer-1](data, node_attr, fij) + + fii = self.output_ii(fii) + fij = self.output_ij(fij) + hamiltonian_diagonal_matrix = self.expand_ii['hamiltonian']( + fii, self.fc_ii['hamiltonian'](data.node_attr), self.fc_ii_bias['hamiltonian'](data.node_attr)) + node_pair_embedding = torch.cat([data.node_attr[full_dst], data.node_attr[full_src]], dim=-1) + hamiltonian_non_diagonal_matrix = self.expand_ij['hamiltonian']( + fij, self.fc_ij['hamiltonian'](node_pair_embedding), + self.fc_ij_bias['hamiltonian'](node_pair_embedding)) + if keep_blocks is False: + hamiltonian_matrix = self.build_final_matrix( + data, hamiltonian_diagonal_matrix, hamiltonian_non_diagonal_matrix) + hamiltonian_matrix = hamiltonian_matrix + hamiltonian_matrix.transpose(-1, -2) + return hamiltonian_matrix + else: + ret_hamiltonian_diagonal_matrix = hamiltonian_diagonal_matrix +\ + hamiltonian_diagonal_matrix.transpose(-1, -2) + + # the transpose should considers the i, j + ret_hamiltonian_non_diagonal_matrix = hamiltonian_non_diagonal_matrix + \ + hamiltonian_non_diagonal_matrix[transpose_edge_index].transpose(-1, -2) + + results = {} + results['hamiltonian_diagonal_blocks'] = ret_hamiltonian_diagonal_matrix + results['hamiltonian_non_diagonal_blocks'] = ret_hamiltonian_non_diagonal_matrix + return results + + def build_graph(self, data, max_radius, edge_index=None): + node_attr = data.z.squeeze() + + if edge_index is None: + radius_edges = radius_graph(data.pos, max_radius, data.batch, max_num_neighbors=data.num_nodes) + else: + radius_edges = data.full_edge_index + + dst, src = radius_edges + edge_vec = data.pos[dst.long()] - data.pos[src.long()] + rbf = self.distance_expansion(edge_vec.norm(dim=-1).unsqueeze(-1)).squeeze().type(data.pos.type()) + + edge_sh = o3.spherical_harmonics( + self.sh_irrep, edge_vec[:, [1, 2, 0]], + normalize=True, normalization='component').type(data.pos.type()) + + start_edge_index = 0 + all_transpose_index = [] + for graph_idx in range(data.ptr.shape[0] - 1): + num_nodes = data.ptr[graph_idx +1] - data.ptr[graph_idx] + graph_edge_index = radius_edges[:, start_edge_index:start_edge_index+num_nodes*(num_nodes-1)] + sub_graph_edge_index = graph_edge_index - data.ptr[graph_idx] + bias = (sub_graph_edge_index[0] < sub_graph_edge_index[1]).type(torch.int) + transpose_index = sub_graph_edge_index[0] * (num_nodes - 1) + sub_graph_edge_index[1] - bias + transpose_index = transpose_index + start_edge_index + all_transpose_index.append(transpose_index) + start_edge_index = start_edge_index + num_nodes*(num_nodes-1) + + return node_attr, radius_edges, rbf, edge_sh, torch.cat(all_transpose_index, dim=-1) + + def build_final_matrix(self, data, diagonal_matrix, non_diagonal_matrix): + # concate the blocks together and then select once. + final_matrix = [] + dst, src = data.full_edge_index + for graph_idx in range(data.ptr.shape[0] - 1): + matrix_block_col = [] + for src_idx in range(data.ptr[graph_idx], data.ptr[graph_idx+1]): + matrix_col = [] + for dst_idx in range(data.ptr[graph_idx], data.ptr[graph_idx+1]): + if src_idx == dst_idx: + matrix_col.append(diagonal_matrix[src_idx].index_select( + -2, self.orbital_mask[data.z[dst_idx].item()]).index_select( + -1, self.orbital_mask[data.z[src_idx].item()]) + ) + else: + mask1 = (src == src_idx) + mask2 = (dst == dst_idx) + index = torch.where(mask1 & mask2)[0].item() + + matrix_col.append( + non_diagonal_matrix[index].index_select( + -2, self.orbital_mask[data.z[dst_idx].item()]).index_select( + -1, self.orbital_mask[data.z[src_idx].item()])) + matrix_block_col.append(torch.cat(matrix_col, dim=-2)) + final_matrix.append(torch.cat(matrix_block_col, dim=-1)) + final_matrix = torch.block_diag(*final_matrix) + return final_matrix + + def _get_mask(self, orbitals): + # get orbitals by z + # retrieve max orbital and get ranges for mask + max_z = max(orbitals.keys()) + _, counts = np.unique(orbitals[max_z], return_counts=True) + s_max, p_max, d_max = counts # max orbital number per type + s_range = [i for i in range(s_max)] + p_range = [i + max(s_range) + 1 for i in range(p_max * 3)] + d_range = [i + max(p_range) + 1 for i in range(d_max * 5)] + ranges = [s_range, p_range, d_range] + orbs_count = [1, 3, 5] # orbital count per type + # create mask for each atom type + atom_orb_masks = {} + for atom in orbitals.keys(): + _, orb_count = np.unique(orbitals[atom], return_counts=True) + mask = [] + for idx, val in enumerate(orb_count): + mask.extend(ranges[idx][:orb_count[idx] * orbs_count[idx]]) + atom_orb_masks[atom] = torch.tensor(mask) + return atom_orb_masks, s_max, p_max, d_max + + +class QHNetLightning(pl.LightningModule): + def __init__( + self, + model_name: str, + net: nn.Module, + optimizer: Optimizer, + lr_scheduler: LRScheduler, + losses: Dict, + ema, + metric, + loss_coefs, + ) -> None: + super(QHNetLightning, self).__init__() + self.net = net + self.ema = ema + self.save_hyperparameters(logger=True, ignore=['net']) + + def forward(self, data: Data): + hamiltonian = self.net(data) + return hamiltonian + + def step(self, batch, calculate_metrics: bool = False): + hamiltonian_out = self.net(batch) + hamiltonian = batch.hamiltonian + preds = {'hamiltonian': hamiltonian_out} + masks = torch.block_diag(*[torch.ones_like(torch.from_numpy(H)) for H in hamiltonian]) + hamiltonian = torch.block_diag(*[torch.from_numpy(H) for H in hamiltonian]).to(self.device) + target = {'hamiltonian': hamiltonian} + loss = self._calculate_loss(preds, target, masks) + if calculate_metrics: + metrics = self._calculate_metrics(preds, target) + return loss, metrics + return loss + + def training_step(self, batch, batch_idx): + bsz = self._get_batch_size(batch) + loss = self.step(batch, calculate_metrics=False) + self._log_current_lr() + self.log( + "train/loss", + loss, + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + batch_size=bsz, + ) + return loss + + def validation_step(self, batch, batch_idx): + bsz = self._get_batch_size(batch) + with self.ema.average_parameters(): + loss, metrics = self.step(batch, calculate_metrics=True) + self.log( + "val/loss", + loss, + prog_bar=True, + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + batch_size=bsz, + ) + # workaround for checkpoint callback + self.log( + "val_loss", + loss, + on_step=False, + on_epoch=True, + logger=False, + sync_dist=True, + batch_size=bsz, + ) + return loss + + def test_step(self, batch, batch_idx): + bsz = self._get_batch_size(batch) + loss, metrics = self.step(batch, calculate_metrics=True) + self.log( + "test/loss", + loss, + prog_bar=True, + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + batch_size=bsz, + ) + return loss + + def predict_step(self, data) -> List[torch.Tensor]: + hamiltonian = self(data) + sizes = self._get_hamiltonian_sizes(data) + hamiltonians_list = [] + for idx in range(1, len(sizes)): + H = hamiltonian[sizes[idx-1]:sizes[idx], + sizes[idx-1]:sizes[idx]] + hamiltonians_list.append(H) + return hamiltonians_list + + def configure_optimizers(self): + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.lr_scheduler is not None: + scheduler = self.hparams.lr_scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + "monitor": "val_loss", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + def on_before_zero_grad(self, optimizer: Optimizer) -> None: + self.ema.update() + + def on_fit_start(self) -> None: + self._instantiate_ema() + self._check_devices() + + def on_test_start(self) -> None: + self._instantiate_ema() + self._check_devices() + + def on_predict_start(self) -> None: + self._instantiate_ema() + self._check_devices() + + def on_validation_epoch_end(self) -> None: + self._reduce_metrics(step_type="val") + + def on_test_epoch_end(self) -> None: + self._reduce_metrics(step_type="test") + + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() + + def _calculate_loss(self, y_pred, y_true, masks) -> float: + total_loss = 0.0 + for name, loss in self.hparams.losses.items(): + total_loss += self.hparams.loss_coefs[name] * loss( + y_pred[name], y_true[name], masks + ) + return total_loss + + def _calculate_metrics(self, y_pred, y_true) -> Dict: + """Function for metrics calculation during step.""" + metric = self.hparams.metric(y_pred, y_true) + return metric + + def _log_current_lr(self) -> None: + opt = self.optimizers() + current_lr = opt.optimizer.param_groups[0]["lr"] + self.log("LR", current_lr, logger=True) + + def _reduce_metrics(self, step_type: str = "train"): + metric = self.hparams.metric.compute() + for key in metric.keys(): + self.log( + f"{step_type}/{key}", + metric[key], + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + self.hparams.metric.reset() + + def _check_devices(self): + self.hparams.metric = self.hparams.metric.to(self.device) + self.net.set() + if self.ema is not None: + self.ema.to(self.device) + + def _instantiate_ema(self): + if self.ema is not None: + self.ema = self.ema(self.parameters()) + + def _get_batch_size(self, batch): + """Function for batch size infer.""" + bsz = batch.batch.max().detach().item() + 1 # get batch size + return bsz + + def _get_hamiltonian_sizes(self, batch): + sizes = [0] + for idx in range(batch.ptr.shape[0] - 1): + atoms = batch.z[batch.ptr[idx]: batch.ptr[idx + 1]] + size = sum([self.net.orbital_mask[atom.item()].shape[0] for atom in atoms]) + sizes.append(size + sum(sizes)) + return sizes + \ No newline at end of file diff --git a/nablaDFT/schedulers.py b/nablaDFT/schedulers.py new file mode 100644 index 0000000..e1ad7e9 --- /dev/null +++ b/nablaDFT/schedulers.py @@ -0,0 +1,114 @@ +from functools import partial +from torch.optim.lr_scheduler import LambdaLR + + +def get_linear_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 +): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L104 + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_linear_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_linear_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int +): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L98 + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, + float(num_training_steps - current_step) + / float(max(1, num_training_steps - num_warmup_steps)), + ) + + +def _get_polynomial_decay_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float, + power: float, + lr_init: int, +): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L218 + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L239 + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + lr_lambda = partial( + _get_polynomial_decay_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + lr_end=lr_end, + power=power, + lr_init=lr_init, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) \ No newline at end of file diff --git a/nablaDFT/utils.py b/nablaDFT/utils.py index fd165d4..a53f96a 100644 --- a/nablaDFT/utils.py +++ b/nablaDFT/utils.py @@ -7,6 +7,7 @@ import logging import warnings +from tqdm import tqdm import pytorch_lightning as pl from pytorch_lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only @@ -26,6 +27,37 @@ logger = logging.getLogger() +def get_file_size(url: str) -> int: + """Returns file size in bytes + + Args: + url (str): url of file to download + """ + req = request.Request(url, method="HEAD") + with request.urlopen(req) as f: + file_size = f.headers.get('Content-Length') + return int(file_size) + + +def tqdm_download_hook(t): + """wraps TQDM progress bar instance""" + last_block = [0] + + def update_to(blocks_count: int, block_size: int, total_size: int): + """Adds progress bar for request.urlretrieve() method + Args: + - blocks_count (int): transferred blocks count. + - block_size (int): size of block in bytes. + - total_size (int): size of requested file. + """ + if total_size in (None, -1): + t.total = total_size + displayed = t.update((blocks_count - last_block[0]) * block_size) + last_block[0] = blocks_count + return displayed + return update_to + + def seed_everything(seed=42): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) @@ -61,7 +93,11 @@ def init_wandb(): def download_model(config: DictConfig) -> str: - """Downloads best model checkpoint from vault.""" + """Downloads best model checkpoint from vault. + + Args: + config (DictConfig): config for task. see r'config/' for examples. + """ model_name = config.get("name") ckpt_path = os.path.join( hydra.utils.get_original_cwd(), @@ -74,12 +110,20 @@ def download_model(config: DictConfig) -> str: with open(nablaDFT.__path__[0] + "/links/models_checkpoints.json", "r") as f: data = json.load(f) url = data[f"{model_name}"]["dataset_train_100k"] - request.urlretrieve(url, ckpt_path) + file_size = get_file_size(url) + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size, desc=f"Downloading {model_name} checkpoint") as t: + request.urlretrieve(url, ckpt_path, reporthook=tqdm_download_hook(t)) logging.info(f"Downloaded {model_name} 100k checkpoint to {ckpt_path}") return ckpt_path def load_model(config: DictConfig, ckpt_path: str) -> LightningModule: + """Instantiates model and loads model weights from checkpoint. + + Args: + config (DictConfig): config for task. see r'config/' for examples. + ckpt_path (str): path to checkpoint. + """ model: LightningModule = hydra.utils.instantiate(config.model) if ckpt_path is None: warnings.warn( diff --git a/setup.py b/setup.py index 5089771..49dfa13 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,23 @@ +import sys import os import io from setuptools import setup, find_packages + def read(fname): with io.open(os.path.join(os.path.dirname(__file__), fname), encoding="utf-8") as f: return f.read() +def get_python_version(): + version_info = sys.version_info + major = version_info[0] + minor = version_info[1] + return f"cp{major}{minor}" + + +CUDA = "cu121" +PYTHON_VERSION = get_python_version() setup( @@ -16,29 +27,33 @@ def read(fname): url="https://github.com/AIRI-Institute/nablaDFT", packages=find_packages(), include_package_data=True, - python_requires=">=3.6", + python_requires=">=3.8", install_requires=[ - "numpy", - "sympy", - "ase>=3.21", - "h5py", - "apsw", - "schnetpack>=2.0.0", + "numpy>=1.26", + "sympy==1.12", + "ase==3.22.1", + "h5py==3.10.0", + "apsw==3.45.1.0", + "schnetpack==2.0.4", "tensorboardX", "pyyaml", - "hydra-core>=1.1.0", - "pytorch_lightning>=1.9.0", - "torch-geometric>=2.3.1", - "torchmetrics", + "hydra-core==1.2.0", + "torch==2.2.0", + f"torch-scatter @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-{PYTHON_VERSION}-{PYTHON_VERSION}-linux_x86_64.whl", + f"torch-sparse @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_sparse-0.6.18%2Bpt22cu121-{PYTHON_VERSION}-{PYTHON_VERSION}-linux_x86_64.whl", + f"torch-cluster @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_cluster-1.6.3%2Bpt22cu121-{PYTHON_VERSION}-{PYTHON_VERSION}-linux_x86_64.whl", + "pytorch_lightning==2.1.4", + "torch-geometric==2.4.0", + "torchmetrics==1.0.1", "hydra-colorlog>=1.1.0", "rich", "fasteners", "dirsync", - "torch-ema", + "torch-ema==0.3", "matscipy", "python-dotenv", - "wandb", - "e3nn" + "wandb==0.16.3", + "e3nn==0.5.1" ], license="MIT", description="nablaDFT: Large-Scale Conformational Energy and Hamiltonian Prediction benchmark and dataset",